mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 08:27:10 +00:00
feat(session): update id_token in session if returned from refresh grant
Co-authored-by: Thomas Krampl <thomas.siegfried.krampl@nav.no>
This commit is contained in:
@@ -12,7 +12,9 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/nais/wonderwall/pkg/retry"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
@@ -94,7 +96,7 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
|
||||
return c.oauth2Config.Exchange(ctx, code, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) {
|
||||
func (c *Client) RefreshGrant(ctx context.Context, refreshToken, previousIDToken, expectedAcr string) (*openid.TokenResponse, error) {
|
||||
ctx, span := otel.StartSpan(ctx, "Client.RefreshGrant")
|
||||
defer span.End()
|
||||
clientAuth, err := c.ClientAuthenticationParams()
|
||||
@@ -116,6 +118,24 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
|
||||
return nil, fmt.Errorf("unmarshalling token response: %w", err)
|
||||
}
|
||||
|
||||
// id_tokens may not always be returned from a refresh grant (OpenID Connect Core 12.1)
|
||||
if tokenResponse.IDToken != "" {
|
||||
jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting jwks: %w", err)
|
||||
}
|
||||
|
||||
err = openid.ValidateRefreshedIDToken(c.cfg, previousIDToken, tokenResponse.IDToken, expectedAcr, jwkSet)
|
||||
if err != nil {
|
||||
if errors.Is(err, jws.VerificationError()) {
|
||||
// JWKS might not be up to date, so we'll want to force a refresh for the next attempt
|
||||
_, _ = c.jwksProvider.RefreshPublicJwkSet(ctx)
|
||||
return nil, retry.RetryableError(err)
|
||||
}
|
||||
return nil, fmt.Errorf("validating refreshed id token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openid
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -44,7 +45,9 @@ func NewTokens(src *oauth2.Token, jwks *jwk.Set, cfg openidconfig.Config, cookie
|
||||
return nil, fmt.Errorf("parsing id_token: %w", err)
|
||||
}
|
||||
|
||||
if err := idToken.Validate(cfg, cookie, jwks); err != nil {
|
||||
expectedAcr := cookie.Acr
|
||||
expectedNonce := cookie.Nonce
|
||||
if err := idToken.Validate(cfg, expectedAcr, expectedNonce, jwks); err != nil {
|
||||
return nil, fmt.Errorf("validating id_token: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,7 +87,7 @@ type IDToken struct {
|
||||
jwt.Token
|
||||
}
|
||||
|
||||
func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *jwk.Set) error {
|
||||
func (in *IDToken) Validate(cfg openidconfig.Config, expectedAcr, expectedNonce string, jwks *jwk.Set) error {
|
||||
openIDconfig := cfg.Provider()
|
||||
clientConfig := cfg.Client()
|
||||
|
||||
@@ -107,12 +110,15 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *
|
||||
// The Client MUST validate that the `aud` (audience) Claim contains its `client_id` value registered at the Issuer identified by the `iss` (issuer) Claim as an audience.
|
||||
// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience
|
||||
jwt.WithAudience(clientConfig.ClientID()),
|
||||
// OpenID Connect Core section 3.1.3.7, step 11.
|
||||
// If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request.
|
||||
jwt.WithClaimValue("nonce", cookie.Nonce),
|
||||
|
||||
// Skew tolerance for time-based claims (exp, iat, nbf)
|
||||
jwt.WithAcceptableSkew(AcceptableSkew),
|
||||
}
|
||||
if expectedNonce != "" {
|
||||
// OpenID Connect Core section 3.1.3.7, step 11.
|
||||
// If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request.
|
||||
opts = append(opts, jwt.WithClaimValue("nonce", expectedNonce))
|
||||
}
|
||||
|
||||
if openIDconfig.SidClaimRequired() {
|
||||
opts = append(opts, jwt.WithRequiredClaim(SidClaim))
|
||||
@@ -122,12 +128,8 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *
|
||||
// If the `acr` Claim was requested, the Client SHOULD check that the asserted Claim Value is appropriate.
|
||||
if len(clientConfig.ACRValues()) > 0 {
|
||||
opts = append(opts, jwt.WithRequiredClaim(AcrClaim))
|
||||
|
||||
if len(cookie.Acr) > 0 {
|
||||
actual := in.Acr()
|
||||
expected := cookie.Acr
|
||||
|
||||
err := acr.Validate(expected, actual)
|
||||
if expectedAcr != "" {
|
||||
err := acr.Validate(expectedAcr, in.Acr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -280,3 +282,91 @@ func (in *IDToken) TimeClaim(claim string) time.Time {
|
||||
// time claims are NumericDate, which is the number of seconds since Epoch.
|
||||
return time.Unix(int64(claimTime), 0)
|
||||
}
|
||||
|
||||
// ValidateRefreshedIDToken validates a refreshed id_token against the previous one, as per OpenID Connect Core, section 12.2
|
||||
func ValidateRefreshedIDToken(cfg openidconfig.Config, previous, refreshed, expectedAcr string, jwks *jwk.Set) error {
|
||||
previousToken, err := ParseIDToken(previous)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing previous id_token: %w", err)
|
||||
}
|
||||
|
||||
refreshedToken, err := ParseIDToken(refreshed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing current id_token: %w", err)
|
||||
}
|
||||
|
||||
// its iss Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
|
||||
previousIssuer, ok := previousToken.Issuer()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'iss' claim in previous id_token")
|
||||
}
|
||||
refreshedIssuer, ok := refreshedToken.Issuer()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'iss' claim in refreshed id_token")
|
||||
}
|
||||
if previousIssuer != refreshedIssuer {
|
||||
return fmt.Errorf("'iss' claim mismatch, expected %q, got %q", previousIssuer, refreshedIssuer)
|
||||
}
|
||||
|
||||
// its sub Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
|
||||
previousSubject, ok := previousToken.Subject()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'sub' claim in previous id_token")
|
||||
}
|
||||
refreshedSubject, ok := refreshedToken.Subject()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'sub' claim in refreshed id_token")
|
||||
}
|
||||
if previousSubject != refreshedSubject {
|
||||
return fmt.Errorf("'sub' claim mismatch, expected %q, got %q", previousSubject, refreshedSubject)
|
||||
}
|
||||
|
||||
// its iat Claim MUST represent the time that the new ID Token is issued
|
||||
previousIat, ok := previousToken.IssuedAt()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'iat' claim in previous id_token")
|
||||
}
|
||||
refreshedIat, ok := refreshedToken.IssuedAt()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'iat' claim in refreshed id_token")
|
||||
}
|
||||
if refreshedIat.Equal(previousIat) || refreshedIat.Before(previousIat) {
|
||||
return fmt.Errorf("'iat' claim in refreshed id_token must be greater than previous id_token, expected > %q, got %q", previousIat, refreshedIat)
|
||||
}
|
||||
|
||||
// its aud Claim Value MUST be the same as in the ID Token issued when the original authentication occurred
|
||||
previousAudience, ok := previousToken.Audience()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'aud' claim in previous id_token")
|
||||
}
|
||||
refreshedAudience, ok := refreshedToken.Audience()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'aud' claim in refreshed id_token")
|
||||
}
|
||||
slices.Sort(previousAudience)
|
||||
slices.Sort(refreshedAudience)
|
||||
if !slices.Equal(previousAudience, refreshedAudience) {
|
||||
return fmt.Errorf("'aud' claim mismatch, expected %q, got %q", previousAudience, refreshedAudience)
|
||||
}
|
||||
|
||||
// if the ID Token contains an auth_time Claim, its value MUST represent the time of the original authentication - not the time that the new ID token is issued
|
||||
if refreshedAuthTime := refreshedToken.AuthTime(); !refreshedAuthTime.IsZero() {
|
||||
previousAuthTime := previousToken.AuthTime()
|
||||
if !refreshedAuthTime.Equal(previousAuthTime) {
|
||||
return fmt.Errorf("'auth_time' claim mismatch, expected %q, got %q", previousAuthTime, refreshedAuthTime)
|
||||
}
|
||||
}
|
||||
|
||||
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of the original authentication contained nonce;
|
||||
// however, if it is present, its value MUST be the same as in the ID Token issued at the time of the original authentication
|
||||
refreshedNonce := refreshedToken.StringClaimOrEmpty("nonce")
|
||||
if refreshedNonce != "" {
|
||||
previousNonce := previousToken.StringClaimOrEmpty("nonce")
|
||||
if previousNonce != refreshedNonce {
|
||||
return fmt.Errorf("'nonce' claim mismatch, expected %q, got %q", previousNonce, refreshedNonce)
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, the same rules apply as apply when issuing an ID Token at the time of the original authentication
|
||||
return refreshedToken.Validate(cfg, expectedAcr, refreshedNonce, jwks)
|
||||
}
|
||||
|
||||
@@ -188,12 +188,6 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
defaultCookie := func() *openid.LoginCookie {
|
||||
return &openid.LoginCookie{
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
claims *claims
|
||||
@@ -364,7 +358,8 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
openidcfg := defaultOpenIdConfig(cfg)
|
||||
cookie := defaultCookie()
|
||||
expectedNonce := "some-nonce"
|
||||
expectedAcr := ""
|
||||
|
||||
c := defaultClaims(openidcfg)
|
||||
c.merge(tt.claims)
|
||||
@@ -376,14 +371,184 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
|
||||
if tt.requireAcr {
|
||||
cfg.OpenID.ACRValues = "some-acr"
|
||||
cookie.Acr = "some-acr"
|
||||
expectedAcr = "some-acr"
|
||||
c.setIfUnset("acr", "some-acr")
|
||||
}
|
||||
|
||||
idToken, err := makeIDToken(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = idToken.Validate(openidcfg, cookie, &jwks.Public)
|
||||
err = idToken.Validate(openidcfg, expectedAcr, expectedNonce, &jwks.Public)
|
||||
if tt.expectErr != "" {
|
||||
assert.ErrorContains(t, err, tt.expectErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRefreshedIDToken(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
previous *claims
|
||||
refreshed *claims
|
||||
requireAcr bool
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
},
|
||||
{
|
||||
name: "issuer mismatch",
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"iss": "https://some-other-issuer",
|
||||
},
|
||||
},
|
||||
expectErr: `'iss' claim mismatch`,
|
||||
},
|
||||
{
|
||||
name: "subject mismatch",
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"sub": "some-other-sub",
|
||||
},
|
||||
},
|
||||
expectErr: `'sub' claim mismatch`,
|
||||
},
|
||||
{
|
||||
name: "iat unchanged",
|
||||
previous: &claims{
|
||||
set: map[string]any{
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
expectErr: "'iat' claim in refreshed id_token must be greater than previous id_token",
|
||||
},
|
||||
{
|
||||
name: "audience mismatch",
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"aud": []string{"some-client id", "trusted-id-1"},
|
||||
},
|
||||
},
|
||||
expectErr: `'aud' claim mismatch`,
|
||||
},
|
||||
{
|
||||
name: "auth_time mismatch",
|
||||
previous: &claims{
|
||||
set: map[string]any{
|
||||
"auth_time": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"auth_time": time.Now().Add(5 * time.Second).Unix(),
|
||||
},
|
||||
},
|
||||
expectErr: "'auth_time' claim mismatch",
|
||||
},
|
||||
{
|
||||
name: "nonce mismatch",
|
||||
previous: &claims{
|
||||
set: map[string]any{
|
||||
"nonce": "some-nonce",
|
||||
},
|
||||
},
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"nonce": "some-other-nonce",
|
||||
},
|
||||
},
|
||||
expectErr: "'nonce' claim mismatch",
|
||||
},
|
||||
{
|
||||
name: "acr mismatch",
|
||||
previous: &claims{
|
||||
set: map[string]any{
|
||||
"acr": "some-acr",
|
||||
},
|
||||
},
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"acr": "some-other-acr",
|
||||
},
|
||||
},
|
||||
requireAcr: true,
|
||||
expectErr: `invalid acr: got "some-other-acr", expected "some-acr"`,
|
||||
},
|
||||
{
|
||||
name: "iat is in the future",
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"iat": time.Now().Add(openid.AcceptableSkew + 5*time.Second).Unix(),
|
||||
},
|
||||
},
|
||||
expectErr: `"iat" not satisfied`,
|
||||
},
|
||||
{
|
||||
name: "exp is in the past",
|
||||
refreshed: &claims{
|
||||
set: map[string]any{
|
||||
"exp": time.Now().Add(-openid.AcceptableSkew - 5*time.Second).Unix(),
|
||||
},
|
||||
},
|
||||
expectErr: `"exp" not satisfied`,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.ACRValues = ""
|
||||
cfg.OpenID.ClientID = "some-client-id"
|
||||
cfg.OpenID.Audiences = []string{"trusted-id-1", "trusted-id-2"}
|
||||
|
||||
openidcfg := mock.NewTestConfiguration(cfg)
|
||||
openidcfg.TestProvider.SetIssuer("https://some-issuer")
|
||||
|
||||
previous := &claims{
|
||||
set: map[string]any{
|
||||
"aud": openidcfg.Client().ClientID(),
|
||||
"iss": openidcfg.Provider().Issuer(),
|
||||
"sub": "some-sub",
|
||||
},
|
||||
}
|
||||
previous.merge(tt.previous)
|
||||
previousIDToken, err := makeIDToken(previous)
|
||||
require.NoError(t, err)
|
||||
|
||||
previousIssuedAt, ok := previousIDToken.IssuedAt()
|
||||
require.True(t, ok)
|
||||
previousExpiry, ok := previousIDToken.Expiration()
|
||||
require.True(t, ok)
|
||||
refreshed := &claims{
|
||||
set: map[string]any{
|
||||
"aud": openidcfg.Client().ClientID(),
|
||||
"iss": openidcfg.Provider().Issuer(),
|
||||
"sub": "some-sub",
|
||||
"iat": previousIssuedAt.Add(5 * time.Second).Unix(),
|
||||
"exp": previousExpiry.Add(5 * time.Second).Unix(),
|
||||
},
|
||||
}
|
||||
refreshed.merge(tt.refreshed)
|
||||
refreshedIDToken, err := makeIDToken(refreshed)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedAcr := ""
|
||||
if tt.requireAcr {
|
||||
cfg.OpenID.ACRValues = "some-acr"
|
||||
expectedAcr = "some-acr"
|
||||
}
|
||||
err = openid.ValidateRefreshedIDToken(openidcfg,
|
||||
previousIDToken.Serialized(),
|
||||
refreshedIDToken.Serialized(),
|
||||
expectedAcr,
|
||||
&jwks.Public)
|
||||
if tt.expectErr != "" {
|
||||
assert.ErrorContains(t, err, tt.expectErr)
|
||||
} else {
|
||||
|
||||
@@ -210,7 +210,7 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
|
||||
|
||||
logger.Debug("session: performing refresh grant...")
|
||||
resp, err := retry.DoValue(ctx, func(ctx context.Context) (*openid.TokenResponse, error) {
|
||||
resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken)
|
||||
resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken, sess.data.IDToken, sess.data.Acr)
|
||||
if errors.Is(err, openidclient.ErrOpenIDServer) {
|
||||
return nil, retry.RetryableError(err)
|
||||
}
|
||||
@@ -226,6 +226,10 @@ func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
|
||||
return nil, fmt.Errorf("performing refresh: %w", err)
|
||||
}
|
||||
|
||||
// id_tokens may not always be returned from a refresh grant (OpenID Connect 12.1)
|
||||
if resp.IDToken != "" {
|
||||
sess.data.IDToken = resp.IDToken
|
||||
}
|
||||
sess.data.AccessToken = resp.AccessToken
|
||||
sess.data.RefreshToken = resp.RefreshToken
|
||||
sess.data.Metadata.Refresh(resp.ExpiresIn)
|
||||
|
||||
Reference in New Issue
Block a user