refactor and add unit test for AuditRequestParams()

This commit is contained in:
Ryan Richard
2024-11-13 12:50:17 -08:00
committed by Joshua Casey
parent c06141c871
commit 51d1cc7a96
11 changed files with 372 additions and 254 deletions

View File

@@ -3,12 +3,6 @@
package auditevent package auditevent
import (
"net/url"
"k8s.io/apimachinery/pkg/util/sets"
)
type Message string type Message string
const ( const (
@@ -32,41 +26,3 @@ const (
TokenCredentialRequestUnsupportedUserInfo Message = "TokenCredentialRequest Unsupported UserInfo" //nolint:gosec // this is not a credential TokenCredentialRequestUnsupportedUserInfo Message = "TokenCredentialRequest Unsupported UserInfo" //nolint:gosec // this is not a credential
IncorrectUsernameOrPassword Message = "Incorrect Username Or Password" //nolint:gosec // this is not a credential IncorrectUsernameOrPassword Message = "Incorrect Username Or Password" //nolint:gosec // this is not a credential
) )
// SanitizeParams can be used to redact all params not included in the allowedKeys set.
// Useful when audit logging HTTPRequestParameters events.
func SanitizeParams(inputParams url.Values, allowedKeys sets.Set[string]) []any {
params := make(map[string]string)
multiValueParams := make(url.Values)
transform := func(key, value string) string {
if !allowedKeys.Has(key) {
return "redacted"
}
unescape, err := url.QueryUnescape(value)
if err != nil {
// ignore these errors and just use the original query parameter
unescape = value
}
return unescape
}
for key := range inputParams {
for i, p := range inputParams[key] {
transformed := transform(key, p)
if i == 0 {
params[key] = transformed
}
if len(inputParams[key]) > 1 {
multiValueParams[key] = append(multiValueParams[key], transformed)
}
}
}
if len(multiValueParams) > 0 {
return []any{"params", params, "multiValueParams", multiValueParams}
}
return []any{"params", params}
}

View File

@@ -1,171 +0,0 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package auditevent
import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
)
func TestSanitizeParams(t *testing.T) {
tests := []struct {
name string
params url.Values
allowedKeys sets.Set[string]
want []any
}{
{
name: "nil values",
params: nil,
allowedKeys: nil,
want: []any{
"params",
map[string]string{},
},
},
{
name: "empty values",
params: url.Values{},
allowedKeys: nil,
want: []any{
"params",
map[string]string{},
},
},
{
name: "all allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New("foo", "bar"),
want: []any{
"params",
map[string]string{
"bar": "d",
"foo": "a",
},
"multiValueParams",
url.Values{
"bar": []string{"d", "e", "f"},
"foo": []string{"a", "b", "c"},
},
},
},
{
name: "all allowed values with single values",
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
allowedKeys: sets.New("foo", "bar"),
want: []any{
"params",
map[string]string{
"foo": "a",
"bar": "d",
},
},
},
{
name: "some allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a",
},
"multiValueParams",
url.Values{
"bar": []string{"redacted", "redacted", "redacted"},
"foo": []string{"a", "b", "c"},
},
},
},
{
name: "some allowed values with single values",
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a",
},
},
},
{
name: "no allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New[string](),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "redacted",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"redacted", "redacted", "redacted"},
},
},
},
{
name: "nil allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: nil,
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "redacted",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"redacted", "redacted", "redacted"},
},
},
},
{
name: "url decodes allowed values",
params: url.Values{
"foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"},
"bar": []string{"d", "e", "f"},
},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a:b",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"},
},
},
},
{
name: "ignores url decode errors",
params: url.Values{
"bad_encoding": []string{"%.."},
},
allowedKeys: sets.New("bad_encoding"),
want: []any{
"params",
map[string]string{
"bad_encoding": "%..",
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// This comparison should require the exact order
require.Equal(t, test.want, SanitizeParams(test.params, test.allowedKeys))
})
}
}

View File

@@ -0,0 +1,38 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package auditid
import (
"net/http"
"github.com/google/uuid"
"k8s.io/apimachinery/pkg/types"
apiserveraudit "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
)
// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID().
func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) {
ctx := audit.WithAuditContext(r.Context())
r = r.WithContext(ctx)
auditID := newAuditIDFunc()
audit.WithAuditID(ctx, types.UID(auditID))
return r, auditID
}
func WithAuditID(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add a randomly generated request ID to the context for this request.
r, auditID := NewRequestWithAuditID(r, func() string {
return uuid.New().String()
})
// Send the Audit-ID response header.
w.Header().Set(apiserveraudit.HeaderAuditID, auditID)
handler.ServeHTTP(w, r)
})
}

