feat(session): update id_token in session if returned from refresh grant

Co-authored-by: Thomas Krampl <thomas.siegfried.krampl@nav.no>
This commit is contained in:
Trong Huu Nguyen
2025-05-22 11:51:06 +02:00
parent 192cd86022
commit c5ec362e60
5 changed files with 302 additions and 22 deletions

View File

@@ -12,7 +12,9 @@ import (
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/nais/wonderwall/pkg/retry"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
@@ -94,7 +96,7 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
return c.oauth2Config.Exchange(ctx, code, opts...)
}
func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) {
func (c *Client) RefreshGrant(ctx context.Context, refreshToken, previousIDToken, expectedAcr string) (*openid.TokenResponse, error) {
ctx, span := otel.StartSpan(ctx, "Client.RefreshGrant")
defer span.End()
clientAuth, err := c.ClientAuthenticationParams()
@@ -116,6 +118,24 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
return nil, fmt.Errorf("unmarshalling token response: %w", err)
}
// id_tokens may not always be returned from a refresh grant (OpenID Connect Core 12.1)
if tokenResponse.IDToken != "" {
jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx)
if err != nil {
return nil, fmt.Errorf("getting jwks: %w", err)
}
err = openid.ValidateRefreshedIDToken(c.cfg, previousIDToken, tokenResponse.IDToken, expectedAcr, jwkSet)
if err != nil {
if errors.Is(err, jws.VerificationError()) {
// JWKS might not be up to date, so we'll want to force a refresh for the next attempt
_, _ = c.jwksProvider.RefreshPublicJwkSet(ctx)
return nil, retry.RetryableError(err)
}
return nil, fmt.Errorf("validating refreshed id token: %w", err)
}
}
return &tokenResponse, nil
}

View File

@@ -13,6 +13,7 @@ import (
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token,omitempty"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
}

View File

