fix(openid/client): flatten audience for client assertion

In accordance with OpenID Connect 1.0 Core, draft 36 incorporating
errata set 3:

> aud
>    REQUIRED. Audience. The aud (audience) Claim. [...] The Audience value MUST be the OP's Issuer Identifier passed as a string, and not a single-element array.
This commit is contained in:
Trong Huu Nguyen
2025-04-02 13:44:37 +02:00
parent 01241f91ac
commit ca8c09ae10
2 changed files with 37 additions and 17 deletions

View File

@@ -24,6 +24,10 @@ import (
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func init() {
jwt.Settings(jwt.WithFlattenAudience(true))
}
var (
ErrOpenIDClient = errors.New("client error")
ErrOpenIDServer = errors.New("server error")
@@ -141,20 +145,16 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
iat := time.Now().Add(-5 * time.Second).Truncate(time.Second)
exp := iat.Add(expiration)
errs := make([]error, 0)
tok := jwt.New()
errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer()))
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))
for _, err := range errs {
if err != nil {
return "", fmt.Errorf("setting claim for client assertion: %w", err)
}
tok, err := jwt.NewBuilder().
Issuer(clientCfg.ClientID()).
Subject(clientCfg.ClientID()).
Audience([]string{providerCfg.Issuer()}). // the aud claim is flattened to a single string value on serialization
IssuedAt(iat).
Expiration(exp).
JwtID(uuid.New().String()).
Build()
if err != nil {
return "", fmt.Errorf("building client assertion: %w", err)
}
encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key))

View File

@@ -1,6 +1,9 @@
package client_test
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
"time"
@@ -20,9 +23,11 @@ func TestMakeAssertion(t *testing.T) {
c := newTestClientWithConfig(openidConfig)
expiry := 30 * time.Second
assertionString, err := c.MakeAssertion(expiry)
jwtAssertion, err := c.MakeAssertion(expiry)
assert.NoError(t, err)
assertFlattenedAudience(t, jwtAssertion)
key := openidConfig.Client().ClientJWK()
publicKey, err := key.PublicKey()
assert.NoError(t, err)
@@ -32,8 +37,7 @@ func TestMakeAssertion(t *testing.T) {
jwt.WithRequiredClaim(jwt.ExpirationKey),
jwt.WithRequiredClaim(jwt.JwtIDKey),
}
assertion, err := jwt.Parse([]byte(assertionString), opts...)
assertion, err := jwt.ParseString(jwtAssertion, opts...)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"some-issuer"}, assertion.Audience())
@@ -45,6 +49,22 @@ func TestMakeAssertion(t *testing.T) {
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
}
// assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value.
// We do this as the jwx library only exposes the audience as a slice of strings for parsed JWTs.
func assertFlattenedAudience(t *testing.T, jwtAssertion string) {
parts := strings.Split(jwtAssertion, ".")
assert.Len(t, parts, 3)
rawClaims, err := base64.RawURLEncoding.DecodeString(parts[1])
assert.NoError(t, err)
claims := make(map[string]any)
err = json.Unmarshal(rawClaims, &claims)
assert.NoError(t, err)
assert.Equal(t, "some-issuer", claims["aud"])
}
func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client {
jwksProvider := mock.NewTestJwksProvider()
return client.NewClient(config, jwksProvider)