diff --git a/pkg/jwt/claims.go b/pkg/jwt/claims.go deleted file mode 100644 index fe3f37d..0000000 --- a/pkg/jwt/claims.go +++ /dev/null @@ -1,14 +0,0 @@ -package jwt - -const ( - JtiClaim = "jti" - SidClaim = "sid" - UtiClaim = "uti" -) - -type Claims struct { - IDTokenJti string `json:"id_token_jti,omitempty"` - IDTokenUti string `json:"id_token_uti,omitempty"` - AccessTokenJti string `json:"access_token_jti,omitempty"` - AccessTokenUti string `json:"access_token_uti,omitempty"` -} diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index e58aaf0..8a8188f 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -11,15 +11,18 @@ import ( const ( AcceptableClockSkew = 10 * time.Second + + JtiClaim = "jti" + SidClaim = "sid" + UtiClaim = "uti" ) type Token interface { GetExpiration() time.Time - GetJtiClaim() string + GetJwtID() string GetSerialized() string GetStringClaim(claim string) (string, error) GetToken() jwt.Token - GetUtiClaim() string } type token struct { @@ -31,8 +34,17 @@ func (in *token) GetExpiration() time.Time { return in.token.Expiration() } -func (in *token) GetJtiClaim() string { - return in.GetStringClaimOrEmpty(JtiClaim) +func (in *token) GetJwtID() string { + jti := in.GetStringClaimOrEmpty(JtiClaim) + uti := in.GetStringClaimOrEmpty(UtiClaim) + + // jti is the standard JWT ID claim + if len(jti) > 0 { + return jti + } + + // else, try to return uti - which seems to be Azure AD's variant + return uti } func (in *token) GetSerialized() string { @@ -70,10 +82,6 @@ func (in *token) GetToken() jwt.Token { return in.token } -func (in *token) GetUtiClaim() string { - return in.GetStringClaimOrEmpty(UtiClaim) -} - func NewToken(raw string, jwtToken jwt.Token) Token { return &token{ serialized: raw, diff --git a/pkg/jwt/jwt_access_token.go b/pkg/jwt/jwt_access_token.go deleted file mode 100644 index 04ae016..0000000 --- a/pkg/jwt/jwt_access_token.go +++ /dev/null @@ -1,25 +0,0 @@ -package jwt - -import ( - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -type AccessToken struct { - Token -} - -func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken { - return &AccessToken{ - NewToken(raw, jwtToken), - } -} - -func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) { - accessToken, err := Parse(raw, jwks) - if err != nil { - return nil, err - } - - return NewAccessToken(raw, accessToken), nil -} diff --git a/pkg/jwt/jwt_id_token.go b/pkg/jwt/jwt_id_token.go deleted file mode 100644 index 502d28e..0000000 --- a/pkg/jwt/jwt_id_token.go +++ /dev/null @@ -1,55 +0,0 @@ -package jwt - -import ( - "time" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" - - openidconfig "github.com/nais/wonderwall/pkg/openid/config" -) - -type IDToken struct { - Token -} - -func (in *IDToken) GetSidClaim() (string, error) { - return in.GetStringClaim(SidClaim) -} - -func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error { - openIDconfig := cfg.Provider() - clientConfig := cfg.Client() - - opts := []jwt.ValidateOption{ - jwt.WithAudience(clientConfig.GetClientID()), - jwt.WithClaimValue("nonce", nonce), - jwt.WithIssuer(openIDconfig.Issuer), - jwt.WithAcceptableSkew(5 * time.Second), - } - - if openIDconfig.SidClaimRequired() { - opts = append(opts, jwt.WithRequiredClaim("sid")) - } - - if len(clientConfig.GetACRValues()) > 0 { - opts = append(opts, jwt.WithRequiredClaim("acr")) - } - - return jwt.Validate(in.GetToken(), opts...) -} - -func NewIDToken(raw string, jwtToken jwt.Token) *IDToken { - return &IDToken{ - NewToken(raw, jwtToken), - } -} - -func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) { - idToken, err := Parse(raw, jwks) - if err != nil { - return nil, err - } - - return NewIDToken(raw, idToken), nil -} diff --git a/pkg/jwt/jwt_tokens.go b/pkg/jwt/jwt_tokens.go deleted file mode 100644 index 5cabe65..0000000 --- a/pkg/jwt/jwt_tokens.go +++ /dev/null @@ -1,52 +0,0 @@ -package jwt - -import ( - "fmt" - - "github.com/lestrrat-go/jwx/v2/jwk" - "golang.org/x/oauth2" -) - -type Tokens struct { - IDToken *IDToken - AccessToken *AccessToken -} - -func (in *Tokens) Claims() Claims { - return Claims{ - IDTokenJti: in.IDToken.GetJtiClaim(), - IDTokenUti: in.IDToken.GetUtiClaim(), - AccessTokenJti: in.AccessToken.GetJtiClaim(), - AccessTokenUti: in.AccessToken.GetUtiClaim(), - } -} - -func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens { - return &Tokens{ - IDToken: idToken, - AccessToken: accessToken, - } -} - -func ParseOauth2Token(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) { - idToken, ok := tokens.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("missing id_token in token response") - } - - return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks) -} - -func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) { - parsedIdToken, err := ParseIDToken(idToken, jwks) - if err != nil { - return nil, fmt.Errorf("id_token: %w", err) - } - - parsedAccessToken, err := ParseAccessToken(accessToken, jwks) - if err != nil { - return nil, fmt.Errorf("access_token: %w", err) - } - - return NewTokens(parsedIdToken, parsedAccessToken), nil -} diff --git a/pkg/loginstatus/loginstatus.go b/pkg/loginstatus/loginstatus.go index d6918b5..ca2f0ca 100644 --- a/pkg/loginstatus/loginstatus.go +++ b/pkg/loginstatus/loginstatus.go @@ -11,7 +11,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" ) const ( @@ -19,7 +18,7 @@ const ( ) type Client interface { - ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) + ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options) HasCookie(r *http.Request) bool ClearCookie(w http.ResponseWriter, opts cookie.Options) @@ -48,7 +47,7 @@ type client struct { httpClient *http.Client } -func (c client) ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) { +func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) { req, err := request(ctx, c.config.TokenURL, accessToken) if err != nil { return nil, fmt.Errorf("creating request %w", err) @@ -101,13 +100,13 @@ func (c client) CookieOptions(opts cookie.Options) cookie.Options { WithPath("/") } -func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Request, error) { +func request(ctx context.Context, url string, token string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized())) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("Accept", "application/json") return req, nil diff --git a/pkg/loginstatus/loginstatus_test.go b/pkg/loginstatus/loginstatus_test.go index 41d77cd..1713075 100644 --- a/pkg/loginstatus/loginstatus_test.go +++ b/pkg/loginstatus/loginstatus_test.go @@ -13,7 +13,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/loginstatus" ) @@ -27,18 +26,18 @@ func TestClient_ExchangeToken(t *testing.T) { client := loginstatus.NewClient(cfg, httpclient) for _, test := range []struct { - token *jwt.AccessToken + token string err error }{ { - token: jwt.NewAccessToken("valid-token", nil), + token: "valid-token", }, { - token: jwt.NewAccessToken("invalid-token", nil), + token: "invalid-token", err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"), }, { - token: jwt.NewAccessToken("internal-server-error", nil), + token: "internal-server-error", err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"), }, } { diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index 8e59350..62d3808 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -9,7 +9,6 @@ import ( "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/provider" ) @@ -17,8 +16,7 @@ import ( type LoginCallback interface { IdentityProviderError() error StateMismatchError() error - ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) - ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) + RedeemTokens(ctx context.Context) (*openid.Tokens, error) } type loginCallback struct { @@ -68,7 +66,7 @@ func (in loginCallback) StateMismatchError() error { return nil } -func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) { +func (in loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) { clientAssertion, err := in.client.MakeAssertion(time.Second * 30) if err != nil { return nil, fmt.Errorf("creating client assertion: %w", err) @@ -81,21 +79,17 @@ func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, er } code := in.requestParams.Get("code") - tokens, err := in.client.AuthCodeGrant(ctx, code, opts) + rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts) if err != nil { return nil, fmt.Errorf("exchanging authorization code for token: %w", err) } - return tokens, nil -} - -func (in loginCallback) ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) { jwkSet, err := in.provider.GetPublicJwkSet(ctx) if err != nil { return nil, fmt.Errorf("getting jwks: %w", err) } - tokens, err := jwt.ParseOauth2Token(rawTokens, *jwkSet) + tokens, err := openid.NewTokens(rawTokens, *jwkSet) if err != nil { // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt _, _ = in.provider.RefreshPublicJwkSet(ctx) diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index 36a0baa..1c94578 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -43,20 +43,20 @@ func TestLoginCallback_IdentityProviderError(t *testing.T) { assert.Error(t, err) } -func TestLoginCallback_ExchangeAuthCode(t *testing.T) { - t.Run("valid code", func(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" +func TestLoginCallback_RedeemTokens(t *testing.T) { + url := "http://wonderwall/oauth2/callback?code=some-code" + t.Run("happy path", func(t *testing.T) { idp, lc := newLoginCallback(t, url) defer idp.Close() - tokens, err := lc.ExchangeAuthCode(context.Background()) + tokens, err := lc.RedeemTokens(context.Background()) assert.NoError(t, err) assert.NotNil(t, tokens) assert.NotEmpty(t, tokens.AccessToken) assert.NotEmpty(t, tokens.RefreshToken) - assert.NotEmpty(t, tokens.Extra("id_token")) + assert.NotEmpty(t, tokens.IDToken.GetSerialized()) assert.NotEmpty(t, tokens.TokenType) assert.NotEmpty(t, tokens.Expiry) @@ -67,8 +67,6 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) { }) t.Run("invalid code", func(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" - idp, lc := newLoginCallback(t, url) defer idp.Close() idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{ @@ -76,38 +74,17 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) { "another-code": {}, } - tokens, err := lc.ExchangeAuthCode(context.Background()) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) -} - -func TestLoginCallback_ProcessTokens(t *testing.T) { - url := "http://wonderwall/oauth2/callback?code=some-code" - - t.Run("happy path", func(t *testing.T) { - idp, lc := newLoginCallback(t, url) - defer idp.Close() - - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) - assert.NoError(t, err) - assert.NotNil(t, tokens) - }) t.Run("nonce mismatch", func(t *testing.T) { idp, lc := newLoginCallback(t, url) defer idp.Close() idp.ProviderHandler.Codes["some-code"].Nonce = "some-other-nonce" - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) @@ -117,11 +94,7 @@ func TestLoginCallback_ProcessTokens(t *testing.T) { defer idp.Close() idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id" - rawTokens, err := lc.ExchangeAuthCode(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, rawTokens) - - tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + tokens, err := lc.RedeemTokens(context.Background()) assert.Error(t, err) assert.Nil(t, tokens) }) diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index 81b707f..2c1a946 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -7,7 +7,6 @@ import ( log "github.com/sirupsen/logrus" wonderwallconfig "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/scopes" "github.com/nais/wonderwall/pkg/router/paths" ) @@ -97,12 +96,12 @@ func NewClientConfig(cfg *wonderwallconfig.Config) (Client, error) { return nil, fmt.Errorf("missing required config %s", wonderwallconfig.Ingress) } - callbackURI, err := openid.RedirectURI(ingress, paths.Callback) + callbackURI, err := RedirectURI(ingress, paths.Callback) if err != nil { return nil, fmt.Errorf("creating callback URI from ingress: %w", err) } - logoutCallbackURI, err := openid.RedirectURI(ingress, paths.LogoutCallback) + logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback) if err != nil { return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err) } diff --git a/pkg/openid/redirect_uri.go b/pkg/openid/config/redirect_uri.go similarity index 96% rename from pkg/openid/redirect_uri.go rename to pkg/openid/config/redirect_uri.go index 4334971..6b1fcc3 100644 --- a/pkg/openid/redirect_uri.go +++ b/pkg/openid/config/redirect_uri.go @@ -1,4 +1,4 @@ -package openid +package config import ( "fmt" diff --git a/pkg/openid/redirect_uri_test.go b/pkg/openid/config/redirect_uri_test.go similarity index 89% rename from pkg/openid/redirect_uri_test.go rename to pkg/openid/config/redirect_uri_test.go index d4a150b..96f246a 100644 --- a/pkg/openid/redirect_uri_test.go +++ b/pkg/openid/config/redirect_uri_test.go @@ -1,4 +1,4 @@ -package openid_test +package config_test import ( "fmt" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/router/paths" ) @@ -47,7 +47,7 @@ func TestRedirectURI(t *testing.T) { err: fmt.Errorf("ingress cannot be empty"), }, } { - actual, err := openid.RedirectURI(test.input, test.path) + actual, err := config.RedirectURI(test.input, test.path) if test.err != nil { assert.EqualError(t, err, test.err.Error()) } else { diff --git a/pkg/openid/tokens.go b/pkg/openid/tokens.go new file mode 100644 index 0000000..5882b4c --- /dev/null +++ b/pkg/openid/tokens.go @@ -0,0 +1,90 @@ +package openid + +import ( + "fmt" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + jwtlib "github.com/lestrrat-go/jwx/v2/jwt" + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/jwt" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" +) + +type Tokens struct { + AccessToken string + Expiry time.Time + IDToken *IDToken + RefreshToken string + TokenType string +} + +func NewTokens(src *oauth2.Token, jwks jwk.Set) (*Tokens, error) { + idToken, err := ParseIDTokenFrom(src, jwks) + if err != nil { + return nil, fmt.Errorf("parsing id_token: %w", err) + } + + return &Tokens{ + AccessToken: src.AccessToken, + Expiry: src.Expiry, + IDToken: idToken, + RefreshToken: src.RefreshToken, + TokenType: src.TokenType, + }, nil +} + +type IDToken struct { + jwt.Token +} + +func (in *IDToken) GetSidClaim() (string, error) { + return in.GetStringClaim(jwt.SidClaim) +} + +func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error { + openIDconfig := cfg.Provider() + clientConfig := cfg.Client() + + opts := []jwtlib.ValidateOption{ + jwtlib.WithAudience(clientConfig.GetClientID()), + jwtlib.WithClaimValue("nonce", nonce), + jwtlib.WithIssuer(openIDconfig.Issuer), + jwtlib.WithAcceptableSkew(5 * time.Second), + } + + if openIDconfig.SidClaimRequired() { + opts = append(opts, jwtlib.WithRequiredClaim("sid")) + } + + if len(clientConfig.GetACRValues()) > 0 { + opts = append(opts, jwtlib.WithRequiredClaim("acr")) + } + + return jwtlib.Validate(in.GetToken(), opts...) +} + +func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken { + return &IDToken{ + jwt.NewToken(raw, jwtToken), + } +} + +func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) { + idToken, err := jwt.Parse(raw, jwks) + if err != nil { + return nil, err + } + + return NewIDToken(raw, idToken), nil +} + +func ParseIDTokenFrom(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) { + idToken, ok := tokens.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("missing id_token in token response") + } + + return ParseIDToken(idToken, jwks) +} diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 6cfdb3a..3d4fa45 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -9,10 +9,9 @@ import ( "github.com/sethvargo/go-retry" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/loginstatus" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" logentry "github.com/nais/wonderwall/pkg/router/middleware" ) @@ -52,19 +51,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - rawTokens, err := h.exchangeAuthCode(r.Context(), loginCallback) + tokens, err := h.redeemValidTokens(r.Context(), loginCallback) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: %w", err)) return } - tokens, err := loginCallback.ProcessTokens(r.Context(), rawTokens) - if err != nil { - h.InternalError(w, r, fmt.Errorf("callback: %w", err)) - return - } - - err = h.createSession(w, r, tokens, rawTokens) + err = h.createSession(w, r, tokens) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) return @@ -85,12 +78,12 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect) } -func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.LoginCallback) (*oauth2.Token, error) { - var tokens *oauth2.Token +func (h *Handler) redeemValidTokens(ctx context.Context, loginCallback client.LoginCallback) (*openid.Tokens, error) { + var tokens *openid.Tokens var err error retryable := func(ctx context.Context) error { - tokens, err = loginCallback.ExchangeAuthCode(ctx) + tokens, err = loginCallback.RedeemTokens(ctx) if err != nil { log.Warnf("callback: retrying: %+v", err) return retry.RetryableError(err) @@ -107,7 +100,7 @@ func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.Log return tokens, nil } -func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (*loginstatus.TokenResponse, error) { +func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) { var tokenResponse *loginstatus.TokenResponse err := retry.Do(ctx, backoff(), func(ctx context.Context) error { @@ -128,10 +121,10 @@ func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) ( return tokenResponse, nil } -func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) { +func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) { fields := map[string]interface{}{ "redirect_to": referer, - "claims": tokens.Claims(), + "jti": tokens.IDToken.GetJwtID(), } logger := logentry.LogEntryWithFields(r.Context(), fields) diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/router/handler_frontchannellogout.go index dc93c35..9f64e14 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/router/handler_frontchannellogout.go @@ -36,7 +36,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { if err != nil { log.Errorf("front-channel logout: destroying session: %+v", err) } else if sessionData != nil { - log.WithField("claims", sessionData.Claims).Infof("front-channel logout: successful logout") + log.WithField("jti", sessionData.IDTokenJwtID).Infof("front-channel logout: successful logout") } w.WriteHeader(http.StatusOK) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index e8e5661..0d3fc26 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -25,7 +25,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { } fields := map[string]interface{}{ - "claims": sessionData.Claims, + "jti": sessionData.IDTokenJwtID, } logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("logout: successful local logout") diff --git a/pkg/router/session.go b/pkg/router/session.go index 54e8822..2c67680 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -9,10 +9,9 @@ import ( "github.com/go-redis/redis/v8" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -77,7 +76,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration { return defaultSessionLifetime } -func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token) error { +func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *openid.Tokens) error { params := r.URL.Query() externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params) @@ -85,7 +84,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * return fmt.Errorf("generating session ID: %w", err) } - sessionLifetime := h.getSessionLifetime(rawTokens.Expiry) + sessionLifetime := h.getSessionLifetime(tokens.Expiry) opts := h.CookieOptions.WithExpiresIn(sessionLifetime) sessionID := h.localSessionID(externalSessionID) @@ -95,7 +94,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * } sessionMetadata := session.NewMetadata(time.Now().Add(sessionLifetime)) - sessionData := session.NewData(externalSessionID, tokens, rawTokens.RefreshToken, sessionMetadata) + sessionData := session.NewData(externalSessionID, tokens, sessionMetadata) encryptedSessionData, err := sessionData.Encrypt(h.Crypter) if err != nil { diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index 52f6598..7be1dd7 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -13,8 +13,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/jwt" "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" ) @@ -40,10 +40,9 @@ func TestHandler_GetSessionFallback(t *testing.T) { sessionData, err := rpHandler.GetSessionFallback(w, r) assert.NoError(t, err) assert.Equal(t, "sid", sessionData.ExternalSessionID) - assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken) + assert.Equal(t, tokens.AccessToken, sessionData.AccessToken) assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken) - assert.Equal(t, "id-token-jti", sessionData.Claims.IDTokenJti) - assert.Equal(t, "access-token-jti", sessionData.Claims.AccessTokenJti) + assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID) assert.Empty(t, sessionData.RefreshToken) }) } @@ -59,7 +58,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { // request should set session cookies in response writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", tokens, "", nil) + data := session.NewData("sid", tokens, nil) err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn) assert.NoError(t, err) @@ -79,7 +78,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { }, { cookieName: "wonderwall-3", - want: tokens.AccessToken.GetSerialized(), + want: tokens.AccessToken, }, } { assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies) @@ -118,10 +117,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) { }) } -func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *jwt.Tokens) *http.Request { +func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request { writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", tokens, "", nil) + data := session.NewData("sid", tokens, nil) err := h.SetSessionFallback(writer, nil, data, expiresIn) assert.NoError(t, err) @@ -163,7 +162,7 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal assert.Equal(t, expectedValue, string(plainbytes)) } -func makeTokens(provider mock.TestProvider) *jwt.Tokens { +func makeTokens(provider mock.TestProvider) *openid.Tokens { jwks := *provider.PrivateJwkSet() jwksPublic, err := provider.GetPublicJwkSet(context.TODO()) if err != nil { @@ -188,20 +187,10 @@ func makeTokens(provider mock.TestProvider) *jwt.Tokens { log.Fatalf("parsing signed id_token: %+v", err) } - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") + accessToken := "some-access-token" - signedAccessToken, err := jwtlib.Sign(accessToken, jwtlib.WithKey(jwa.RS256, signer)) - if err != nil { - log.Fatalf("signing access_token: %+v", err) - } - parsedAccessToken, err := jwtlib.Parse(signedAccessToken, jwtlib.WithKeySet(*jwksPublic)) - if err != nil { - log.Fatalf("parsing signed access_token: %+v", err) - } - - return &jwt.Tokens{ - IDToken: jwt.NewIDToken(string(signedIdToken), parsedIdToken), - AccessToken: jwt.NewAccessToken(string(signedAccessToken), parsedAccessToken), + return &openid.Tokens{ + IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken), + AccessToken: accessToken, } } diff --git a/pkg/session/cookie.go b/pkg/session/cookie.go index deb39aa..d35ab2a 100644 --- a/pkg/session/cookie.go +++ b/pkg/session/cookie.go @@ -7,9 +7,12 @@ import ( "net/http" "time" + "golang.org/x/oauth2" + "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/provider" ) @@ -87,15 +90,34 @@ func (c cookieSessionStore) Read(ctx context.Context) (*Data, error) { return nil, fmt.Errorf("callback: getting jwks: %w", err) } - tokens, err := jwt.ParseTokensFromStrings(idToken, accessToken, *jwkSet) + // TODO: currently a placeholder fallback value, should fetch from metadata cookie + expiry := time.Now().Add(time.Hour) + + // attempt to get expiry from access_token if it is a JWT + parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet) + if err == nil { + expiry = parsedAccessToken.Expiration() + } + + // TODO: set refresh token and metadata + rawTokens := &oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + RefreshToken: "", + Expiry: expiry, + } + rawTokens = rawTokens.WithExtra(map[string]interface{}{ + "id_token": idToken, + }) + + tokens, err := openid.NewTokens(rawTokens, *jwkSet) if err != nil { // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt _, _ = c.provider.RefreshPublicJwkSet(ctx) return nil, fmt.Errorf("parsing tokens: %w", err) } - // TODO: set refresh token and metadata - return NewData(externalSessionID, tokens, "", nil), nil + return NewData(externalSessionID, tokens, nil), nil } func (c cookieSessionStore) Delete() { diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index e7d62ab..cbef10a 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -22,16 +22,16 @@ func TestMemory(t *testing.T) { idToken := jwtlib.New() idToken.Set("jti", "id-token-jti") - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") - - tokens := &jwt.Tokens{ - IDToken: jwt.NewIDToken("id_token", idToken), - AccessToken: jwt.NewAccessToken("access_token", accessToken), - } + accessToken := "some-access-token" refreshToken := "some-refresh-token" + + tokens := &openid.Tokens{ + AccessToken: accessToken, + IDToken: openid.NewIDToken("id_token", idToken), + RefreshToken: refreshToken, + } metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, refreshToken, metadata) + data := session.NewData("myid", tokens, metadata) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index d2dea90..4728280 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) @@ -24,16 +24,16 @@ func TestRedis(t *testing.T) { idToken := jwtlib.New() idToken.Set("jti", "id-token-jti") - accessToken := jwtlib.New() - accessToken.Set("jti", "access-token-jti") - - tokens := &jwt.Tokens{ - IDToken: jwt.NewIDToken("id_token", idToken), - AccessToken: jwt.NewAccessToken("access_token", accessToken), - } + accessToken := "some-access-token" refreshToken := "some-refresh-token" + + tokens := &openid.Tokens{ + AccessToken: accessToken, + IDToken: openid.NewIDToken("id_token", idToken), + RefreshToken: refreshToken, + } metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, refreshToken, metadata) + data := session.NewData("myid", tokens, metadata) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/session_data.go b/pkg/session/session_data.go index 130e081..5010401 100644 --- a/pkg/session/session_data.go +++ b/pkg/session/session_data.go @@ -7,7 +7,7 @@ import ( "time" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" ) type EncryptedData struct { @@ -46,21 +46,21 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) { } type Data struct { - ExternalSessionID string `json:"external_session_id"` - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - Claims jwt.Claims `json:"claims"` - Metadata Metadata `json:"metadata"` + ExternalSessionID string `json:"external_session_id"` + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + IDTokenJwtID string `json:"id_token_jwt_id"` + Metadata Metadata `json:"metadata"` } -func NewData(externalSessionID string, tokens *jwt.Tokens, refreshToken string, metadata *Metadata) *Data { +func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data { data := &Data{ ExternalSessionID: externalSessionID, - AccessToken: tokens.AccessToken.GetSerialized(), + AccessToken: tokens.AccessToken, IDToken: tokens.IDToken.GetSerialized(), - RefreshToken: refreshToken, - Claims: tokens.Claims(), + IDTokenJwtID: tokens.IDToken.GetJwtID(), + RefreshToken: tokens.RefreshToken, } if metadata != nil { diff --git a/pkg/session/session_id.go b/pkg/session/session_id.go index 20a91db..6a88bad 100644 --- a/pkg/session/session_id.go +++ b/pkg/session/session_id.go @@ -7,7 +7,7 @@ import ( "io" "net/url" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" ) @@ -15,7 +15,7 @@ const ( SessionStateParamKey = "session_state" ) -func NewSessionID(cfg *config.Provider, idToken *jwt.IDToken, params url.Values) (string, error) { +func NewSessionID(cfg *config.Provider, idToken *openid.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token sessionID, err := idToken.GetSidClaim() if err == nil { diff --git a/pkg/session/session_id_test.go b/pkg/session/session_id_test.go index 772597a..56ca83a 100644 --- a/pkg/session/session_id_test.go +++ b/pkg/session/session_id_test.go @@ -8,7 +8,7 @@ import ( jwtlib "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/jwt" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/session" ) @@ -17,7 +17,7 @@ func TestSessionID(t *testing.T) { for _, test := range []struct { name string config *config.Provider - idToken *jwt.IDToken + idToken *openid.IDToken params url.Values want string exactMatch bool @@ -136,7 +136,7 @@ func params(key, value string) url.Values { return values } -func newIDToken(extraClaims map[string]string) *jwt.IDToken { +func newIDToken(extraClaims map[string]string) *openid.IDToken { idToken := jwtlib.New() idToken.Set("sub", "test") idToken.Set("iss", "test") @@ -155,15 +155,15 @@ func newIDToken(extraClaims map[string]string) *jwt.IDToken { panic(err) } - return jwt.NewIDToken(string(serialized), idToken) + return openid.NewIDToken(string(serialized), idToken) } -func idTokenWithSid(sid string) *jwt.IDToken { +func idTokenWithSid(sid string) *openid.IDToken { return newIDToken(map[string]string{ "sid": sid, }) } -func idToken() *jwt.IDToken { +func idToken() *openid.IDToken { return newIDToken(nil) }