diff --git a/pkg/openid/tokens_test.go b/pkg/openid/tokens_test.go index cd98d40..3a48d3f 100644 --- a/pkg/openid/tokens_test.go +++ b/pkg/openid/tokens_test.go @@ -12,9 +12,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/openid" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" ) var jwks *crypto.JwkSet @@ -112,6 +114,38 @@ func TestIDToken_GetSidClaim(t *testing.T) { } func TestIDToken_Validate(t *testing.T) { + defaultConfig := func() *config.Config { + cfg := mock.Config() + cfg.OpenID.ACRValues = "" + cfg.OpenID.ClientID = "some-client-id" + cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"} + + return cfg + } + + defaultOpenIdConfig := func(cfg *config.Config) *mock.TestConfiguration { + openidcfg := mock.NewTestConfiguration(cfg) + openidcfg.TestProvider.SetIssuer("https://some-issuer") + + return openidcfg + } + + defaultClaims := func(cfg openidconfig.Config) *claims { + return &claims{ + set: map[string]any{ + "aud": cfg.Client().ClientID(), + "iss": cfg.Provider().Issuer(), + }, + remove: []string{}, + } + } + + defaultCookie := func() *openid.LoginCookie { + return &openid.LoginCookie{ + Nonce: "some-nonce", + } + } + for _, tt := range []struct { name string claims *claims @@ -253,51 +287,22 @@ func TestIDToken_Validate(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - cfg := mock.Config() - cfg.OpenID.ACRValues = "" - cfg.OpenID.ClientID = "some-client-id" - cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"} + cfg := defaultConfig() + openidcfg := defaultOpenIdConfig(cfg) + cookie := defaultCookie() - if tt.requireAcr { - cfg.OpenID.ACRValues = "some-acr" - } - - openidcfg := mock.NewTestConfiguration(cfg) - openidcfg.TestProvider.SetIssuer("https://some-issuer") - cookie := &openid.LoginCookie{ - Nonce: "some-nonce", - } - c := &claims{ - set: map[string]any{ - "aud": openidcfg.Client().ClientID(), - "iss": openidcfg.Provider().Issuer(), - }, - remove: []string{}, - } - - if tt.claims != nil { - if tt.claims.set != nil { - for k, v := range tt.claims.set { - c.set[k] = v - } - } - if len(tt.claims.remove) > 0 { - c.remove = append(c.remove, tt.claims.remove...) - } - } + c := defaultClaims(openidcfg) + c.merge(tt.claims) if tt.requireSid { openidcfg.TestProvider.WithFrontChannelLogoutSupport() // sid claim is required - if _, ok := c.set["sid"]; !ok { - c.set["sid"] = "some-sid" - } + c.setIfUnset("sid", "some-sid") } if tt.requireAcr { + cfg.OpenID.ACRValues = "some-acr" cookie.Acr = "some-acr" - if _, ok := c.set["acr"]; !ok { - c.set["acr"] = "some-acr" - } + c.setIfUnset("acr", "some-acr") } idToken, err := makeIDToken(c) @@ -319,6 +324,28 @@ type claims struct { remove []string } +func (in *claims) setIfUnset(key, value string) { + if _, ok := in.set[key]; !ok { + in.set[key] = value + } +} + +func (in *claims) merge(other *claims) { + if other == nil { + return + } + + if other.set != nil { + for k, v := range other.set { + in.set[k] = v + } + } + + if len(other.remove) > 0 { + in.remove = append(in.remove, other.remove...) + } +} + func makeIDToken(claims *claims) (*openid.IDToken, error) { iat := time.Now().Truncate(time.Second).UTC() exp := iat.Add(5 * time.Second)