mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 08:27:10 +00:00
refactor(openid): clean up tests
This commit is contained in:
@@ -12,9 +12,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
)
|
||||
|
||||
var jwks *crypto.JwkSet
|
||||
@@ -112,6 +114,38 @@ func TestIDToken_GetSidClaim(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIDToken_Validate(t *testing.T) {
|
||||
defaultConfig := func() *config.Config {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.ACRValues = ""
|
||||
cfg.OpenID.ClientID = "some-client-id"
|
||||
cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
defaultOpenIdConfig := func(cfg *config.Config) *mock.TestConfiguration {
|
||||
openidcfg := mock.NewTestConfiguration(cfg)
|
||||
openidcfg.TestProvider.SetIssuer("https://some-issuer")
|
||||
|
||||
return openidcfg
|
||||
}
|
||||
|
||||
defaultClaims := func(cfg openidconfig.Config) *claims {
|
||||
return &claims{
|
||||
set: map[string]any{
|
||||
"aud": cfg.Client().ClientID(),
|
||||
"iss": cfg.Provider().Issuer(),
|
||||
},
|
||||
remove: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
defaultCookie := func() *openid.LoginCookie {
|
||||
return &openid.LoginCookie{
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
claims *claims
|
||||
@@ -253,51 +287,22 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.ACRValues = ""
|
||||
cfg.OpenID.ClientID = "some-client-id"
|
||||
cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"}
|
||||
cfg := defaultConfig()
|
||||
openidcfg := defaultOpenIdConfig(cfg)
|
||||
cookie := defaultCookie()
|
||||
|
||||
if tt.requireAcr {
|
||||
cfg.OpenID.ACRValues = "some-acr"
|
||||
}
|
||||
|
||||
openidcfg := mock.NewTestConfiguration(cfg)
|
||||
openidcfg.TestProvider.SetIssuer("https://some-issuer")
|
||||
cookie := &openid.LoginCookie{
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
c := &claims{
|
||||
set: map[string]any{
|
||||
"aud": openidcfg.Client().ClientID(),
|
||||
"iss": openidcfg.Provider().Issuer(),
|
||||
},
|
||||
remove: []string{},
|
||||
}
|
||||
|
||||
if tt.claims != nil {
|
||||
if tt.claims.set != nil {
|
||||
for k, v := range tt.claims.set {
|
||||
c.set[k] = v
|
||||
}
|
||||
}
|
||||
if len(tt.claims.remove) > 0 {
|
||||
c.remove = append(c.remove, tt.claims.remove...)
|
||||
}
|
||||
}
|
||||
c := defaultClaims(openidcfg)
|
||||
c.merge(tt.claims)
|
||||
|
||||
if tt.requireSid {
|
||||
openidcfg.TestProvider.WithFrontChannelLogoutSupport() // sid claim is required
|
||||
if _, ok := c.set["sid"]; !ok {
|
||||
c.set["sid"] = "some-sid"
|
||||
}
|
||||
c.setIfUnset("sid", "some-sid")
|
||||
}
|
||||
|
||||
if tt.requireAcr {
|
||||
cfg.OpenID.ACRValues = "some-acr"
|
||||
cookie.Acr = "some-acr"
|
||||
if _, ok := c.set["acr"]; !ok {
|
||||
c.set["acr"] = "some-acr"
|
||||
}
|
||||
c.setIfUnset("acr", "some-acr")
|
||||
}
|
||||
|
||||
idToken, err := makeIDToken(c)
|
||||
@@ -319,6 +324,28 @@ type claims struct {
|
||||
remove []string
|
||||
}
|
||||
|
||||
func (in *claims) setIfUnset(key, value string) {
|
||||
if _, ok := in.set[key]; !ok {
|
||||
in.set[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (in *claims) merge(other *claims) {
|
||||
if other == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if other.set != nil {
|
||||
for k, v := range other.set {
|
||||
in.set[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if len(other.remove) > 0 {
|
||||
in.remove = append(in.remove, other.remove...)
|
||||
}
|
||||
}
|
||||
|
||||
func makeIDToken(claims *claims) (*openid.IDToken, error) {
|
||||
iat := time.Now().Truncate(time.Second).UTC()
|
||||
exp := iat.Add(5 * time.Second)
|
||||
|
||||
Reference in New Issue
Block a user