diff --git a/internal/auditevent/audit_event.go b/internal/auditevent/audit_event.go index b03f82cca..3077f5263 100644 --- a/internal/auditevent/audit_event.go +++ b/internal/auditevent/audit_event.go @@ -3,12 +3,6 @@ package auditevent -import ( - "net/url" - - "k8s.io/apimachinery/pkg/util/sets" -) - type Message string const ( @@ -32,41 +26,3 @@ const ( TokenCredentialRequestUnsupportedUserInfo Message = "TokenCredentialRequest Unsupported UserInfo" //nolint:gosec // this is not a credential IncorrectUsernameOrPassword Message = "Incorrect Username Or Password" //nolint:gosec // this is not a credential ) - -// SanitizeParams can be used to redact all params not included in the allowedKeys set. -// Useful when audit logging HTTPRequestParameters events. -func SanitizeParams(inputParams url.Values, allowedKeys sets.Set[string]) []any { - params := make(map[string]string) - multiValueParams := make(url.Values) - - transform := func(key, value string) string { - if !allowedKeys.Has(key) { - return "redacted" - } - - unescape, err := url.QueryUnescape(value) - if err != nil { - // ignore these errors and just use the original query parameter - unescape = value - } - return unescape - } - - for key := range inputParams { - for i, p := range inputParams[key] { - transformed := transform(key, p) - if i == 0 { - params[key] = transformed - } - - if len(inputParams[key]) > 1 { - multiValueParams[key] = append(multiValueParams[key], transformed) - } - } - } - - if len(multiValueParams) > 0 { - return []any{"params", params, "multiValueParams", multiValueParams} - } - return []any{"params", params} -} diff --git a/internal/auditevent/audit_event_test.go b/internal/auditevent/audit_event_test.go deleted file mode 100644 index 51924710c..000000000 --- a/internal/auditevent/audit_event_test.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2024 the Pinniped contributors. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package auditevent - -import ( - "net/url" - "testing" - - "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/util/sets" -) - -func TestSanitizeParams(t *testing.T) { - tests := []struct { - name string - params url.Values - allowedKeys sets.Set[string] - want []any - }{ - { - name: "nil values", - params: nil, - allowedKeys: nil, - want: []any{ - "params", - map[string]string{}, - }, - }, - { - name: "empty values", - params: url.Values{}, - allowedKeys: nil, - want: []any{ - "params", - map[string]string{}, - }, - }, - { - name: "all allowed values", - params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, - allowedKeys: sets.New("foo", "bar"), - want: []any{ - "params", - map[string]string{ - "bar": "d", - "foo": "a", - }, - "multiValueParams", - url.Values{ - "bar": []string{"d", "e", "f"}, - "foo": []string{"a", "b", "c"}, - }, - }, - }, - { - name: "all allowed values with single values", - params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, - allowedKeys: sets.New("foo", "bar"), - want: []any{ - "params", - map[string]string{ - "foo": "a", - "bar": "d", - }, - }, - }, - { - name: "some allowed values", - params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, - allowedKeys: sets.New("foo"), - want: []any{ - "params", - map[string]string{ - "bar": "redacted", - "foo": "a", - }, - "multiValueParams", - url.Values{ - "bar": []string{"redacted", "redacted", "redacted"}, - "foo": []string{"a", "b", "c"}, - }, - }, - }, - { - name: "some allowed values with single values", - params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, - allowedKeys: sets.New("foo"), - want: []any{ - "params", - map[string]string{ - "bar": "redacted", - "foo": "a", - }, - }, - }, - { - name: "no allowed values", - params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, - allowedKeys: sets.New[string](), - want: []any{ - "params", - map[string]string{ - "bar": "redacted", - "foo": "redacted", - }, - "multiValueParams", - url.Values{ - "bar": {"redacted", "redacted", "redacted"}, - "foo": {"redacted", "redacted", "redacted"}, - }, - }, - }, - { - name: "nil allowed values", - params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, - allowedKeys: nil, - want: []any{ - "params", - map[string]string{ - "bar": "redacted", - "foo": "redacted", - }, - "multiValueParams", - url.Values{ - "bar": {"redacted", "redacted", "redacted"}, - "foo": {"redacted", "redacted", "redacted"}, - }, - }, - }, - { - name: "url decodes allowed values", - params: url.Values{ - "foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"}, - "bar": []string{"d", "e", "f"}, - }, - allowedKeys: sets.New("foo"), - want: []any{ - "params", - map[string]string{ - "bar": "redacted", - "foo": "a:b", - }, - "multiValueParams", - url.Values{ - "bar": {"redacted", "redacted", "redacted"}, - "foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"}, - }, - }, - }, - { - name: "ignores url decode errors", - params: url.Values{ - "bad_encoding": []string{"%.."}, - }, - allowedKeys: sets.New("bad_encoding"), - want: []any{ - "params", - map[string]string{ - "bad_encoding": "%..", - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // This comparison should require the exact order - require.Equal(t, test.want, SanitizeParams(test.params, test.allowedKeys)) - }) - } -} diff --git a/internal/auditid/auditid.go b/internal/auditid/auditid.go new file mode 100644 index 000000000..4c9605df7 --- /dev/null +++ b/internal/auditid/auditid.go @@ -0,0 +1,38 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package auditid + +import ( + "net/http" + + "github.com/google/uuid" + "k8s.io/apimachinery/pkg/types" + apiserveraudit "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/audit" +) + +// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID(). +func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) { + ctx := audit.WithAuditContext(r.Context()) + r = r.WithContext(ctx) + + auditID := newAuditIDFunc() + audit.WithAuditID(ctx, types.UID(auditID)) + + return r, auditID +} + +func WithAuditID(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add a randomly generated request ID to the context for this request. + r, auditID := NewRequestWithAuditID(r, func() string { + return uuid.New().String() + }) + + // Send the Audit-ID response header. + w.Header().Set(apiserveraudit.HeaderAuditID, auditID) + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/federationdomain/endpoints/auth/auth_handler_test.go b/internal/federationdomain/endpoints/auth/auth_handler_test.go index 12bffeda9..8b8350d61 100644 --- a/internal/federationdomain/endpoints/auth/auth_handler_test.go +++ b/internal/federationdomain/endpoints/auth/auth_handler_test.go @@ -30,12 +30,12 @@ import ( supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/federationdomain/csrftoken" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" - "go.pinniped.dev/internal/federationdomain/requestlogger" "go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/here" @@ -4118,7 +4118,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo if test.customPasswordHeader != nil { req.Header.Set("Pinniped-Password", *test.customPasswordHeader) } - req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) diff --git a/internal/federationdomain/endpoints/callback/callback_handler_test.go b/internal/federationdomain/endpoints/callback/callback_handler_test.go index 6000ccb43..d20a157c4 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_test.go @@ -22,10 +22,10 @@ import ( supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" - "go.pinniped.dev/internal/federationdomain/requestlogger" "go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/federationdomain/upstreamprovider" @@ -1979,7 +1979,7 @@ func TestCallbackEndpoint(t *testing.T) { if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) } - req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) t.Logf("response: %#v", rsp) diff --git a/internal/federationdomain/endpoints/login/login_handler_test.go b/internal/federationdomain/endpoints/login/login_handler_test.go index 3182e096f..e3a463a76 100644 --- a/internal/federationdomain/endpoints/login/login_handler_test.go +++ b/internal/federationdomain/endpoints/login/login_handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/gorilla/securecookie" "github.com/stretchr/testify/require" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/federationdomain/oidc" - "go.pinniped.dev/internal/federationdomain/requestlogger" "go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/plog" @@ -412,7 +412,7 @@ func TestLoginEndpoint(t *testing.T) { if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) } - req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) rsp := httptest.NewRecorder() testGetHandler := func( diff --git a/internal/federationdomain/endpoints/login/post_login_handler_test.go b/internal/federationdomain/endpoints/login/post_login_handler_test.go index d64c7f396..06cdc73f7 100644 --- a/internal/federationdomain/endpoints/login/post_login_handler_test.go +++ b/internal/federationdomain/endpoints/login/post_login_handler_test.go @@ -19,12 +19,12 @@ import ( supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/celtransformer" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" - "go.pinniped.dev/internal/federationdomain/requestlogger" "go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" @@ -1342,7 +1342,7 @@ func TestPostLoginEndpoint(t *testing.T) { if tt.reqURIQuery != nil { req.URL.RawQuery = tt.reqURIQuery.Encode() } - req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) rsp := httptest.NewRecorder() diff --git a/internal/federationdomain/endpointsmanager/manager.go b/internal/federationdomain/endpointsmanager/manager.go index d270203af..a1dc50558 100644 --- a/internal/federationdomain/endpointsmanager/manager.go +++ b/internal/federationdomain/endpointsmanager/manager.go @@ -11,6 +11,7 @@ import ( corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/config/supervisor" "go.pinniped.dev/internal/federationdomain/csrftoken" "go.pinniped.dev/internal/federationdomain/dynamiccodec" @@ -197,7 +198,7 @@ func (m *Manager) buildHandlerChain(nextHandler http.Handler, auditInternalPaths // Log all requests, including audit ID. handler = requestlogger.WithHTTPRequestAuditLogging(handler, m.auditLogger, auditInternalPathsCfg) // Add random audit ID to request context and response headers. - handler = requestlogger.WithAuditID(handler) + handler = auditid.WithAuditID(handler) m.handlerChain = handler } diff --git a/internal/federationdomain/requestlogger/request_logger.go b/internal/federationdomain/requestlogger/request_logger.go index b73a34cd4..1560379a9 100644 --- a/internal/federationdomain/requestlogger/request_logger.go +++ b/internal/federationdomain/requestlogger/request_logger.go @@ -11,10 +11,6 @@ import ( "slices" "time" - "github.com/google/uuid" - "k8s.io/apimachinery/pkg/types" - apisaudit "k8s.io/apiserver/pkg/apis/audit" - "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/endpoints/responsewriter" "k8s.io/utils/clock" @@ -24,31 +20,6 @@ import ( "go.pinniped.dev/internal/plog" ) -// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID(). -func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) { - ctx := audit.WithAuditContext(r.Context()) - r = r.WithContext(ctx) - - auditID := newAuditIDFunc() - audit.WithAuditID(ctx, types.UID(auditID)) - - return r, auditID -} - -func WithAuditID(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Add a randomly generated request ID to the context for this request. - r, auditID := NewRequestWithAuditID(r, func() string { - return uuid.New().String() - }) - - // Send the Audit-ID response header. - w.Header().Set(apisaudit.HeaderAuditID, auditID) - - handler.ServeHTTP(w, r) - }) -} - func WithHTTPRequestAuditLogging(handler http.Handler, auditLogger plog.AuditLogger, auditInternalPathsCfg supervisor.AuditInternalPaths) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { rl := newRequestLogger(req, w, auditLogger, time.Now(), auditInternalPathsCfg) diff --git a/internal/plog/plog.go b/internal/plog/plog.go index c97e450fa..061b5fe9b 100644 --- a/internal/plog/plog.go +++ b/internal/plog/plog.go @@ -33,6 +33,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "reflect" "slices" @@ -205,7 +206,7 @@ func (a *auditLogger) AuditRequestParams(r *http.Request, reqParamsSafeToLog set a.Audit(auditevent.HTTPRequestParameters, &AuditParams{ ReqCtx: r.Context(), - KeysAndValues: auditevent.SanitizeParams(r.Form, reqParamsSafeToLog), + KeysAndValues: sanitizeRequestParams(r.Form, reqParamsSafeToLog), }) return nil @@ -539,3 +540,41 @@ func (p *piiKeysAndValues) asJSONValue(v any) string { } } } + +// sanitizeRequestParams can be used to redact all params not included in the allowedKeys set. +// Useful when audit logging HTTPRequestParameters events. +func sanitizeRequestParams(inputParams url.Values, allowedKeys sets.Set[string]) []any { + params := make(map[string]string) + multiValueParams := make(url.Values) + + transform := func(key, value string) string { + if !allowedKeys.Has(key) { + return "redacted" + } + + unescape, err := url.QueryUnescape(value) + if err != nil { + // ignore these errors and just use the original query parameter + unescape = value + } + return unescape + } + + for key := range inputParams { + for i, p := range inputParams[key] { + transformed := transform(key, p) + if i == 0 { + params[key] = transformed + } + + if len(inputParams[key]) > 1 { + multiValueParams[key] = append(multiValueParams[key], transformed) + } + } + } + + if len(multiValueParams) > 0 { + return []any{"params", params, "multiValueParams", multiValueParams} + } + return []any{"params", params} +} diff --git a/internal/plog/plog_test.go b/internal/plog/plog_test.go index 8af8cb3af..e087e531c 100644 --- a/internal/plog/plog_test.go +++ b/internal/plog/plog_test.go @@ -6,15 +6,21 @@ package plog import ( "context" "fmt" + "net/http" + "net/http/httptest" + "net/url" "runtime" "strings" "testing" "time" "github.com/coreos/go-semver/semver" + "github.com/ory/fosite" "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/audit" + "go.pinniped.dev/internal/auditid" "go.pinniped.dev/internal/here" ) @@ -196,6 +202,125 @@ func TestAudit(t *testing.T) { } } +func TestAuditRequestParams(t *testing.T) { + tests := []struct { + name string + req func() *http.Request + paramsSafeToLog sets.Set[string] + want string + wantErr *fosite.RFC6749Error + }{ + { + name: "get request", + req: func() *http.Request { + params := url.Values{ + "foo": []string{"bar1", "bar2"}, + "baz": []string{"baz1", "baz2"}, + } + req := httptest.NewRequestWithContext(context.Background(), "GET", "/?"+params.Encode(), nil) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + return req + }, + paramsSafeToLog: sets.New("foo"), + want: here.Doc(` + {"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}} + `), + }, + { + name: "post request with urlencoded form in body", + req: func() *http.Request { + params := url.Values{ + "foo": []string{"bar1", "bar2"}, + "baz": []string{"baz1", "baz2"}, + } + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader(params.Encode())) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + paramsSafeToLog: sets.New("foo"), + want: here.Doc(` + {"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}} + `), + }, + { + name: "get request with bad form", + req: func() *http.Request { + req := httptest.NewRequestWithContext(context.Background(), "GET", "/?invalid;;;form", nil) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + return req + }, + paramsSafeToLog: sets.New("foo"), + wantErr: &fosite.RFC6749Error{ + CodeField: fosite.ErrInvalidRequest.CodeField, + ErrorField: fosite.ErrInvalidRequest.ErrorField, + DescriptionField: fosite.ErrInvalidRequest.DescriptionField, + HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.", + DebugField: "invalid semicolon separator in query", + }, + }, + { + name: "post request with bad urlencoded form in body", + req: func() *http.Request { + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("invalid;;;form")) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + paramsSafeToLog: sets.New("foo"), + wantErr: &fosite.RFC6749Error{ + CodeField: fosite.ErrInvalidRequest.CodeField, + ErrorField: fosite.ErrInvalidRequest.ErrorField, + DescriptionField: fosite.ErrInvalidRequest.DescriptionField, + HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.", + DebugField: "invalid semicolon separator in query", + }, + }, + { + name: "post request with bad multipart form in body", + req: func() *http.Request { + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("this is not a valid multipart form")) + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + req.Header.Set("Content-Type", "multipart/form-data") + return req + }, + paramsSafeToLog: sets.New("foo"), + wantErr: &fosite.RFC6749Error{ + CodeField: fosite.ErrInvalidRequest.CodeField, + ErrorField: fosite.ErrInvalidRequest.ErrorField, + DescriptionField: fosite.ErrInvalidRequest.DescriptionField, + HintField: "Unable to parse multipart HTTP body, make sure to send a properly formatted form request body.", + DebugField: "no multipart boundary param in Content-Type", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + l, actualAuditLogs := TestAuditLogger(t) + + rawErr := l.AuditRequestParams(test.req(), test.paramsSafeToLog) + + if test.wantErr == nil { + require.NoError(t, rawErr) + } else { + require.Error(t, rawErr) + err, ok := rawErr.(*fosite.RFC6749Error) + require.True(t, ok) + require.Equal(t, test.wantErr.CodeField, err.CodeField) + require.Equal(t, test.wantErr.ErrorField, err.ErrorField) + require.Equal(t, test.wantErr.DescriptionField, err.DescriptionField) + require.Equal(t, test.wantErr.HintField, err.HintField) + require.Equal(t, test.wantErr.DebugField, err.DebugField) + } + + require.Equal(t, strings.TrimSpace(test.want), strings.TrimSpace(actualAuditLogs.String())) + }) + } +} + func TestPlog(t *testing.T) { runtimeVersion := runtime.Version() if strings.HasPrefix(runtimeVersion, "go") { @@ -565,3 +690,162 @@ func testAllPlogMethods(l Logger) { l.All("all", "panda", 2) l.Always("always", "panda", 2) } + +func TestSanitizeRequestParams(t *testing.T) { + tests := []struct { + name string + params url.Values + allowedKeys sets.Set[string] + want []any + }{ + { + name: "nil values", + params: nil, + allowedKeys: nil, + want: []any{ + "params", + map[string]string{}, + }, + }, + { + name: "empty values", + params: url.Values{}, + allowedKeys: nil, + want: []any{ + "params", + map[string]string{}, + }, + }, + { + name: "all allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New("foo", "bar"), + want: []any{ + "params", + map[string]string{ + "bar": "d", + "foo": "a", + }, + "multiValueParams", + url.Values{ + "bar": []string{"d", "e", "f"}, + "foo": []string{"a", "b", "c"}, + }, + }, + }, + { + name: "all allowed values with single values", + params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, + allowedKeys: sets.New("foo", "bar"), + want: []any{ + "params", + map[string]string{ + "foo": "a", + "bar": "d", + }, + }, + }, + { + name: "some allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New("foo"), + want: []any{ + "params", + map[string]string{ + "bar": "redacted", + "foo": "a", + }, + "multiValueParams", + url.Values{ + "bar": []string{"redacted", "redacted", "redacted"}, + "foo": []string{"a", "b", "c"}, + }, + }, + }, + { + name: "some allowed values with single values", + params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}}, + allowedKeys: sets.New("foo"), + want: []any{ + "params", + map[string]string{ + "bar": "redacted", + "foo": "a", + }, + }, + }, + { + name: "no allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: sets.New[string](), + want: []any{ + "params", + map[string]string{ + "bar": "redacted", + "foo": "redacted", + }, + "multiValueParams", + url.Values{ + "bar": {"redacted", "redacted", "redacted"}, + "foo": {"redacted", "redacted", "redacted"}, + }, + }, + }, + { + name: "nil allowed values", + params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}}, + allowedKeys: nil, + want: []any{ + "params", + map[string]string{ + "bar": "redacted", + "foo": "redacted", + }, + "multiValueParams", + url.Values{ + "bar": {"redacted", "redacted", "redacted"}, + "foo": {"redacted", "redacted", "redacted"}, + }, + }, + }, + { + name: "url decodes allowed values", + params: url.Values{ + "foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"}, + "bar": []string{"d", "e", "f"}, + }, + allowedKeys: sets.New("foo"), + want: []any{ + "params", + map[string]string{ + "bar": "redacted", + "foo": "a:b", + }, + "multiValueParams", + url.Values{ + "bar": {"redacted", "redacted", "redacted"}, + "foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"}, + }, + }, + }, + { + name: "ignores url decode errors", + params: url.Values{ + "bad_encoding": []string{"%.."}, + }, + allowedKeys: sets.New("bad_encoding"), + want: []any{ + "params", + map[string]string{ + "bad_encoding": "%..", + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // This comparison should require the exact order + require.Equal(t, test.want, sanitizeRequestParams(test.params, test.allowedKeys)) + }) + } +}