mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-02-14 10:00:05 +00:00
refactor and add unit test for AuditRequestParams()
This commit is contained in:
committed by
Joshua Casey
parent
c06141c871
commit
51d1cc7a96
@@ -3,12 +3,6 @@
|
||||
|
||||
package auditevent
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
)
|
||||
|
||||
type Message string
|
||||
|
||||
const (
|
||||
@@ -32,41 +26,3 @@ const (
|
||||
TokenCredentialRequestUnsupportedUserInfo Message = "TokenCredentialRequest Unsupported UserInfo" //nolint:gosec // this is not a credential
|
||||
IncorrectUsernameOrPassword Message = "Incorrect Username Or Password" //nolint:gosec // this is not a credential
|
||||
)
|
||||
|
||||
// SanitizeParams can be used to redact all params not included in the allowedKeys set.
|
||||
// Useful when audit logging HTTPRequestParameters events.
|
||||
func SanitizeParams(inputParams url.Values, allowedKeys sets.Set[string]) []any {
|
||||
params := make(map[string]string)
|
||||
multiValueParams := make(url.Values)
|
||||
|
||||
transform := func(key, value string) string {
|
||||
if !allowedKeys.Has(key) {
|
||||
return "redacted"
|
||||
}
|
||||
|
||||
unescape, err := url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
// ignore these errors and just use the original query parameter
|
||||
unescape = value
|
||||
}
|
||||
return unescape
|
||||
}
|
||||
|
||||
for key := range inputParams {
|
||||
for i, p := range inputParams[key] {
|
||||
transformed := transform(key, p)
|
||||
if i == 0 {
|
||||
params[key] = transformed
|
||||
}
|
||||
|
||||
if len(inputParams[key]) > 1 {
|
||||
multiValueParams[key] = append(multiValueParams[key], transformed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(multiValueParams) > 0 {
|
||||
return []any{"params", params, "multiValueParams", multiValueParams}
|
||||
}
|
||||
return []any{"params", params}
|
||||
}
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auditevent
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
)
|
||||
|
||||
func TestSanitizeParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params url.Values
|
||||
allowedKeys sets.Set[string]
|
||||
want []any
|
||||
}{
|
||||
{
|
||||
name: "nil values",
|
||||
params: nil,
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
params: url.Values{},
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New("foo", "bar"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "d",
|
||||
"foo": "a",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": []string{"d", "e", "f"},
|
||||
"foo": []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all allowed values with single values",
|
||||
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
|
||||
allowedKeys: sets.New("foo", "bar"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"foo": "a",
|
||||
"bar": "d",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": []string{"redacted", "redacted", "redacted"},
|
||||
"foo": []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some allowed values with single values",
|
||||
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New[string](),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "redacted",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"redacted", "redacted", "redacted"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "redacted",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"redacted", "redacted", "redacted"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "url decodes allowed values",
|
||||
params: url.Values{
|
||||
"foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"},
|
||||
"bar": []string{"d", "e", "f"},
|
||||
},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a:b",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ignores url decode errors",
|
||||
params: url.Values{
|
||||
"bad_encoding": []string{"%.."},
|
||||
},
|
||||
allowedKeys: sets.New("bad_encoding"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bad_encoding": "%..",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// This comparison should require the exact order
|
||||
require.Equal(t, test.want, SanitizeParams(test.params, test.allowedKeys))
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/auditid/auditid.go
Normal file
38
internal/auditid/auditid.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auditid
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
apiserveraudit "k8s.io/apiserver/pkg/apis/audit"
|
||||
"k8s.io/apiserver/pkg/audit"
|
||||
)
|
||||
|
||||
// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID().
|
||||
func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) {
|
||||
ctx := audit.WithAuditContext(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
auditID := newAuditIDFunc()
|
||||
audit.WithAuditID(ctx, types.UID(auditID))
|
||||
|
||||
return r, auditID
|
||||
}
|
||||
|
||||
func WithAuditID(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add a randomly generated request ID to the context for this request.
|
||||
r, auditID := NewRequestWithAuditID(r, func() string {
|
||||
return uuid.New().String()
|
||||
})
|
||||
|
||||
// Send the Audit-ID response header.
|
||||
w.Header().Set(apiserveraudit.HeaderAuditID, auditID)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -30,12 +30,12 @@ import (
|
||||
|
||||
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
|
||||
"go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1"
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/authenticators"
|
||||
"go.pinniped.dev/internal/federationdomain/csrftoken"
|
||||
"go.pinniped.dev/internal/federationdomain/endpoints/jwks"
|
||||
"go.pinniped.dev/internal/federationdomain/oidc"
|
||||
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
|
||||
"go.pinniped.dev/internal/federationdomain/requestlogger"
|
||||
"go.pinniped.dev/internal/federationdomain/stateparam"
|
||||
"go.pinniped.dev/internal/federationdomain/storage"
|
||||
"go.pinniped.dev/internal/here"
|
||||
@@ -4118,7 +4118,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo
|
||||
if test.customPasswordHeader != nil {
|
||||
req.Header.Set("Pinniped-Password", *test.customPasswordHeader)
|
||||
}
|
||||
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
rsp := httptest.NewRecorder()
|
||||
|
||||
subject.ServeHTTP(rsp, req)
|
||||
|
||||
@@ -22,10 +22,10 @@ import (
|
||||
|
||||
supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1"
|
||||
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/federationdomain/endpoints/jwks"
|
||||
"go.pinniped.dev/internal/federationdomain/oidc"
|
||||
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
|
||||
"go.pinniped.dev/internal/federationdomain/requestlogger"
|
||||
"go.pinniped.dev/internal/federationdomain/stateparam"
|
||||
"go.pinniped.dev/internal/federationdomain/storage"
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
@@ -1979,7 +1979,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
if test.csrfCookie != "" {
|
||||
req.Header.Set("Cookie", test.csrfCookie)
|
||||
}
|
||||
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
rsp := httptest.NewRecorder()
|
||||
subject.ServeHTTP(rsp, req)
|
||||
t.Logf("response: %#v", rsp)
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/federationdomain/oidc"
|
||||
"go.pinniped.dev/internal/federationdomain/requestlogger"
|
||||
"go.pinniped.dev/internal/federationdomain/stateparam"
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
@@ -412,7 +412,7 @@ func TestLoginEndpoint(t *testing.T) {
|
||||
if test.csrfCookie != "" {
|
||||
req.Header.Set("Cookie", test.csrfCookie)
|
||||
}
|
||||
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "fake-audit-id" })
|
||||
rsp := httptest.NewRecorder()
|
||||
|
||||
testGetHandler := func(
|
||||
|
||||
@@ -19,12 +19,12 @@ import (
|
||||
|
||||
supervisorconfigv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1"
|
||||
supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/authenticators"
|
||||
"go.pinniped.dev/internal/celtransformer"
|
||||
"go.pinniped.dev/internal/federationdomain/endpoints/jwks"
|
||||
"go.pinniped.dev/internal/federationdomain/oidc"
|
||||
"go.pinniped.dev/internal/federationdomain/oidcclientvalidator"
|
||||
"go.pinniped.dev/internal/federationdomain/requestlogger"
|
||||
"go.pinniped.dev/internal/federationdomain/storage"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/psession"
|
||||
@@ -1342,7 +1342,7 @@ func TestPostLoginEndpoint(t *testing.T) {
|
||||
if tt.reqURIQuery != nil {
|
||||
req.URL.RawQuery = tt.reqURIQuery.Encode()
|
||||
}
|
||||
req, _ = requestlogger.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
|
||||
rsp := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
|
||||
"go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/typed/config/v1alpha1"
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/config/supervisor"
|
||||
"go.pinniped.dev/internal/federationdomain/csrftoken"
|
||||
"go.pinniped.dev/internal/federationdomain/dynamiccodec"
|
||||
@@ -197,7 +198,7 @@ func (m *Manager) buildHandlerChain(nextHandler http.Handler, auditInternalPaths
|
||||
// Log all requests, including audit ID.
|
||||
handler = requestlogger.WithHTTPRequestAuditLogging(handler, m.auditLogger, auditInternalPathsCfg)
|
||||
// Add random audit ID to request context and response headers.
|
||||
handler = requestlogger.WithAuditID(handler)
|
||||
handler = auditid.WithAuditID(handler)
|
||||
m.handlerChain = handler
|
||||
}
|
||||
|
||||
|
||||
@@ -11,10 +11,6 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
apisaudit "k8s.io/apiserver/pkg/apis/audit"
|
||||
"k8s.io/apiserver/pkg/audit"
|
||||
"k8s.io/apiserver/pkg/endpoints/responsewriter"
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
@@ -24,31 +20,6 @@ import (
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
// NewRequestWithAuditID is public for use in unit tests. Production code should use WithAuditID().
|
||||
func NewRequestWithAuditID(r *http.Request, newAuditIDFunc func() string) (*http.Request, string) {
|
||||
ctx := audit.WithAuditContext(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
auditID := newAuditIDFunc()
|
||||
audit.WithAuditID(ctx, types.UID(auditID))
|
||||
|
||||
return r, auditID
|
||||
}
|
||||
|
||||
func WithAuditID(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add a randomly generated request ID to the context for this request.
|
||||
r, auditID := NewRequestWithAuditID(r, func() string {
|
||||
return uuid.New().String()
|
||||
})
|
||||
|
||||
// Send the Audit-ID response header.
|
||||
w.Header().Set(apisaudit.HeaderAuditID, auditID)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func WithHTTPRequestAuditLogging(handler http.Handler, auditLogger plog.AuditLogger, auditInternalPathsCfg supervisor.AuditInternalPaths) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
rl := newRequestLogger(req, w, auditLogger, time.Now(), auditInternalPathsCfg)
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"slices"
|
||||
@@ -205,7 +206,7 @@ func (a *auditLogger) AuditRequestParams(r *http.Request, reqParamsSafeToLog set
|
||||
|
||||
a.Audit(auditevent.HTTPRequestParameters, &AuditParams{
|
||||
ReqCtx: r.Context(),
|
||||
KeysAndValues: auditevent.SanitizeParams(r.Form, reqParamsSafeToLog),
|
||||
KeysAndValues: sanitizeRequestParams(r.Form, reqParamsSafeToLog),
|
||||
})
|
||||
|
||||
return nil
|
||||
@@ -539,3 +540,41 @@ func (p *piiKeysAndValues) asJSONValue(v any) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeRequestParams can be used to redact all params not included in the allowedKeys set.
|
||||
// Useful when audit logging HTTPRequestParameters events.
|
||||
func sanitizeRequestParams(inputParams url.Values, allowedKeys sets.Set[string]) []any {
|
||||
params := make(map[string]string)
|
||||
multiValueParams := make(url.Values)
|
||||
|
||||
transform := func(key, value string) string {
|
||||
if !allowedKeys.Has(key) {
|
||||
return "redacted"
|
||||
}
|
||||
|
||||
unescape, err := url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
// ignore these errors and just use the original query parameter
|
||||
unescape = value
|
||||
}
|
||||
return unescape
|
||||
}
|
||||
|
||||
for key := range inputParams {
|
||||
for i, p := range inputParams[key] {
|
||||
transformed := transform(key, p)
|
||||
if i == 0 {
|
||||
params[key] = transformed
|
||||
}
|
||||
|
||||
if len(inputParams[key]) > 1 {
|
||||
multiValueParams[key] = append(multiValueParams[key], transformed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(multiValueParams) > 0 {
|
||||
return []any{"params", params, "multiValueParams", multiValueParams}
|
||||
}
|
||||
return []any{"params", params}
|
||||
}
|
||||
|
||||
@@ -6,15 +6,21 @@ package plog
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-semver/semver"
|
||||
"github.com/ory/fosite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apiserver/pkg/audit"
|
||||
|
||||
"go.pinniped.dev/internal/auditid"
|
||||
"go.pinniped.dev/internal/here"
|
||||
)
|
||||
|
||||
@@ -196,6 +202,125 @@ func TestAudit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditRequestParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req func() *http.Request
|
||||
paramsSafeToLog sets.Set[string]
|
||||
want string
|
||||
wantErr *fosite.RFC6749Error
|
||||
}{
|
||||
{
|
||||
name: "get request",
|
||||
req: func() *http.Request {
|
||||
params := url.Values{
|
||||
"foo": []string{"bar1", "bar2"},
|
||||
"baz": []string{"baz1", "baz2"},
|
||||
}
|
||||
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?"+params.Encode(), nil)
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
return req
|
||||
},
|
||||
paramsSafeToLog: sets.New("foo"),
|
||||
want: here.Doc(`
|
||||
{"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:<line>$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}}
|
||||
`),
|
||||
},
|
||||
{
|
||||
name: "post request with urlencoded form in body",
|
||||
req: func() *http.Request {
|
||||
params := url.Values{
|
||||
"foo": []string{"bar1", "bar2"},
|
||||
"baz": []string{"baz1", "baz2"},
|
||||
}
|
||||
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader(params.Encode()))
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
paramsSafeToLog: sets.New("foo"),
|
||||
want: here.Doc(`
|
||||
{"level":"info","timestamp":"2099-08-08T13:57:36.123456Z","caller":"plog/plog.go:<line>$plog.(*auditLogger).AuditRequestParams","message":"HTTP Request Parameters","auditEvent":true,"auditID":"some-audit-id","params":{"baz":"redacted","foo":"bar1"},"multiValueParams":{"baz":["redacted","redacted"],"foo":["bar1","bar2"]}}
|
||||
`),
|
||||
},
|
||||
{
|
||||
name: "get request with bad form",
|
||||
req: func() *http.Request {
|
||||
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?invalid;;;form", nil)
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
return req
|
||||
},
|
||||
paramsSafeToLog: sets.New("foo"),
|
||||
wantErr: &fosite.RFC6749Error{
|
||||
CodeField: fosite.ErrInvalidRequest.CodeField,
|
||||
ErrorField: fosite.ErrInvalidRequest.ErrorField,
|
||||
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
|
||||
HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.",
|
||||
DebugField: "invalid semicolon separator in query",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "post request with bad urlencoded form in body",
|
||||
req: func() *http.Request {
|
||||
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("invalid;;;form"))
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
paramsSafeToLog: sets.New("foo"),
|
||||
wantErr: &fosite.RFC6749Error{
|
||||
CodeField: fosite.ErrInvalidRequest.CodeField,
|
||||
ErrorField: fosite.ErrInvalidRequest.ErrorField,
|
||||
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
|
||||
HintField: "Unable to parse form params, make sure to send a properly formatted query params or form request body.",
|
||||
DebugField: "invalid semicolon separator in query",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "post request with bad multipart form in body",
|
||||
req: func() *http.Request {
|
||||
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("this is not a valid multipart form"))
|
||||
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
paramsSafeToLog: sets.New("foo"),
|
||||
wantErr: &fosite.RFC6749Error{
|
||||
CodeField: fosite.ErrInvalidRequest.CodeField,
|
||||
ErrorField: fosite.ErrInvalidRequest.ErrorField,
|
||||
DescriptionField: fosite.ErrInvalidRequest.DescriptionField,
|
||||
HintField: "Unable to parse multipart HTTP body, make sure to send a properly formatted form request body.",
|
||||
DebugField: "no multipart boundary param in Content-Type",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l, actualAuditLogs := TestAuditLogger(t)
|
||||
|
||||
rawErr := l.AuditRequestParams(test.req(), test.paramsSafeToLog)
|
||||
|
||||
if test.wantErr == nil {
|
||||
require.NoError(t, rawErr)
|
||||
} else {
|
||||
require.Error(t, rawErr)
|
||||
err, ok := rawErr.(*fosite.RFC6749Error)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, test.wantErr.CodeField, err.CodeField)
|
||||
require.Equal(t, test.wantErr.ErrorField, err.ErrorField)
|
||||
require.Equal(t, test.wantErr.DescriptionField, err.DescriptionField)
|
||||
require.Equal(t, test.wantErr.HintField, err.HintField)
|
||||
require.Equal(t, test.wantErr.DebugField, err.DebugField)
|
||||
}
|
||||
|
||||
require.Equal(t, strings.TrimSpace(test.want), strings.TrimSpace(actualAuditLogs.String()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlog(t *testing.T) {
|
||||
runtimeVersion := runtime.Version()
|
||||
if strings.HasPrefix(runtimeVersion, "go") {
|
||||
@@ -565,3 +690,162 @@ func testAllPlogMethods(l Logger) {
|
||||
l.All("all", "panda", 2)
|
||||
l.Always("always", "panda", 2)
|
||||
}
|
||||
|
||||
func TestSanitizeRequestParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params url.Values
|
||||
allowedKeys sets.Set[string]
|
||||
want []any
|
||||
}{
|
||||
{
|
||||
name: "nil values",
|
||||
params: nil,
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
params: url.Values{},
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New("foo", "bar"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "d",
|
||||
"foo": "a",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": []string{"d", "e", "f"},
|
||||
"foo": []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all allowed values with single values",
|
||||
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
|
||||
allowedKeys: sets.New("foo", "bar"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"foo": "a",
|
||||
"bar": "d",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": []string{"redacted", "redacted", "redacted"},
|
||||
"foo": []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some allowed values with single values",
|
||||
params: url.Values{"foo": []string{"a"}, "bar": []string{"d"}},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: sets.New[string](),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "redacted",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"redacted", "redacted", "redacted"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil allowed values",
|
||||
params: url.Values{"foo": []string{"a", "b", "c"}, "bar": []string{"d", "e", "f"}},
|
||||
allowedKeys: nil,
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "redacted",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"redacted", "redacted", "redacted"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "url decodes allowed values",
|
||||
params: url.Values{
|
||||
"foo": []string{"a%3Ab", "c", "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"},
|
||||
"bar": []string{"d", "e", "f"},
|
||||
},
|
||||
allowedKeys: sets.New("foo"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bar": "redacted",
|
||||
"foo": "a:b",
|
||||
},
|
||||
"multiValueParams",
|
||||
url.Values{
|
||||
"bar": {"redacted", "redacted", "redacted"},
|
||||
"foo": {"a:b", "c", "urn:ietf:params:oauth:grant-type:token-exchange"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ignores url decode errors",
|
||||
params: url.Values{
|
||||
"bad_encoding": []string{"%.."},
|
||||
},
|
||||
allowedKeys: sets.New("bad_encoding"),
|
||||
want: []any{
|
||||
"params",
|
||||
map[string]string{
|
||||
"bad_encoding": "%..",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// This comparison should require the exact order
|
||||
require.Equal(t, test.want, sanitizeRequestParams(test.params, test.allowedKeys))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user