@@ -2,6 +2,7 @@ package openid
import (
"fmt"
"slices"
"strings"
"time"
@@ -44,7 +45,9 @@ func NewTokens(src *oauth2.Token, jwks *jwk.Set, cfg openidconfig.Config, cookie
return nil, fmt.Errorf("parsing id_token: %w", err)
}
if err := idToken.Validate(cfg, cookie, jwks); err != nil {
expectedAcr := cookie.Acr
expectedNonce := cookie.Nonce
if err := idToken.Validate(cfg, expectedAcr, expectedNonce, jwks); err != nil {
return nil, fmt.Errorf("validating id_token: %w", err)
}
@@ -84,7 +87,7 @@ type IDToken struct {
jwt.Token
}
func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *jwk.Set) error {
func (in *IDToken) Validate(cfg openidconfig.Config, expectedAcr, expectedNonce string, jwks *jwk.Set) error {
openIDconfig := cfg.Provider()
clientConfig := cfg.Client()
@@ -107,12 +110,15 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *
// The Client MUST validate that the `aud` (audience) Claim contains its `client_id` value registered at the Issuer identified by the `iss` (issuer) Claim as an audience.
// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience
jwt.WithAudience(clientConfig.ClientID()),
// OpenID Connect Core section 3.1.3.7, step 11.
// If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request.
jwt.WithClaimValue("nonce", cookie.Nonce),
// Skew tolerance for time-based claims (exp, iat, nbf)
jwt.WithAcceptableSkew(AcceptableSkew),
}
if expectedNonce != "" {
// OpenID Connect Core section 3.1.3.7, step 11.
// If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request.
opts = append(opts, jwt.WithClaimValue("nonce", expectedNonce))
}
if openIDconfig.SidClaimRequired() {
opts = append(opts, jwt.WithRequiredClaim(SidClaim))
@@ -122,12 +128,8 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *
// If the `acr` Claim was requested, the Client SHOULD check that the asserted Claim Value is appropriate.
if len(clientConfig.ACRValues()) > 0 {
opts = append(opts, jwt.WithRequiredClaim(AcrClaim))
if len(cookie.Acr) > 0 {
actual := in.Acr()
expected := cookie.Acr
err := acr.Validate(expected, actual)
if expectedAcr != "" {
err := acr.Validate(expectedAcr, in.Acr())
if err != nil {
return err
}
@@ -280,3 +282,91 @@ func (in *IDToken) TimeClaim(claim string) time.Time {
// time claims are NumericDate, which is the number of seconds since Epoch.
return time.Unix(int64(claimTime), 0)
}
// ValidateRefreshedIDToken validates a refreshed id_token against the previous one, as per OpenID Connect Core, section 12.2
func ValidateRefreshedIDToken(cfg openidconfig.Config, previous, refreshed, expectedAcr string, jwks *jwk.Set) error {
previousToken, err := ParseIDToken(previous)
if err != nil {
return fmt.Errorf("parsing previous id_token: %w", err)
}
refreshedToken, err := ParseIDToken(refreshed)
if err != nil {
return fmt.Errorf("parsing current id_token: %w", err)
}
// its iss Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
previousIssuer, ok := previousToken.Issuer()
if !ok {
return fmt.Errorf("missing required 'iss' claim in previous id_token")
}
refreshedIssuer, ok := refreshedToken.Issuer()
if !ok {
return fmt.Errorf("missing required 'iss' claim in refreshed id_token")
}
if previousIssuer != refreshedIssuer {
return fmt.Errorf("'iss' claim mismatch, expected %q, got %q", previousIssuer, refreshedIssuer)
}
// its sub Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
previousSubject, ok := previousToken.Subject()
if !ok {
return fmt.Errorf("missing required 'sub' claim in previous id_token")
}
refreshedSubject, ok := refreshedToken.Subject()
if !ok {
return fmt.Errorf("missing required 'sub' claim in refreshed id_token")
}
if previousSubject != refreshedSubject {
return fmt.Errorf("'sub' claim mismatch, expected %q, got %q", previousSubject, refreshedSubject)
}
// its iat Claim MUST represent the time that the new ID Token is issued
previousIat, ok := previousToken.IssuedAt()
if !ok {
return fmt.Errorf("missing required 'iat' claim in previous id_token")
}
refreshedIat, ok := refreshedToken.IssuedAt()
if !ok {
return fmt.Errorf("missing required 'iat' claim in refreshed id_token")
}
if refreshedIat.Equal(previousIat) || refreshedIat.Before(previousIat) {
return fmt.Errorf("'iat' claim in refreshed id_token must be greater than previous id_token, expected > %q, got %q", previousIat, refreshedIat)
}
// its aud Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
previousAudience, ok := previousToken.Audience()
if !ok {
return fmt.Errorf("missing required 'aud' claim in previous id_token")
}
refreshedAudience, ok := refreshedToken.Audience()
if !ok {
return fmt.Errorf("missing required 'aud' claim in refreshed id_token")
}
slices.Sort(previousAudience)
slices.Sort(refreshedAudience)
if !slices.Equal(previousAudience, refreshedAudience) {
return fmt.Errorf("'aud' claim mismatch, expected %q, got %q", previousAudience, refreshedAudience)
}
// if the ID Token contains an auth_time Claim, its value MUST represent the time of the original authentication - not the time that the new ID token is issued
if refreshedAuthTime := refreshedToken.AuthTime(); !refreshedAuthTime.IsZero() {
previousAuthTime := previousToken.AuthTime()
if !refreshedAuthTime.Equal(previousAuthTime) {
return fmt.Errorf("'auth_time' claim mismatch, expected %q, got %q", previousAuthTime, refreshedAuthTime)
}
}
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of the original authentication contained nonce;
// however, if it is present, its value MUST be the same as in the ID Token issued at the time of the original authentication
refreshedNonce := refreshedToken.StringClaimOrEmpty("nonce")
if refreshedNonce != "" {
previousNonce := previousToken.StringClaimOrEmpty("nonce")
if previousNonce != refreshedNonce {
return fmt.Errorf("'nonce' claim mismatch, expected %q, got %q", previousNonce, refreshedNonce)
}
}
// otherwise, the same rules apply as apply when issuing an ID Token at the time of the original authentication
return refreshedToken.Validate(cfg, expectedAcr, refreshedNonce, jwks)
}

View File

@@ -188,12 +188,6 @@ func TestIDToken_Validate(t *testing.T) {
}
}
defaultCookie := func() *openid.LoginCookie {
return &openid.LoginCookie{
Nonce: "some-nonce",
}
}
for _, tt := range []struct {
name string
claims *claims
@@ -364,7 +358,8 @@ func TestIDToken_Validate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
cfg := defaultConfig()
openidcfg := defaultOpenIdConfig(cfg)
cookie := defaultCookie()
expectedNonce := "some-nonce"
expectedAcr := ""
c := defaultClaims(openidcfg)
c.merge(tt.claims)
@@ -376,14 +371,184 @@ func TestIDToken_Validate(t *testing.T) {
if tt.requireAcr {
cfg.OpenID.ACRValues = "some-acr"
cookie.Acr = "some-acr"
expectedAcr = "some-acr"
c.setIfUnset("acr", "some-acr")
}
idToken, err := makeIDToken(c)
require.NoError(t, err)
err = idToken.Validate(openidcfg, cookie, &jwks.Public)
err = idToken.Validate(openidcfg, expectedAcr, expectedNonce, &jwks.Public)
if tt.expectErr != "" {
assert.ErrorContains(t, err, tt.expectErr)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateRefreshedIDToken(t *testing.T) {
for _, tt := range []struct {
name string
previous *claims
refreshed *claims
requireAcr bool
expectErr string
}{
{
name: "happy path",
},
{
name: "issuer mismatch",
refreshed: &claims{
set: map[string]any{
"iss": "https://some-other-issuer",
},
},
expectErr: `'iss' claim mismatch`,
},
{
name: "subject mismatch",
refreshed: &claims{
set: map[string]any{
"sub": "some-other-sub",
},
},
expectErr: `'sub' claim mismatch`,
},
{
name: "iat unchanged",
previous: &claims{
set: map[string]any{
"iat": time.Now().Unix(),
},
},
refreshed: &claims{
set: map[string]any{
"iat": time.Now().Unix(),
},
},
expectErr: "'iat' claim in refreshed id_token must be greater than previous id_token",
},
{
name: "audience mismatch",
refreshed: &claims{
set: map[string]any{
"aud": []string{"some-client id", "trusted-id-1"},
},
},
expectErr: `'aud' claim mismatch`,
},
{
name: "auth_time mismatch",
previous: &claims{
set: map[string]any{
"auth_time": time.Now().Unix(),
},
},
refreshed: &claims{
set: map[string]any{
"auth_time": time.Now().Add(5 * time.Second).Unix(),
},
},
expectErr: "'auth_time' claim mismatch",
},
{
name: "nonce mismatch",
previous: &claims{
set: map[string]any{
"nonce": "some-nonce",
},
},
refreshed: &claims{
set: map[string]any{
"nonce": "some-other-nonce",
},
},
expectErr: "'nonce' claim mismatch",
},
{
name: "acr mismatch",
previous: &claims{
set: map[string]any{
"acr": "some-acr",
},
},
refreshed: &claims{
set: map[string]any{
"acr": "some-other-acr",
},
},
requireAcr: true,
expectErr: `invalid acr: got "some-other-acr", expected "some-acr"`,
},
{
name: "iat is in the future",
refreshed: &claims{
set: map[string]any{
"iat": time.Now().Add(openid.AcceptableSkew + 5*time.Second).Unix(),
},
},
expectErr: `"iat" not satisfied`,
},
{
name: "exp is in the past",
refreshed: &claims{
set: map[string]any{
"exp": time.Now().Add(-openid.AcceptableSkew - 5*time.Second).Unix(),
},
},
expectErr: `"exp" not satisfied`,
},
} {
t.Run(tt.name, func(t *testing.T) {
cfg := mock.Config()
cfg.OpenID.ACRValues = ""
cfg.OpenID.ClientID = "some-client-id"
cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"}
openidcfg := mock.NewTestConfiguration(cfg)
openidcfg.TestProvider.SetIssuer("https://some-issuer")
previous := &claims{
set: map[string]any{
"aud": openidcfg.Client().ClientID(),
"iss": openidcfg.Provider().Issuer(),
"sub": "some-sub",
},
}
previous.merge(tt.previous)
previousIDToken, err := makeIDToken(previous)
require.NoError(t, err)
previousIssuedAt, ok := previousIDToken.IssuedAt()
require.True(t, ok)
previousExpiry, ok := previousIDToken.Expiration()
require.True(t, ok)
refreshed := &claims{
set: map[string]any{
"aud": openidcfg.Client().ClientID(),
"iss": openidcfg.Provider().Issuer(),
"sub": "some-sub",
"iat": previousIssuedAt.Add(5 * time.Second).Unix(),
"exp": previousExpiry.Add(5 * time.Second).Unix(),
},
}
refreshed.merge(tt.refreshed)
refreshedIDToken, err := makeIDToken(refreshed)
require.NoError(t, err)
expectedAcr := ""
if tt.requireAcr {
cfg.OpenID.ACRValues = "some-acr"
expectedAcr = "some-acr"
}
err = openid.ValidateRefreshedIDToken(openidcfg,
previousIDToken.Serialized(),
refreshedIDToken.Serialized(),
expectedAcr,
&jwks.Public)
if tt.expectErr != "" {
assert.ErrorContains(t, err, tt.expectErr)
} else {

View File

@@ -210,7 +210,7 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
logger.Debug("session: performing refresh grant...")
resp, err := retry.DoValue(ctx, func(ctx context.Context) (*openid.TokenResponse, error) {
resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken)
resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken, sess.data.IDToken, sess.data.Acr)
if errors.Is(err, openidclient.ErrOpenIDServer) {
return nil, retry.RetryableError(err)
}
@@ -226,6 +226,10 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
return nil, fmt.Errorf("performing refresh: %w", err)
}
// id_tokens may not always be returned from a refresh grant (OpenID Connect 12.1)
if resp.IDToken != "" {
sess.data.IDToken = resp.IDToken
}
sess.data.AccessToken = resp.AccessToken
sess.data.RefreshToken = resp.RefreshToken
sess.data.Metadata.Refresh(resp.ExpiresIn)