View File

@@ -30,12 +30,12 @@ import (
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
"go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1" "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/federationdomain/csrftoken" "go.pinniped.dev/internal/federationdomain/csrftoken"
"go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/endpoints/jwks"
"go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidc"
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
"go.pinniped.dev/internal/federationdomain/requestlogger"
"go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/federationdomain/stateparam"
"go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/federationdomain/storage"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
@@ -4118,7 +4118,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo
if test.customPasswordHeader != nil { if test.customPasswordHeader != nil {
req.Header.Set("Pinniped-Password", *test.customPasswordHeader) req.Header.Set("Pinniped-Password", *test.customPasswordHeader)
} }
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)

View File

@@ -22,10 +22,10 @@ import (
supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1"
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/endpoints/jwks"
"go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidc"
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
"go.pinniped.dev/internal/federationdomain/requestlogger"
"go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/federationdomain/stateparam"
"go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/federationdomain/storage"
"go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/federationdomain/upstreamprovider"
@@ -1979,7 +1979,7 @@ func TestCallbackEndpoint(t *testing.T) {
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)
} }
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)

View File

@@ -13,8 +13,8 @@ import (
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidc"
"go.pinniped.dev/internal/federationdomain/requestlogger"
"go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/federationdomain/stateparam"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
@@ -412,7 +412,7 @@ func TestLoginEndpoint(t *testing.T) {
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)
} }
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" }) req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
testGetHandler := func( testGetHandler := func(

View File

@@ -19,12 +19,12 @@ import (
supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1"
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/celtransformer" "go.pinniped.dev/internal/celtransformer"
"go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/endpoints/jwks"
"go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/oidc"
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
"go.pinniped.dev/internal/federationdomain/requestlogger"
"go.pinniped.dev/internal/federationdomain/storage" "go.pinniped.dev/internal/federationdomain/storage"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
@@ -1342,7 +1342,7 @@ func TestPostLoginEndpoint(t *testing.T) {
if tt.reqURIQuery != nil { if tt.reqURIQuery != nil {
req.URL.RawQuery = tt.reqURIQuery.Encode() req.URL.RawQuery = tt.reqURIQuery.Encode()
} }
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()

View File

@@ -11,6 +11,7 @@ import (
corev1client "k8s.io/client-go/kubernetes/typed/core/v1" corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1" "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/config/supervisor" "go.pinniped.dev/internal/config/supervisor"
"go.pinniped.dev/internal/federationdomain/csrftoken" "go.pinniped.dev/internal/federationdomain/csrftoken"
"go.pinniped.dev/internal/federationdomain/dynamiccodec" "go.pinniped.dev/internal/federationdomain/dynamiccodec"
@@ -197,7 +198,7 @@ func (m *Manager) buildHandlerChain(nextHandler http.Handler, auditInternalPaths
// Log all requests, including audit ID. // Log all requests, including audit ID.
handler = requestlogger.WithHTTPRequestAuditLogging(handler, m.auditLogger, auditInternalPathsCfg) handler = requestlogger.WithHTTPRequestAuditLogging(handler, m.auditLogger, auditInternalPathsCfg)
// Add random audit ID to request context and response headers. // Add random audit ID to request context and response headers.
handler = requestlogger.WithAuditID(handler) handler = auditid.WithAuditID(handler)
m.handlerChain = handler m.handlerChain = handler
} }

View File

@@ -11,10 +11,6 @@ import (
"slices" "slices"
"time" "time"
"github.com/google/uuid"
"k8s.io/apimachinery/pkg/types"
apisaudit "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/endpoints/responsewriter" "k8s.io/apiserver/pkg/endpoints/responsewriter"
"k8s.io/utils/clock" "k8s.io/utils/clock"
@@ -24,31 +20,6 @@ import (
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID().
func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) {
ctx := audit.WithAuditContext(r.Context())
r = r.WithContext(ctx)
auditID := newAuditIDFunc()
audit.WithAuditID(ctx, types.UID(auditID))
return r, auditID
}
func WithAuditID(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add a randomly generated request ID to the context for this request.
r, auditID := NewRequestWithAuditID(r, func() string {
return uuid.New().String()
})
// Send the Audit-ID response header.
w.Header().Set(apisaudit.HeaderAuditID, auditID)
handler.ServeHTTP(w, r)
})
}
func WithHTTPRequestAuditLogging(handler http.Handler, auditLogger plog.AuditLogger, auditInternalPathsCfg supervisor.AuditInternalPaths) http.Handler { func WithHTTPRequestAuditLogging(handler http.Handler, auditLogger plog.AuditLogger, auditInternalPathsCfg supervisor.AuditInternalPaths) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
rl := newRequestLogger(req, w, auditLogger, time.Now(), auditInternalPathsCfg) rl := newRequestLogger(req, w, auditLogger, time.Now(), auditInternalPathsCfg)

View File

@@ -33,6 +33,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"os" "os"
"reflect" "reflect"
"slices" "slices"
@@ -205,7 +206,7 @@ func (a *auditLogger) AuditRequestParams(r *http.Request, reqParamsSafeToLog set
a.Audit(auditevent.HTTPRequestParameters, &AuditParams{ a.Audit(auditevent.HTTPRequestParameters, &AuditParams{
ReqCtx: r.Context(), ReqCtx: r.Context(),
KeysAndValues: auditevent.SanitizeParams(r.Form, reqParamsSafeToLog), KeysAndValues: sanitizeRequestParams(r.Form, reqParamsSafeToLog),
}) })
return nil return nil
@@ -539,3 +540,41 @@ func (p *piiKeysAndValues) asJSONValue(v any) string {
} }
} }
} }
// sanitizeRequestParams can be used to redact all params not included in the allowedKeys set.
// Useful when audit logging HTTPRequestParameters events.
func sanitizeRequestParams(inputParams url.Values, allowedKeys sets.Set[string]) []any {
params := make(map[string]string)
multiValueParams := make(url.Values)
transform := func(key, value string) string {
if !allowedKeys.Has(key) {
return "redacted"
}
unescape, err := url.QueryUnescape(value)
if err != nil {
// ignore these errors and just use the original query parameter
unescape = value
}
return unescape
}
for key := range inputParams {
for i, p := range inputParams[key] {
transformed := transform(key, p)
if i == 0 {
params[key] = transformed
}
if len(inputParams[key]) > 1 {
multiValueParams[key] = append(multiValueParams[key], transformed)
}
}
}
if len(multiValueParams) > 0 {
return []any{"params", params, "multiValueParams", multiValueParams}
}
return []any{"params", params}
}

