From aab249d78a8a779a3f292824575d34cd427448d1 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 14 Jul 2022 12:14:22 +0200 Subject: [PATCH] refactor(jwt): skip parsing access tokens Access Tokens are not necessarily JWTs. We also don't have to validate them as we only pass it on as an opaque string. This also means that we don't log the JTI access tokens anymore. We also simplify handling of oidc callbacks. --- pkg/jwt/claims.go | 14 --- pkg/jwt/jwt.go | 24 ++++-- pkg/jwt/jwt_access_token.go | 25 ------ pkg/jwt/jwt_id_token.go | 55 ------------ pkg/jwt/jwt_tokens.go | 52 ----------- pkg/loginstatus/loginstatus.go | 9 +- pkg/loginstatus/loginstatus_test.go | 9 +- pkg/openid/client/login_callback.go | 14 +-- pkg/openid/client/login_callback_test.go | 43 ++-------- pkg/openid/config/client.go | 5 +- pkg/openid/{ => config}/redirect_uri.go | 2 +- pkg/openid/{ => config}/redirect_uri_test.go | 6 +- pkg/openid/tokens.go | 90 ++++++++++++++++++++ pkg/router/handler_callback.go | 25 ++---- pkg/router/handler_frontchannellogout.go | 2 +- pkg/router/handler_logout.go | 2 +- pkg/router/session.go | 9 +- pkg/router/session_fallback_test.go | 35 +++----- pkg/session/cookie.go | 28 +++++- pkg/session/memory_test.go | 18 ++-- pkg/session/redis_test.go | 18 ++-- pkg/session/session_data.go | 22 ++--- pkg/session/session_id.go | 4 +- pkg/session/session_id_test.go | 12 +-- 24 files changed, 221 insertions(+), 302 deletions(-) delete mode 100644 pkg/jwt/claims.go delete mode 100644 pkg/jwt/jwt_access_token.go delete mode 100644 pkg/jwt/jwt_id_token.go delete mode 100644 pkg/jwt/jwt_tokens.go rename pkg/openid/{ => config}/redirect_uri.go (96%) rename pkg/openid/{ => config}/redirect_uri_test.go (89%) create mode 100644 pkg/openid/tokens.go diff --git a/pkg/jwt/claims.go b/pkg/jwt/claims.go deleted file mode 100644 index fe3f37d..0000000 --- a/pkg/jwt/claims.go +++ /dev/null @@ -1,14 +0,0 @@ -package jwt - -const ( - JtiClaim = "jti" - SidClaim = "sid" - UtiClaim = "uti" -) - -type Claims struct { - IDTokenJti string `json:"id_token_jti,omitempty"` - IDTokenUti string `json:"id_token_uti,omitempty"` - AccessTokenJti string `json:"access_token_jti,omitempty"` - AccessTokenUti string `json:"access_token_uti,omitempty"` -} diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index e58aaf0..8a8188f 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -11,15 +11,18 @@ import ( const ( AcceptableClockSkew = 10 * time.Second + + JtiClaim = "jti" + SidClaim = "sid" + UtiClaim = "uti" ) type Token interface { GetExpiration() time.Time - GetJtiClaim() string + GetJwtID() string GetSerialized() string GetStringClaim(claim string) (string, error) GetToken() jwt.Token - GetUtiClaim() string } type token struct { @@ -31,8 +34,17 @@ func (in *token) GetExpiration() time.Time { return in.token.Expiration() } -func (in *token) GetJtiClaim() string { - return in.GetStringClaimOrEmpty(JtiClaim) +func (in *token) GetJwtID() string { + jti := in.GetStringClaimOrEmpty(JtiClaim) + uti := in.GetStringClaimOrEmpty(UtiClaim) + + // jti is the standard JWT ID claim + if len(jti) > 0 { + return jti + } + + // else, try to return uti - which seems to be Azure AD's variant + return uti } func (in *token) GetSerialized() string { @@ -70,10 +82,6 @@ func (in *token) GetToken() jwt.Token { return in.token } -func (in *token) GetUtiClaim() string { - return in.GetStringClaimOrEmpty(UtiClaim) -} - func NewToken(raw string, jwtToken jwt.Token) Token { return &token{ serialized: raw, diff --git a/pkg/jwt/jwt_access_token.go b/pkg/jwt/jwt_access_token.go deleted file mode 100644 index 04ae016..0000000 --- a/pkg/jwt/jwt_access_token.go +++ /dev/null @@ -1,25 +0,0 @@ -package jwt - -import ( - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -type AccessToken struct { - Token -} - -func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken { - return &AccessToken{ - NewToken(raw, jwtToken), - } -} - -func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) { - accessToken, err := Parse(raw, jwks) - if err != nil { - return nil, err - } - - return NewAccessToken(raw, accessToken), nil -} diff --git a/pkg/jwt/jwt_id_token.go b/pkg/jwt/jwt_id_token.go deleted file mode 100644 index 502d28e..0000000 --- a/pkg/jwt/jwt_id_token.go +++ /dev/null @@ -1,55 +0,0 @@ -package jwt - -import ( - "time" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" - - openidconfig "github.com/nais/wonderwall/pkg/openid/config" -) - -type IDToken struct { - Token -} - -func (in *IDToken) GetSidClaim() (string, error) { - return in.GetStringClaim(SidClaim) -} - -func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error { - openIDconfig := cfg.Provider() - clientConfig := cfg.Client() - - opts := []jwt.ValidateOption{ - jwt.WithAudience(clientConfig.GetClientID()), - jwt.WithClaimValue("nonce", nonce), - jwt.WithIssuer(openIDconfig.Issuer), - jwt.WithAcceptableSkew(5 * time.Second), - } - - if openIDconfig.SidClaimRequired() { - opts = append(opts, jwt.WithRequiredClaim("sid")) - } - - if len(clientConfig.GetACRValues()) > 0 { - opts = append(opts, jwt.WithRequiredClaim("acr")) - } - - return jwt.Validate(in.GetToken(), opts...) -} - -func NewIDToken(raw string, jwtToken jwt.Token) *IDToken { - return &IDToken{ - NewToken(raw, jwtToken), - } -} - -func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) { - idToken, err := Parse(raw, jwks) - if err != nil { - return nil, err - } - - return NewIDToken(raw, idToken), nil -} diff --git a/pkg/jwt/jwt_tokens.go b/pkg/jwt/jwt_tokens.go deleted file mode 100644 index 5cabe65..0000000 --- a/pkg/jwt/jwt_tokens.go +++ /dev/null @@ -1,52 +0,0 @@ -package jwt - -import ( - "fmt" - - "github.com/lestrrat-go/jwx/v2/jwk" - "golang.org/x/oauth2" -) - -type Tokens struct { - IDToken *IDToken - AccessToken *AccessToken -} - -func (in *Tokens) Claims() Claims { - return Claims{ - IDTokenJti: in.IDToken.GetJtiClaim(), - IDTokenUti: in.IDToken.GetUtiClaim(), - AccessTokenJti: in.AccessToken.GetJtiClaim(), - AccessTokenUti: in.AccessToken.GetUtiClaim(), - } -} - -func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens { - return &Tokens{ - IDToken: idToken, - AccessToken: accessToken, - } -} - -func ParseOauth2Token(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) { - idToken, ok := tokens.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("missing id_token in token response") - } - - return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks) -} - -func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) { - parsedIdToken, err := ParseIDToken(idToken, jwks) - if err != nil { - return nil, fmt.Errorf("id_token: %w", err) - } - - parsedAccessToken, err := ParseAccessToken(accessToken, jwks) - if err != nil { - return nil, fmt.Errorf("access_token: %w", err) - } - - return NewTokens(parsedIdToken, parsedAccessToken), nil -} diff --git a/pkg/loginstatus/loginstatus.go b/pkg/loginstatus/loginstatus.go index d6918b5..ca2f0ca 100644 --- a/pkg/loginstatus/loginstatus.go +++ b/pkg/loginstatus/loginstatus.go @@ -11,7 +11,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" ) const ( @@ -19,7 +18,7 @@ const ( ) type Client interface { - ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) + ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options) HasCookie(r *http.Request) bool ClearCookie(w http.ResponseWriter, opts cookie.Options) @@ -48,7 +47,7 @@ type client struct { httpClient *http.Client } -func (c client) ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) { +func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) { req, err := request(ctx, c.config.TokenURL, accessToken) if err != nil { return nil, fmt.Errorf("creating request %w", err) @@ -101,13 +100,13 @@ func (c client) CookieOptions(opts cookie.Options) cookie.Options { WithPath("/") } -func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Request, error) { +func request(ctx context.Context, url string, token string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized())) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("Accept", "application/json") return req, nil diff --git a/pkg/loginstatus/loginstatus_test.go b/pkg/loginstatus/loginstatus_test.go index 41d77cd..1713075 100644 --- a/pkg/loginstatus/loginstatus_test.go +++ b/pkg/loginstatus/loginstatus_test.go @@ -13,7 +13,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/loginstatus" ) @@ -27,18 +26,18 @@ func TestClient_ExchangeToken(t *testing.T) { client := loginstatus.NewClient(cfg, httpclient) for _, test := range []struct { - token *jwt.AccessToken + token string err error }{ { - token: jwt.NewAccessToken("valid-token", nil), + token: "valid-token", }, { - token: jwt.NewAccessToken("invalid-token", nil), + token: "invalid-token", err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"), }, { - token: jwt.NewAccessToken("internal-server-error", nil), + token: "internal-server-error", err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"), }, } { diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index 8e59350..62d3808 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -9,7 +9,6 @@ import ( "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/provider" ) @@ -17,8 +16,7 @@ import ( type LoginCallback interface { IdentityProviderError() error StateMismatchError() error - ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) - ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) + RedeemTokens(ctx context.Context) (*openid.Tokens, error) } type loginCallback struct { @@ -68,7 +66,7 @@ func (in loginCallback) StateMismatchError() error { return nil } -func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) { +func (in loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) { clientAssertion, err := in.client.MakeAssertion(time.Second * 30) if err != nil { return nil, fmt.Errorf("creating client assertion: %w", err) @@ -81,21 +79,17 @@ func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, er } code := in.requestParams.Get("code") - tokens, err := in.client.AuthCodeGrant(ctx, code, opts) + rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts) if err != nil { return nil, fmt.Errorf("exchanging authorization code for token: %w", err) } - return tokens, nil -} - -func (in loginCallback) ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) { jwkSet, err := in.provider.GetPublicJwkSet(ctx) if err != nil { return nil, fmt.Errorf("getting jwks: %w", err) } - tokens, err := jwt.ParseOauth2Token(rawTokens, *jwkSet) + tokens, err := openid.NewTokens(rawTokens, *jwkSet) if err != nil { // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt _, _ = in.provider.RefreshPublicJwkSet(ctx) diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index 36a0baa..1c94578 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -43,20 +43,20 @@ func TestLoginCallback_IdentityProviderError(t *testing.T) { assert.Error(t, err) } -func TestLoginCallback_ExchangeAuthCode(t *testing.T) { - t.Run("valid code", func(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" +func TestLoginCallback_RedeemTokens(t *testing.T) { + url := "http://wonderwall/oauth2/callback?code=some-code" + t.Run("happy path", func(t *testing.T) { idp, lc := newLoginCallback(t, url) defer idp.Close() - tokens, err := lc.ExchangeAuthCode(context.Background()) + tokens, err := lc.RedeemTokens(context.Background()) assert.NoError(t, err) assert.NotNil(t, tokens) assert.NotEmpty(t, tokens.AccessToken) assert.NotEmpty(t, tokens.RefreshToken) - assert.NotEmpty(t, tokens.Extra("id_token")) + assert.NotEmpty(t, tokens.IDToken.GetSerialized()) assert.NotEmpty(t, tokens.TokenType) assert.NotEmpty(t, tokens.Expiry) @@ -67,8 +67,6 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) { }) t.Run("invalid code", func(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" - idp, lc := newLoginCallback(t, url) defer idp.Close() idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{ @@ -76,38 +74,17 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) { "another-code": {}, } - tokens, err := lc.ExchangeAuthCode(context.Background()) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) -} - -func TestLoginCallback_ProcessTokens(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" - - t.Run("happy path", func(t *testing.T) { - idp, lc := newLoginCallback(t, url) - defer idp.Close() - - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) - assert.NoError(t, err) - assert.NotNil(t, tokens) - }) t.Run("nonce mismatch", func(t *testing.T) { idp, lc := newLoginCallback(t, url) defer idp.Close() idp.ProviderHandler.Codes["some-code"].Nonce = "some-other-nonce" - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) @@ -117,11 +94,7 @@ func TestLoginCallback_ProcessTokens(t *testing.T) { defer idp.Close() idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id" - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index 81b707f..2c1a946 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -7,7 +7,6 @@ import ( log "github.com/sirupsen/logrus" wonderwallconfig "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/scopes" "github.com/nais/wonderwall/pkg/router/paths" ) @@ -97,12 +96,12 @@ func NewClientConfig(cfg *wonderwallconfig.Config) (Client, error) { return nil, fmt.Errorf("missing required config %s", wonderwallconfig.Ingress) } - callbackURI, err := openid.RedirectURI(ingress, paths.Callback) + callbackURI, err := RedirectURI(ingress, paths.Callback) if err != nil { return nil, fmt.Errorf("creating callback URI from ingress: %w", err) } - logoutCallbackURI, err := openid.RedirectURI(ingress, paths.LogoutCallback) + logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback) if err != nil { return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err) } diff --git a/pkg/openid/redirect_uri.go b/pkg/openid/config/redirect_uri.go similarity index 96% rename from pkg/openid/redirect_uri.go rename to pkg/openid/config/redirect_uri.go index 4334971..6b1fcc3 100644 --- a/pkg/openid/redirect_uri.go +++ b/pkg/openid/config/redirect_uri.go @@ -1,4 +1,4 @@ -package openid +package config import ( "fmt" diff --git a/pkg/openid/redirect_uri_test.go b/pkg/openid/config/redirect_uri_test.go similarity index 89% rename from pkg/openid/redirect_uri_test.go rename to pkg/openid/config/redirect_uri_test.go index d4a150b..96f246a 100644 --- a/pkg/openid/redirect_uri_test.go +++ b/pkg/openid/config/redirect_uri_test.go @@ -1,4 +1,4 @@ -package openid_test +package config_test import ( "fmt" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/router/paths" ) @@ -47,7 +47,7 @@ func TestRedirectURI(t *testing.T) { err: fmt.Errorf("ingress cannot be empty"), }, } { - actual, err := openid.RedirectURI(test.input, test.path) + actual, err := config.RedirectURI(test.input, test.path) if test.err != nil { assert.EqualError(t, err, test.err.Error()) } else { diff --git a/pkg/openid/tokens.go b/pkg/openid/tokens.go new file mode 100644 index 0000000..5882b4c --- /dev/null +++ b/pkg/openid/tokens.go @@ -0,0 +1,90 @@ +package openid + +import ( + "fmt" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + jwtlib "github.com/lestrrat-go/jwx/v2/jwt" + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/jwt" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" +) + +type Tokens struct { + AccessToken string + Expiry time.Time + IDToken *IDToken + RefreshToken string + TokenType string +} + +func NewTokens(src *oauth2.Token, jwks jwk.Set) (*Tokens, error) { + idToken, err := ParseIDTokenFrom(src, jwks) + if err != nil { + return nil, fmt.Errorf("parsing id_token: %w", err) + } + + return &Tokens{ + AccessToken: src.AccessToken, + Expiry: src.Expiry, + IDToken: idToken, + RefreshToken: src.RefreshToken, + TokenType: src.TokenType, + }, nil +} + +type IDToken struct { + jwt.Token +} + +func (in *IDToken) GetSidClaim() (string, error) { + return in.GetStringClaim(jwt.SidClaim) +} + +func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error { + openIDconfig := cfg.Provider() + clientConfig := cfg.Client() + + opts := []jwtlib.ValidateOption{ + jwtlib.WithAudience(clientConfig.GetClientID()), + jwtlib.WithClaimValue("nonce", nonce), + jwtlib.WithIssuer(openIDconfig.Issuer), + jwtlib.WithAcceptableSkew(5 * time.Second), + } + + if openIDconfig.SidClaimRequired() { + opts = append(opts, jwtlib.WithRequiredClaim("sid")) + } + + if len(clientConfig.GetACRValues()) > 0 { + opts = append(opts, jwtlib.WithRequiredClaim("acr")) + } + + return jwtlib.Validate(in.GetToken(), opts...) +} + +func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken { + return &IDToken{ + jwt.NewToken(raw, jwtToken), + } +} + +func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) { + idToken, err := jwt.Parse(raw, jwks) + if err != nil { + return nil, err + } + + return NewIDToken(raw, idToken), nil +} + +func ParseIDTokenFrom(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) { + idToken, ok := tokens.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("missing id_token in token response") + } + + return ParseIDToken(idToken, jwks) +} diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 6cfdb3a..3d4fa45 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -9,10 +9,9 @@ import ( "github.com/sethvargo/go-retry" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/loginstatus" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" logentry "github.com/nais/wonderwall/pkg/router/middleware" ) @@ -52,19 +51,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - rawTokens, err := h.exchangeAuthCode(r.Context(), loginCallback) + tokens, err := h.redeemValidTokens(r.Context(), loginCallback) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: %w", err)) return } - tokens, err := loginCallback.ProcessTokens(r.Context(), rawTokens) - if err != nil { - h.InternalError(w, r, fmt.Errorf("callback: %w", err)) - return - } - - err = h.createSession(w, r, tokens, rawTokens) + err = h.createSession(w, r, tokens) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) return @@ -85,12 +78,12 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect) } -func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.LoginCallback) (*oauth2.Token, error) { - var tokens *oauth2.Token +func (h *Handler) redeemValidTokens(ctx context.Context, loginCallback client.LoginCallback) (*openid.Tokens, error) { + var tokens *openid.Tokens var err error retryable := func(ctx context.Context) error { - tokens, err = loginCallback.ExchangeAuthCode(ctx) + tokens, err = loginCallback.RedeemTokens(ctx) if err != nil { log.Warnf("callback: retrying: %+v", err) return retry.RetryableError(err) @@ -107,7 +100,7 @@ func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.Log return tokens, nil } -func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (*loginstatus.TokenResponse, error) { +func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) { var tokenResponse *loginstatus.TokenResponse err := retry.Do(ctx, backoff(), func(ctx context.Context) error { @@ -128,10 +121,10 @@ func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) ( return tokenResponse, nil } -func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) { +func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) { fields := map[string]interface{}{ "redirect_to": referer, - "claims": tokens.Claims(), + "jti": tokens.IDToken.GetJwtID(), } logger := logentry.LogEntryWithFields(r.Context(), fields) diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/router/handler_frontchannellogout.go index dc93c35..9f64e14 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/router/handler_frontchannellogout.go @@ -36,7 +36,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { if err != nil { log.Errorf("front-channel logout: destroying session: %+v", err) } else if sessionData != nil { - log.WithField("claims", sessionData.Claims).Infof("front-channel logout: successful logout") + log.WithField("jti", sessionData.IDTokenJwtID).Infof("front-channel logout: successful logout") } w.WriteHeader(http.StatusOK) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index e8e5661..0d3fc26 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -25,7 +25,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { } fields := map[string]interface{}{ - "claims": sessionData.Claims, + "jti": sessionData.IDTokenJwtID, } logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("logout: successful local logout") diff --git a/pkg/router/session.go b/pkg/router/session.go index 54e8822..2c67680 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -9,10 +9,9 @@ import ( "github.com/go-redis/redis/v8" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -77,7 +76,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration { return defaultSessionLifetime } -func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token) error { +func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *openid.Tokens) error { params := r.URL.Query() externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params) @@ -85,7 +84,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * return fmt.Errorf("generating session ID: %w", err) } - sessionLifetime := h.getSessionLifetime(rawTokens.Expiry) + sessionLifetime := h.getSessionLifetime(tokens.Expiry) opts := h.CookieOptions.WithExpiresIn(sessionLifetime) sessionID := h.localSessionID(externalSessionID) @@ -95,7 +94,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * } sessionMetadata := session.NewMetadata(time.Now().Add(sessionLifetime)) - sessionData := session.NewData(externalSessionID, tokens, rawTokens.RefreshToken, sessionMetadata) + sessionData := session.NewData(externalSessionID, tokens, sessionMetadata) encryptedSessionData, err := sessionData.Encrypt(h.Crypter) if err != nil { diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index 52f6598..7be1dd7 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -13,8 +13,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" ) @@ -40,10 +40,9 @@ func TestHandler_GetSessionFallback(t *testing.T) { sessionData, err := rpHandler.GetSessionFallback(w, r) assert.NoError(t, err) assert.Equal(t, "sid", sessionData.ExternalSessionID) - assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken) + assert.Equal(t, tokens.AccessToken, sessionData.AccessToken) assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken) - assert.Equal(t, "id-token-jti", sessionData.Claims.IDTokenJti) - assert.Equal(t, "access-token-jti", sessionData.Claims.AccessTokenJti) + assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID) assert.Empty(t, sessionData.RefreshToken) }) } @@ -59,7 +58,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { // request should set session cookies in response writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", tokens, "", nil) + data := session.NewData("sid", tokens, nil) err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn) assert.NoError(t, err) @@ -79,7 +78,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { }, { cookieName: "wonderwall-3", - want: tokens.AccessToken.GetSerialized(), + want: tokens.AccessToken, }, } { assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies) @@ -118,10 +117,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) { }) } -func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *jwt.Tokens) *http.Request { +func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request { writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", tokens, "", nil) + data := session.NewData("sid", tokens, nil) err := h.SetSessionFallback(writer, nil, data, expiresIn) assert.NoError(t, err) @@ -163,7 +162,7 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal assert.Equal(t, expectedValue, string(plainbytes)) } -func makeTokens(provider mock.TestProvider) *jwt.Tokens { +func makeTokens(provider mock.TestProvider) *openid.Tokens { jwks := *provider.PrivateJwkSet() jwksPublic, err := provider.GetPublicJwkSet(context.TODO()) if err != nil { @@ -188,20 +187,10 @@ func makeTokens(provider mock.TestProvider) *jwt.Tokens { log.Fatalf("parsing signed id_token: %+v", err) } - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") + accessToken := "some-access-token" - signedAccessToken, err := jwtlib.Sign(accessToken, jwtlib.WithKey(jwa.RS256, signer)) - if err != nil { - log.Fatalf("signing access_token: %+v", err) - } - parsedAccessToken, err := jwtlib.Parse(signedAccessToken, jwtlib.WithKeySet(*jwksPublic)) - if err != nil { - log.Fatalf("parsing signed access_token: %+v", err) - } - - return &jwt.Tokens{ - IDToken: jwt.NewIDToken(string(signedIdToken), parsedIdToken), - AccessToken: jwt.NewAccessToken(string(signedAccessToken), parsedAccessToken), + return &openid.Tokens{ + IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken), + AccessToken: accessToken, } } diff --git a/pkg/session/cookie.go b/pkg/session/cookie.go index deb39aa..d35ab2a 100644 --- a/pkg/session/cookie.go +++ b/pkg/session/cookie.go @@ -7,9 +7,12 @@ import ( "net/http" "time" + "golang.org/x/oauth2" + "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/provider" ) @@ -87,15 +90,34 @@ func (c cookieSessionStore) Read(ctx context.Context) (*Data, error) { return nil, fmt.Errorf("callback: getting jwks: %w", err) } - tokens, err := jwt.ParseTokensFromStrings(idToken, accessToken, *jwkSet) + // TODO: currently a placeholder fallback value, should fetch from metadata cookie + expiry := time.Now().Add(time.Hour) + + // attempt to get expiry from access_token if it is a JWT + parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet) + if err == nil { + expiry = parsedAccessToken.Expiration() + } + + // TODO: set refresh token and metadata + rawTokens := &oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + RefreshToken: "", + Expiry: expiry, + } + rawTokens = rawTokens.WithExtra(map[string]interface{}{ + "id_token": idToken, + }) + + tokens, err := openid.NewTokens(rawTokens, *jwkSet) if err != nil { // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt _, _ = c.provider.RefreshPublicJwkSet(ctx) return nil, fmt.Errorf("parsing tokens: %w", err) } - // TODO: set refresh token and metadata - return NewData(externalSessionID, tokens, "", nil), nil + return NewData(externalSessionID, tokens, nil), nil } func (c cookieSessionStore) Delete() { diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index e7d62ab..cbef10a 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -22,16 +22,16 @@ func TestMemory(t *testing.T) { idToken := jwtlib.New() idToken.Set("jti", "id-token-jti") - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") - - tokens := &jwt.Tokens{ - IDToken: jwt.NewIDToken("id_token", idToken), - AccessToken: jwt.NewAccessToken("access_token", accessToken), - } + accessToken := "some-access-token" refreshToken := "some-refresh-token" + + tokens := &openid.Tokens{ + AccessToken: accessToken, + IDToken: openid.NewIDToken("id_token", idToken), + RefreshToken: refreshToken, + } metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, refreshToken, metadata) + data := session.NewData("myid", tokens, metadata) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index d2dea90..4728280 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -24,16 +24,16 @@ func TestRedis(t *testing.T) { idToken := jwtlib.New() idToken.Set("jti", "id-token-jti") - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") - - tokens := &jwt.Tokens{ - IDToken: jwt.NewIDToken("id_token", idToken), - AccessToken: jwt.NewAccessToken("access_token", accessToken), - } + accessToken := "some-access-token" refreshToken := "some-refresh-token" + + tokens := &openid.Tokens{ + AccessToken: accessToken, + IDToken: openid.NewIDToken("id_token", idToken), + RefreshToken: refreshToken, + } metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, refreshToken, metadata) + data := session.NewData("myid", tokens, metadata) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/session_data.go b/pkg/session/session_data.go index 130e081..5010401 100644 --- a/pkg/session/session_data.go +++ b/pkg/session/session_data.go @@ -7,7 +7,7 @@ import ( "time" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" ) type EncryptedData struct { @@ -46,21 +46,21 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) { } type Data struct { - ExternalSessionID string `json:"external_session_id"` - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - Claims jwt.Claims `json:"claims"` - Metadata Metadata `json:"metadata"` + ExternalSessionID string `json:"external_session_id"` + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + IDTokenJwtID string `json:"id_token_jwt_id"` + Metadata Metadata `json:"metadata"` } -func NewData(externalSessionID string, tokens *jwt.Tokens, refreshToken string, metadata *Metadata) *Data { +func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data { data := &Data{ ExternalSessionID: externalSessionID, - AccessToken: tokens.AccessToken.GetSerialized(), + AccessToken: tokens.AccessToken, IDToken: tokens.IDToken.GetSerialized(), - RefreshToken: refreshToken, - Claims: tokens.Claims(), + IDTokenJwtID: tokens.IDToken.GetJwtID(), + RefreshToken: tokens.RefreshToken, } if metadata != nil { diff --git a/pkg/session/session_id.go b/pkg/session/session_id.go index 20a91db..6a88bad 100644 --- a/pkg/session/session_id.go +++ b/pkg/session/session_id.go @@ -7,7 +7,7 @@ import ( "io" "net/url" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" ) @@ -15,7 +15,7 @@ const ( SessionStateParamKey = "session_state" ) -func NewSessionID(cfg *config.Provider, idToken *jwt.IDToken, params url.Values) (string, error) { +func NewSessionID(cfg *config.Provider, idToken *openid.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token sessionID, err := idToken.GetSidClaim() if err == nil { diff --git a/pkg/session/session_id_test.go b/pkg/session/session_id_test.go index 772597a..56ca83a 100644 --- a/pkg/session/session_id_test.go +++ b/pkg/session/session_id_test.go @@ -8,7 +8,7 @@ import ( jwtlib "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/session" ) @@ -17,7 +17,7 @@ func TestSessionID(t *testing.T) { for _, test := range []struct { name string config *config.Provider - idToken *jwt.IDToken + idToken *openid.IDToken params url.Values want string exactMatch bool @@ -136,7 +136,7 @@ func params(key, value string) url.Values { return values } -func newIDToken(extraClaims map[string]string) *jwt.IDToken { +func newIDToken(extraClaims map[string]string) *openid.IDToken { idToken := jwtlib.New() idToken.Set("sub", "test") idToken.Set("iss", "test") @@ -155,15 +155,15 @@ func newIDToken(extraClaims map[string]string) *jwt.IDToken { panic(err) } - return jwt.NewIDToken(string(serialized), idToken) + return openid.NewIDToken(string(serialized), idToken) } -func idTokenWithSid(sid string) *jwt.IDToken { +func idTokenWithSid(sid string) *openid.IDToken { return newIDToken(map[string]string{ "sid": sid, }) } -func idToken() *jwt.IDToken { +func idToken() *openid.IDToken { return newIDToken(nil) }