diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index c052a3e..af9980b 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -24,6 +24,10 @@ import ( urlpkg "github.com/nais/wonderwall/pkg/url" ) +func init() { + jwt.Settings(jwt.WithFlattenAudience(true)) +} + var ( ErrOpenIDClient = errors.New("client error") ErrOpenIDServer = errors.New("server error") @@ -141,20 +145,16 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { iat := time.Now().Add(-5 * time.Second).Truncate(time.Second) exp := iat.Add(expiration) - errs := make([]error, 0) - - tok := jwt.New() - errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID())) - errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID())) - errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer())) - errs = append(errs, tok.Set(jwt.IssuedAtKey, iat)) - errs = append(errs, tok.Set(jwt.ExpirationKey, exp)) - errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String())) - - for _, err := range errs { - if err != nil { - return "", fmt.Errorf("setting claim for client assertion: %w", err) - } + tok, err := jwt.NewBuilder(). + Issuer(clientCfg.ClientID()). + Subject(clientCfg.ClientID()). + Audience([]string{providerCfg.Issuer()}). // the aud claim is flattened to a single string value on serialization + IssuedAt(iat). + Expiration(exp). + JwtID(uuid.New().String()). + Build() + if err != nil { + return "", fmt.Errorf("building client assertion: %w", err) } encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key)) diff --git a/pkg/openid/client/client_test.go b/pkg/openid/client/client_test.go index b6828e3..1ff89a6 100644 --- a/pkg/openid/client/client_test.go +++ b/pkg/openid/client/client_test.go @@ -1,6 +1,9 @@ package client_test import ( + "encoding/base64" + "encoding/json" + "strings" "testing" "time" @@ -20,9 +23,11 @@ func TestMakeAssertion(t *testing.T) { c := newTestClientWithConfig(openidConfig) expiry := 30 * time.Second - assertionString, err := c.MakeAssertion(expiry) + jwtAssertion, err := c.MakeAssertion(expiry) assert.NoError(t, err) + assertFlattenedAudience(t, jwtAssertion) + key := openidConfig.Client().ClientJWK() publicKey, err := key.PublicKey() assert.NoError(t, err) @@ -32,8 +37,7 @@ func TestMakeAssertion(t *testing.T) { jwt.WithRequiredClaim(jwt.ExpirationKey), jwt.WithRequiredClaim(jwt.JwtIDKey), } - - assertion, err := jwt.Parse([]byte(assertionString), opts...) + assertion, err := jwt.ParseString(jwtAssertion, opts...) assert.NoError(t, err) assert.ElementsMatch(t, []string{"some-issuer"}, assertion.Audience()) @@ -45,6 +49,22 @@ func TestMakeAssertion(t *testing.T) { assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry))) } +// assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value. +// We do this as the jwx library only exposes the audience as a slice of strings for parsed JWTs. +func assertFlattenedAudience(t *testing.T, jwtAssertion string) { + parts := strings.Split(jwtAssertion, ".") + assert.Len(t, parts, 3) + + rawClaims, err := base64.RawURLEncoding.DecodeString(parts[1]) + assert.NoError(t, err) + + claims := make(map[string]any) + err = json.Unmarshal(rawClaims, &claims) + assert.NoError(t, err) + + assert.Equal(t, "some-issuer", claims["aud"]) +} + func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client { jwksProvider := mock.NewTestJwksProvider() return client.NewClient(config, jwksProvider)