View File

@@ -6,15 +6,21 @@ package plog
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/http/httptest"
"net/url"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/coreos/go-semver/semver" "github.com/coreos/go-semver/semver"
"github.com/ory/fosite"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/audit"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
) )
@@ -196,6 +202,125 @@ func TestAudit(t *testing.T) {
} }
} }
func TestAuditRequestParams(t *testing.T) {
tests := []struct {
name string
req func() *http.Request
paramsSafeToLog sets.Set[string]
want string
wantErr *fosite.RFC6749Error
}{
{
name: "get request",
req: func() *http.Request {
params := url.Values{
"foo": []string{"bar1", "bar2"},
"baz": []string{"baz1", "baz2"},
}
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?"+params.Encode(), nil)
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
return req
},
paramsSafeToLog: sets.New("foo"),
want: here.Doc(`
{"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:<line>$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}}
`),
},
{
name: "post request with urlencoded form in body",
req: func() *http.Request {
params := url.Values{
"foo": []string{"bar1", "bar2"},
"baz": []string{"baz1", "baz2"},
}
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader(params.Encode()))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
paramsSafeToLog: sets.New("foo"),
want: here.Doc(`
{"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:<line>$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}}
`),
},
{
name: "get request with bad form",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?invalid;;;form", nil)
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
return req
},
paramsSafeToLog: sets.New("foo"),
wantErr: &fosite.RFC6749Error{
CodeField: fosite.ErrInvalidRequest.CodeField,
ErrorField: fosite.ErrInvalidRequest.ErrorField,
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.",
DebugField: "invalid semicolon separator in query",
},
},
{
name: "post request with bad urlencoded form in body",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("invalid;;;form"))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
paramsSafeToLog: sets.New("foo"),
wantErr: &fosite.RFC6749Error{
CodeField: fosite.ErrInvalidRequest.CodeField,
ErrorField: fosite.ErrInvalidRequest.ErrorField,
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.",
DebugField: "invalid semicolon separator in query",
},
},
{
name: "post request with bad multipart form in body",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("this is not a valid multipart form"))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
paramsSafeToLog: sets.New("foo"),
wantErr: &fosite.RFC6749Error{
CodeField: fosite.ErrInvalidRequest.CodeField,
ErrorField: fosite.ErrInvalidRequest.ErrorField,
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
HintField: "Unable to parse multipart HTTP body, make sure to send a properly formatted form request body.",
DebugField: "no multipart boundary param in Content-Type",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
l, actualAuditLogs := TestAuditLogger(t)
rawErr := l.AuditRequestParams(test.req(), test.paramsSafeToLog)
if test.wantErr == nil {
require.NoError(t, rawErr)
} else {
require.Error(t, rawErr)
err, ok := rawErr.(*fosite.RFC6749Error)
require.True(t, ok)
require.Equal(t, test.wantErr.CodeField, err.CodeField)
require.Equal(t, test.wantErr.ErrorField, err.ErrorField)
require.Equal(t, test.wantErr.DescriptionField, err.DescriptionField)
require.Equal(t, test.wantErr.HintField, err.HintField)
require.Equal(t, test.wantErr.DebugField, err.DebugField)
}
require.Equal(t, strings.TrimSpace(test.want), strings.TrimSpace(actualAuditLogs.String()))
})
}
}
func TestPlog(t *testing.T) { func TestPlog(t *testing.T) {
runtimeVersion := runtime.Version() runtimeVersion := runtime.Version()
if strings.HasPrefix(runtimeVersion, "go") { if strings.HasPrefix(runtimeVersion, "go") {
@@ -565,3 +690,162 @@ func testAllPlogMethods(l Logger) {
l.All("all", "panda", 2) l.All("all", "panda", 2)
l.Always("always", "panda", 2) l.Always("always", "panda", 2)
} }
func TestSanitizeRequestParams(t *testing.T) {
tests := []struct {
name string
params url.Values
allowedKeys sets.Set[string]
want []any
}{
{
name: "nil values",
params: nil,
allowedKeys: nil,
want: []any{
"params",
map[string]string{},
},
},
{
name: "empty values",
params: url.Values{},
allowedKeys: nil,
want: []any{
"params",
map[string]string{},
},
},
{
name: "all allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New("foo", "bar"),
want: []any{
"params",
map[string]string{
"bar": "d",
"foo": "a",
},
"multiValueParams",
url.Values{
"bar": []string{"d", "e", "f"},
"foo": []string{"a", "b", "c"},
},
},
},
{
name: "all allowed values with single values",
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
allowedKeys: sets.New("foo", "bar"),
want: []any{
"params",
map[string]string{
"foo": "a",
"bar": "d",
},
},
},
{
name: "some allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a",
},
"multiValueParams",
url.Values{
"bar": []string{"redacted", "redacted", "redacted"},
"foo": []string{"a", "b", "c"},
},
},
},
{
name: "some allowed values with single values",
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a",
},
},
},
{
name: "no allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: sets.New[string](),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "redacted",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"redacted", "redacted", "redacted"},
},
},
},
{
name: "nil allowed values",
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
allowedKeys: nil,
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "redacted",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"redacted", "redacted", "redacted"},
},
},
},
{
name: "url decodes allowed values",
params: url.Values{
"foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"},
"bar": []string{"d", "e", "f"},
},
allowedKeys: sets.New("foo"),
want: []any{
"params",
map[string]string{
"bar": "redacted",
"foo": "a:b",
},
"multiValueParams",
url.Values{
"bar": {"redacted", "redacted", "redacted"},
"foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"},
},
},
},
{
name: "ignores url decode errors",
params: url.Values{
"bad_encoding": []string{"%.."},
},
allowedKeys: sets.New("bad_encoding"),
want: []any{
"params",
map[string]string{
"bad_encoding": "%..",
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// This comparison should require the exact order
require.Equal(t, test.want, sanitizeRequestParams(test.params, test.allowedKeys))
})
}
}