From c5ec362e609c0ecfbf81acace9bf128cffab681b Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 22 May 2025 11:51:06 +0200 Subject: [PATCH] feat(session): update id_token in session if returned from refresh grant Co-authored-by: Thomas Krampl --- pkg/openid/client/client.go | 22 +++- pkg/openid/oauth2.go | 1 + pkg/openid/tokens.go | 112 ++++++++++++++++++-- pkg/openid/tokens_test.go | 183 +++++++++++++++++++++++++++++++-- pkg/session/session_manager.go | 6 +- 5 files changed, 302 insertions(+), 22 deletions(-) diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index 3ffa4e9..f5b7e41 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -12,7 +12,9 @@ import ( "github.com/google/uuid" "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jws" "github.com/lestrrat-go/jwx/v3/jwt" + "github.com/nais/wonderwall/pkg/retry" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" @@ -94,7 +96,7 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A return c.oauth2Config.Exchange(ctx, code, opts...) } -func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) { +func (c *Client) RefreshGrant(ctx context.Context, refreshToken, previousIDToken, expectedAcr string) (*openid.TokenResponse, error) { ctx, span := otel.StartSpan(ctx, "Client.RefreshGrant") defer span.End() clientAuth, err := c.ClientAuthenticationParams() @@ -116,6 +118,24 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid return nil, fmt.Errorf("unmarshalling token response: %w", err) } + // id_tokens may not always be returned from a refresh grant (OpenID Connect Core 12.1) + if tokenResponse.IDToken != "" { + jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx) + if err != nil { + return nil, fmt.Errorf("getting jwks: %w", err) + } + + err = openid.ValidateRefreshedIDToken(c.cfg, previousIDToken, tokenResponse.IDToken, expectedAcr, jwkSet) + if err != nil { + if errors.Is(err, jws.VerificationError()) { + // JWKS might not be up to date, so we'll want to force a refresh for the next attempt + _, _ = c.jwksProvider.RefreshPublicJwkSet(ctx) + return nil, retry.RetryableError(err) + } + return nil, fmt.Errorf("validating refreshed id token: %w", err) + } + } + return &tokenResponse, nil } diff --git a/pkg/openid/oauth2.go b/pkg/openid/oauth2.go index 00f5662..020c754 100644 --- a/pkg/openid/oauth2.go +++ b/pkg/openid/oauth2.go @@ -13,6 +13,7 @@ import ( type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token,omitempty"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` } diff --git a/pkg/openid/tokens.go b/pkg/openid/tokens.go index b2aba3d..e8404d4 100644 --- a/pkg/openid/tokens.go +++ b/pkg/openid/tokens.go @@ -2,6 +2,7 @@ package openid import ( "fmt" + "slices" "strings" "time" @@ -44,7 +45,9 @@ func NewTokens(src *oauth2.Token, jwks *jwk.Set, cfg openidconfig.Config, cookie return nil, fmt.Errorf("parsing id_token: %w", err) } - if err := idToken.Validate(cfg, cookie, jwks); err != nil { + expectedAcr := cookie.Acr + expectedNonce := cookie.Nonce + if err := idToken.Validate(cfg, expectedAcr, expectedNonce, jwks); err != nil { return nil, fmt.Errorf("validating id_token: %w", err) } @@ -84,7 +87,7 @@ type IDToken struct { jwt.Token } -func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *jwk.Set) error { +func (in *IDToken) Validate(cfg openidconfig.Config, expectedAcr, expectedNonce string, jwks *jwk.Set) error { openIDconfig := cfg.Provider() clientConfig := cfg.Client() @@ -107,12 +110,15 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks * // The Client MUST validate that the `aud` (audience) Claim contains its `client_id` value registered at the Issuer identified by the `iss` (issuer) Claim as an audience. // The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience jwt.WithAudience(clientConfig.ClientID()), - // OpenID Connect Core section 3.1.3.7, step 11. - // If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request. - jwt.WithClaimValue("nonce", cookie.Nonce), + // Skew tolerance for time-based claims (exp, iat, nbf) jwt.WithAcceptableSkew(AcceptableSkew), } + if expectedNonce != "" { + // OpenID Connect Core section 3.1.3.7, step 11. + // If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request. + opts = append(opts, jwt.WithClaimValue("nonce", expectedNonce)) + } if openIDconfig.SidClaimRequired() { opts = append(opts, jwt.WithRequiredClaim(SidClaim)) @@ -122,12 +128,8 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks * // If the `acr` Claim was requested, the Client SHOULD check that the asserted Claim Value is appropriate. if len(clientConfig.ACRValues()) > 0 { opts = append(opts, jwt.WithRequiredClaim(AcrClaim)) - - if len(cookie.Acr) > 0 { - actual := in.Acr() - expected := cookie.Acr - - err := acr.Validate(expected, actual) + if expectedAcr != "" { + err := acr.Validate(expectedAcr, in.Acr()) if err != nil { return err } @@ -280,3 +282,91 @@ func (in *IDToken) TimeClaim(claim string) time.Time { // time claims are NumericDate, which is the number of seconds since Epoch. return time.Unix(int64(claimTime), 0) } + +// ValidateRefreshedIDToken validates a refreshed id_token against the previous one, as per OpenID Connect Core, section 12.2 +func ValidateRefreshedIDToken(cfg openidconfig.Config, previous, refreshed, expectedAcr string, jwks *jwk.Set) error { + previousToken, err := ParseIDToken(previous) + if err != nil { + return fmt.Errorf("parsing previous id_token: %w", err) + } + + refreshedToken, err := ParseIDToken(refreshed) + if err != nil { + return fmt.Errorf("parsing current id_token: %w", err) + } + + // its iss Claim Value MUST be the same as in the ID Token issued when the original authentication occurred + previousIssuer, ok := previousToken.Issuer() + if !ok { + return fmt.Errorf("missing required 'iss' claim in previous id_token") + } + refreshedIssuer, ok := refreshedToken.Issuer() + if !ok { + return fmt.Errorf("missing required 'iss' claim in refreshed id_token") + } + if previousIssuer != refreshedIssuer { + return fmt.Errorf("'iss' claim mismatch, expected %q, got %q", previousIssuer, refreshedIssuer) + } + + // its sub Claim Value MUST be the same as in the ID Token issued when the original authentication occurred + previousSubject, ok := previousToken.Subject() + if !ok { + return fmt.Errorf("missing required 'sub' claim in previous id_token") + } + refreshedSubject, ok := refreshedToken.Subject() + if !ok { + return fmt.Errorf("missing required 'sub' claim in refreshed id_token") + } + if previousSubject != refreshedSubject { + return fmt.Errorf("'sub' claim mismatch, expected %q, got %q", previousSubject, refreshedSubject) + } + + // its iat Claim MUST represent the time that the new ID Token is issued + previousIat, ok := previousToken.IssuedAt() + if !ok { + return fmt.Errorf("missing required 'iat' claim in previous id_token") + } + refreshedIat, ok := refreshedToken.IssuedAt() + if !ok { + return fmt.Errorf("missing required 'iat' claim in refreshed id_token") + } + if refreshedIat.Equal(previousIat) || refreshedIat.Before(previousIat) { + return fmt.Errorf("'iat' claim in refreshed id_token must be greater than previous id_token, expected > %q, got %q", previousIat, refreshedIat) + } + + // its aud Claim Value MUST be the same as in the ID Token issued when the original authentication occurred + previousAudience, ok := previousToken.Audience() + if !ok { + return fmt.Errorf("missing required 'aud' claim in previous id_token") + } + refreshedAudience, ok := refreshedToken.Audience() + if !ok { + return fmt.Errorf("missing required 'aud' claim in refreshed id_token") + } + slices.Sort(previousAudience) + slices.Sort(refreshedAudience) + if !slices.Equal(previousAudience, refreshedAudience) { + return fmt.Errorf("'aud' claim mismatch, expected %q, got %q", previousAudience, refreshedAudience) + } + + // if the ID Token contains an auth_time Claim, its value MUST represent the time of the original authentication - not the time that the new ID token is issued + if refreshedAuthTime := refreshedToken.AuthTime(); !refreshedAuthTime.IsZero() { + previousAuthTime := previousToken.AuthTime() + if !refreshedAuthTime.Equal(previousAuthTime) { + return fmt.Errorf("'auth_time' claim mismatch, expected %q, got %q", previousAuthTime, refreshedAuthTime) + } + } + + // it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of the original authentication contained nonce; + // however, if it is present, its value MUST be the same as in the ID Token issued at the time of the original authentication + refreshedNonce := refreshedToken.StringClaimOrEmpty("nonce") + if refreshedNonce != "" { + previousNonce := previousToken.StringClaimOrEmpty("nonce") + if previousNonce != refreshedNonce { + return fmt.Errorf("'nonce' claim mismatch, expected %q, got %q", previousNonce, refreshedNonce) + } + } + + // otherwise, the same rules apply as apply when issuing an ID Token at the time of the original authentication + return refreshedToken.Validate(cfg, expectedAcr, refreshedNonce, jwks) +} diff --git a/pkg/openid/tokens_test.go b/pkg/openid/tokens_test.go index d4b5462..77e92d4 100644 --- a/pkg/openid/tokens_test.go +++ b/pkg/openid/tokens_test.go @@ -188,12 +188,6 @@ func TestIDToken_Validate(t *testing.T) { } } - defaultCookie := func() *openid.LoginCookie { - return &openid.LoginCookie{ - Nonce: "some-nonce", - } - } - for _, tt := range []struct { name string claims *claims @@ -364,7 +358,8 @@ func TestIDToken_Validate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfg := defaultConfig() openidcfg := defaultOpenIdConfig(cfg) - cookie := defaultCookie() + expectedNonce := "some-nonce" + expectedAcr := "" c := defaultClaims(openidcfg) c.merge(tt.claims) @@ -376,14 +371,184 @@ func TestIDToken_Validate(t *testing.T) { if tt.requireAcr { cfg.OpenID.ACRValues = "some-acr" - cookie.Acr = "some-acr" + expectedAcr = "some-acr" c.setIfUnset("acr", "some-acr") } idToken, err := makeIDToken(c) require.NoError(t, err) - err = idToken.Validate(openidcfg, cookie, &jwks.Public) + err = idToken.Validate(openidcfg, expectedAcr, expectedNonce, &jwks.Public) + if tt.expectErr != "" { + assert.ErrorContains(t, err, tt.expectErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateRefreshedIDToken(t *testing.T) { + for _, tt := range []struct { + name string + previous *claims + refreshed *claims + requireAcr bool + expectErr string + }{ + { + name: "happy path", + }, + { + name: "issuer mismatch", + refreshed: &claims{ + set: map[string]any{ + "iss": "https://some-other-issuer", + }, + }, + expectErr: `'iss' claim mismatch`, + }, + { + name: "subject mismatch", + refreshed: &claims{ + set: map[string]any{ + "sub": "some-other-sub", + }, + }, + expectErr: `'sub' claim mismatch`, + }, + { + name: "iat unchanged", + previous: &claims{ + set: map[string]any{ + "iat": time.Now().Unix(), + }, + }, + refreshed: &claims{ + set: map[string]any{ + "iat": time.Now().Unix(), + }, + }, + expectErr: "'iat' claim in refreshed id_token must be greater than previous id_token", + }, + { + name: "audience mismatch", + refreshed: &claims{ + set: map[string]any{ + "aud": []string{"some-client id", "trusted-id-1"}, + }, + }, + expectErr: `'aud' claim mismatch`, + }, + { + name: "auth_time mismatch", + previous: &claims{ + set: map[string]any{ + "auth_time": time.Now().Unix(), + }, + }, + refreshed: &claims{ + set: map[string]any{ + "auth_time": time.Now().Add(5 * time.Second).Unix(), + }, + }, + expectErr: "'auth_time' claim mismatch", + }, + { + name: "nonce mismatch", + previous: &claims{ + set: map[string]any{ + "nonce": "some-nonce", + }, + }, + refreshed: &claims{ + set: map[string]any{ + "nonce": "some-other-nonce", + }, + }, + expectErr: "'nonce' claim mismatch", + }, + { + name: "acr mismatch", + previous: &claims{ + set: map[string]any{ + "acr": "some-acr", + }, + }, + refreshed: &claims{ + set: map[string]any{ + "acr": "some-other-acr", + }, + }, + requireAcr: true, + expectErr: `invalid acr: got "some-other-acr", expected "some-acr"`, + }, + { + name: "iat is in the future", + refreshed: &claims{ + set: map[string]any{ + "iat": time.Now().Add(openid.AcceptableSkew + 5*time.Second).Unix(), + }, + }, + expectErr: `"iat" not satisfied`, + }, + { + name: "exp is in the past", + refreshed: &claims{ + set: map[string]any{ + "exp": time.Now().Add(-openid.AcceptableSkew - 5*time.Second).Unix(), + }, + }, + expectErr: `"exp" not satisfied`, + }, + } { + t.Run(tt.name, func(t *testing.T) { + cfg := mock.Config() + cfg.OpenID.ACRValues = "" + cfg.OpenID.ClientID = "some-client-id" + cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"} + + openidcfg := mock.NewTestConfiguration(cfg) + openidcfg.TestProvider.SetIssuer("https://some-issuer") + + previous := &claims{ + set: map[string]any{ + "aud": openidcfg.Client().ClientID(), + "iss": openidcfg.Provider().Issuer(), + "sub": "some-sub", + }, + } + previous.merge(tt.previous) + previousIDToken, err := makeIDToken(previous) + require.NoError(t, err) + + previousIssuedAt, ok := previousIDToken.IssuedAt() + require.True(t, ok) + previousExpiry, ok := previousIDToken.Expiration() + require.True(t, ok) + refreshed := &claims{ + set: map[string]any{ + "aud": openidcfg.Client().ClientID(), + "iss": openidcfg.Provider().Issuer(), + "sub": "some-sub", + "iat": previousIssuedAt.Add(5 * time.Second).Unix(), + "exp": previousExpiry.Add(5 * time.Second).Unix(), + }, + } + refreshed.merge(tt.refreshed) + refreshedIDToken, err := makeIDToken(refreshed) + require.NoError(t, err) + + expectedAcr := "" + if tt.requireAcr { + cfg.OpenID.ACRValues = "some-acr" + expectedAcr = "some-acr" + } + err = openid.ValidateRefreshedIDToken(openidcfg, + previousIDToken.Serialized(), + refreshedIDToken.Serialized(), + expectedAcr, + &jwks.Public) if tt.expectErr != "" { assert.ErrorContains(t, err, tt.expectErr) } else { diff --git a/pkg/session/session_manager.go b/pkg/session/session_manager.go index 8c5ed0c..dd3af6a 100644 --- a/pkg/session/session_manager.go +++ b/pkg/session/session_manager.go @@ -210,7 +210,7 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) { logger.Debug("session: performing refresh grant...") resp, err := retry.DoValue(ctx, func(ctx context.Context) (*openid.TokenResponse, error) { - resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken) + resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken, sess.data.IDToken, sess.data.Acr) if errors.Is(err, openidclient.ErrOpenIDServer) { return nil, retry.RetryableError(err) } @@ -226,6 +226,10 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) { return nil, fmt.Errorf("performing refresh: %w", err) } + // id_tokens may not always be returned from a refresh grant (OpenID Connect 12.1) + if resp.IDToken != "" { + sess.data.IDToken = resp.IDToken + } sess.data.AccessToken = resp.AccessToken sess.data.RefreshToken = resp.RefreshToken sess.data.Metadata.Refresh(resp.ExpiresIn)