From e4e95ef5c6f6c2851c94273b6a5ca35c4cdf3e35 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Wed, 2 Feb 2022 18:13:32 +0100 Subject: [PATCH] refactor: move token parsing to own package; prepare for audit logs --- pkg/loginstatus/loginstatus.go | 9 +-- pkg/loginstatus/loginstatus_test.go | 9 +-- pkg/openid/token.go | 60 ------------------- pkg/router/handler_callback.go | 50 +++++----------- pkg/router/session.go | 30 ++++------ pkg/router/session_id.go | 5 +- pkg/router/session_id_test.go | 16 +++-- pkg/token/access_token.go | 38 ++++++++++++ pkg/token/id_token.go | 74 +++++++++++++++++++++++ pkg/token/token.go | 91 +++++++++++++++++++++++++++++ 10 files changed, 251 insertions(+), 131 deletions(-) delete mode 100644 pkg/openid/token.go create mode 100644 pkg/token/access_token.go create mode 100644 pkg/token/id_token.go create mode 100644 pkg/token/token.go diff --git a/pkg/loginstatus/loginstatus.go b/pkg/loginstatus/loginstatus.go index 1376a3a..422d192 100644 --- a/pkg/loginstatus/loginstatus.go +++ b/pkg/loginstatus/loginstatus.go @@ -11,6 +11,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/token" ) const ( @@ -18,7 +19,7 @@ const ( ) type Client interface { - ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) + ExchangeToken(ctx context.Context, accessToken *token.AccessToken) (*TokenResponse, error) SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options) HasCookie(r *http.Request) bool ClearCookie(w http.ResponseWriter, opts cookie.Options) @@ -46,7 +47,7 @@ type client struct { httpClient *http.Client } -func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) { +func (c client) ExchangeToken(ctx context.Context, accessToken *token.AccessToken) (*TokenResponse, error) { req, err := request(ctx, c.config.TokenURL, accessToken) if err != nil { return nil, fmt.Errorf("creating request %w", err) @@ -98,13 +99,13 @@ func (c client) cookieOptions(opts cookie.Options) cookie.Options { WithSameSite(SameSiteMode) } -func request(ctx context.Context, url, token string) (*http.Request, error) { +func request(ctx context.Context, url string, token *token.AccessToken) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Accept", "application/json") return req, nil diff --git a/pkg/loginstatus/loginstatus_test.go b/pkg/loginstatus/loginstatus_test.go index f2f9801..60f5a24 100644 --- a/pkg/loginstatus/loginstatus_test.go +++ b/pkg/loginstatus/loginstatus_test.go @@ -14,6 +14,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/loginstatus" + "github.com/nais/wonderwall/pkg/token" ) func TestClient_ExchangeToken(t *testing.T) { @@ -24,18 +25,18 @@ func TestClient_ExchangeToken(t *testing.T) { client := loginstatus.NewClient(cfg, httpclient) for _, test := range []struct { - token string + token *token.AccessToken err error }{ { - token: "valid-token", + token: token.NewAccessToken("valid-token", nil), }, { - token: "invalid-token", + token: token.NewAccessToken("invalid-token", nil), err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"), }, { - token: "internal-server-error", + token: token.NewAccessToken("internal-server-error", nil), err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"), }, } { diff --git a/pkg/openid/token.go b/pkg/openid/token.go deleted file mode 100644 index 020a859..0000000 --- a/pkg/openid/token.go +++ /dev/null @@ -1,60 +0,0 @@ -package openid - -import ( - "fmt" - - "github.com/lestrrat-go/jwx/jwk" - "github.com/lestrrat-go/jwx/jwt" - "golang.org/x/oauth2" -) - -type IDToken struct { - Raw string - Token jwt.Token -} - -func (in *IDToken) Validate(opts ...jwt.ValidateOption) error { - err := jwt.Validate(in.Token, opts...) - if err != nil { - return fmt.Errorf("validating id_token: %w", err) - } - - return nil -} - -func (in *IDToken) GetStringClaim(claim string) (string, error) { - gotClaim, ok := in.Token.Get(claim) - if !ok { - return "", fmt.Errorf("missing required '%s' claim in id_token", claim) - } - - claimString, ok := gotClaim.(string) - if !ok { - return "", fmt.Errorf("'%s' claim is not a string", claim) - } - - return claimString, nil -} - -func ParseIDToken(jwks jwk.Set, token *oauth2.Token) (*IDToken, error) { - raw, ok := token.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("missing id_token in token response") - } - - parseOpts := []jwt.ParseOption{ - jwt.WithKeySet(jwks), - jwt.InferAlgorithmFromKey(true), - } - idToken, err := jwt.Parse([]byte(raw), parseOpts...) - if err != nil { - return nil, fmt.Errorf("parsing jwt: %w", err) - } - - result := &IDToken{ - Raw: raw, - Token: idToken, - } - - return result, nil -} diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 46e2edf..0e424f7 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -4,13 +4,13 @@ import ( "context" "fmt" "net/http" - "net/url" "time" - "github.com/lestrrat-go/jwx/jwt" + log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/token" ) func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { @@ -33,49 +33,44 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - tokens, err := h.codeExchangeForToken(r.Context(), loginCookie, params.Get("code")) + rawTokens, err := h.codeExchangeForToken(r.Context(), loginCookie, params.Get("code")) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: exchanging code: %w", err)) return } jwkSet := h.Provider.GetPublicJwkSet() - idToken, err := openid.ParseIDToken(*jwkSet, tokens) + + tokens, err := token.ParseTokens(rawTokens, *jwkSet) if err != nil { - h.InternalError(w, r, fmt.Errorf("callback: parsing id_token: %w", err)) + h.InternalError(w, r, fmt.Errorf("callback: parsing tokens: %w", err)) return } - err = h.validateIDToken(idToken, loginCookie, params) + err = tokens.IDToken.Validate(h.Provider, loginCookie.Nonce) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err)) return } - sessionID, err := SessionID(h.Provider.GetOpenIDConfiguration(), idToken, params) - if err != nil { - h.InternalError(w, r, fmt.Errorf("callback: generating session ID: %w", err)) - return - } - - err = h.createSession(w, r, sessionID, tokens, idToken) + err = h.createSession(w, r, tokens, params) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) return } if h.Config.Features.Loginstatus.Enabled { - token, err := h.Loginstatus.ExchangeToken(r.Context(), tokens.AccessToken) + loginstatusToken, err := h.Loginstatus.ExchangeToken(r.Context(), tokens.AccessToken) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: exchanging loginstatus token: %w", err)) return } - h.Loginstatus.SetCookie(w, token, h.CookieOptions) + h.Loginstatus.SetCookie(w, loginstatusToken, h.CookieOptions) } h.clearLoginCookies(w) - + logSuccessfulLogin(tokens, loginCookie.Referer) http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect) } @@ -99,24 +94,11 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid. return tokens, nil } -func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) error { - openIDconfig := h.Provider.GetOpenIDConfiguration() - clientConfig := h.Provider.GetClientConfiguration() - - validateOpts := []jwt.ValidateOption{ - jwt.WithAudience(clientConfig.GetClientID()), - jwt.WithClaimValue("nonce", loginCookie.Nonce), - jwt.WithIssuer(openIDconfig.Issuer), - jwt.WithAcceptableSkew(5 * time.Second), +func logSuccessfulLogin(tokens *token.Tokens, referer string) { + fields := log.Fields{ + "redirect_to": referer, + "jti": tokens.JwtIDs(), } - if openIDconfig.SidClaimRequired() { - validateOpts = append(validateOpts, jwt.WithRequiredClaim("sid")) - } - - if len(clientConfig.GetACRValues()) > 0 { - validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr")) - } - - return idToken.Validate(validateOpts...) + log.WithFields(fields).Info("successful login") } diff --git a/pkg/router/session.go b/pkg/router/session.go index c09f51b..5a27494 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -4,15 +4,14 @@ import ( "errors" "fmt" "net/http" + "net/url" "time" "github.com/go-redis/redis/v8" - "github.com/lestrrat-go/jwx/jwt" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) // localSessionID prefixes the given `sid` or `session_state` with the given client ID to prevent key collisions. @@ -55,39 +54,34 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return fallbackSessionData, nil } -func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) { +func (h *Handler) getSessionLifetime(accessToken *token.AccessToken) time.Duration { defaultSessionLifetime := h.Config.SessionMaxLifetime - tok, err := jwt.Parse([]byte(accessToken)) - if err != nil { - return 0, err - } - - tokenDuration := tok.Expiration().Sub(time.Now()) + tokenDuration := accessToken.Token.Expiration().Sub(time.Now()) if tokenDuration <= defaultSessionLifetime { - return tokenDuration, nil + return tokenDuration } - return defaultSessionLifetime, nil + return defaultSessionLifetime } -func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, externalSessionID string, tokens *oauth2.Token, idToken *openid.IDToken) error { - sessionID := h.localSessionID(externalSessionID) - - sessionLifetime, err := h.getSessionLifetime(tokens.AccessToken) +func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *token.Tokens, params url.Values) error { + externalSessionID, err := NewSessionID(h.Provider.GetOpenIDConfiguration(), tokens.IDToken, params) if err != nil { - return fmt.Errorf("getting access token lifetime: %w", err) + return fmt.Errorf("generating session ID: %w", err) } + sessionLifetime := h.getSessionLifetime(tokens.AccessToken) opts := h.CookieOptions.WithExpiresIn(sessionLifetime) + sessionID := h.localSessionID(externalSessionID) err = h.setEncryptedCookie(w, SessionCookieName, sessionID, opts) if err != nil { return fmt.Errorf("setting session cookie: %w", err) } - sessionData := session.NewData(externalSessionID, tokens.AccessToken, idToken.Raw) + sessionData := session.NewData(externalSessionID, tokens.AccessToken.Raw, tokens.IDToken.Raw) encryptedSessionData, err := sessionData.Encrypt(h.Crypter) if err != nil { diff --git a/pkg/router/session_id.go b/pkg/router/session_id.go index b0a7d1f..cfd78d6 100644 --- a/pkg/router/session_id.go +++ b/pkg/router/session_id.go @@ -8,15 +8,16 @@ import ( "net/url" "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/token" ) const ( SessionStateParamKey = "session_state" ) -func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) { +func NewSessionID(cfg *openid.Configuration, idToken *token.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token - sessionID, err := idToken.GetStringClaim("sid") + sessionID, err := idToken.GetSidClaim() if err == nil { return sessionID, nil } diff --git a/pkg/router/session_id_test.go b/pkg/router/session_id_test.go index 956bcee..2e9c265 100644 --- a/pkg/router/session_id_test.go +++ b/pkg/router/session_id_test.go @@ -10,13 +10,14 @@ import ( "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router" + "github.com/nais/wonderwall/pkg/token" ) func TestSessionID(t *testing.T) { for _, test := range []struct { name string config *openid.Configuration - idToken *openid.IDToken + idToken *token.IDToken params url.Values want string exactMatch bool @@ -97,7 +98,7 @@ func TestSessionID(t *testing.T) { exactMatch: true, }, } { - actual, err := router.SessionID(test.config, test.idToken, test.params) + actual, err := router.NewSessionID(test.config, test.idToken, test.params) t.Run(test.name, func(t *testing.T) { if test.expectErr { @@ -135,7 +136,7 @@ func params(key, value string) url.Values { return values } -func newIDToken(extraClaims map[string]string) *openid.IDToken { +func newIDToken(extraClaims map[string]string) *token.IDToken { idToken := jwt.New() idToken.Set("sub", "test") idToken.Set("iss", "test") @@ -154,18 +155,15 @@ func newIDToken(extraClaims map[string]string) *openid.IDToken { panic(err) } - return &openid.IDToken{ - Raw: string(serialized), - Token: idToken, - } + return token.NewIDToken(string(serialized), idToken) } -func idTokenWithSid(sid string) *openid.IDToken { +func idTokenWithSid(sid string) *token.IDToken { return newIDToken(map[string]string{ "sid": sid, }) } -func idToken() *openid.IDToken { +func idToken() *token.IDToken { return newIDToken(nil) } diff --git a/pkg/token/access_token.go b/pkg/token/access_token.go new file mode 100644 index 0000000..0ab104e --- /dev/null +++ b/pkg/token/access_token.go @@ -0,0 +1,38 @@ +package token + +import ( + "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" + "golang.org/x/oauth2" +) + +type AccessToken struct { + Raw string + Token jwt.Token + Type Type +} + +func (in *AccessToken) GetJtiClaim() string { + return GetStringClaimOrEmpty(in.Token, JtiClaim) +} + +func (in *AccessToken) GetStringClaim(claim string) (string, error) { + return GetStringClaim(in.Token, claim) +} + +func NewAccessToken(raw string, token jwt.Token) *AccessToken { + return &AccessToken{ + Raw: raw, + Token: token, + Type: TypeAccessToken, + } +} + +func ParseAccessToken(tokens *oauth2.Token, jwks jwk.Set) (*AccessToken, error) { + accessToken, err := ParseJwt(tokens.AccessToken, jwks) + if err != nil { + return nil, err + } + + return NewAccessToken(tokens.AccessToken, accessToken), nil +} diff --git a/pkg/token/id_token.go b/pkg/token/id_token.go new file mode 100644 index 0000000..3156cf9 --- /dev/null +++ b/pkg/token/id_token.go @@ -0,0 +1,74 @@ +package token + +import ( + "fmt" + "time" + + "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/openid" +) + +type IDToken struct { + Raw string + Token jwt.Token + Type Type +} + +func (in *IDToken) GetJtiClaim() string { + return GetStringClaimOrEmpty(in.Token, JtiClaim) +} + +func (in *IDToken) GetSidClaim() (string, error) { + return in.GetStringClaim(SidClaim) +} + +func (in *IDToken) GetStringClaim(claim string) (string, error) { + return GetStringClaim(in.Token, claim) +} + +func (in *IDToken) Validate(provider openid.Provider, nonce string) error { + openIDconfig := provider.GetOpenIDConfiguration() + clientConfig := provider.GetClientConfiguration() + + opts := []jwt.ValidateOption{ + jwt.WithAudience(clientConfig.GetClientID()), + jwt.WithClaimValue("nonce", nonce), + jwt.WithIssuer(openIDconfig.Issuer), + jwt.WithAcceptableSkew(5 * time.Second), + } + + if openIDconfig.SidClaimRequired() { + opts = append(opts, jwt.WithRequiredClaim("sid")) + } + + if len(clientConfig.GetACRValues()) > 0 { + opts = append(opts, jwt.WithRequiredClaim("acr")) + } + + return jwt.Validate(in.Token, opts...) +} + +func NewIDToken(raw string, token jwt.Token) *IDToken { + return &IDToken{ + Raw: raw, + Token: token, + Type: TypeIDToken, + } +} + +func ParseIDToken(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) { + raw, ok := tokens.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("missing id_token in token response") + } + + idToken, err := ParseJwt(raw, jwks) + if err != nil { + return nil, err + } + + return NewIDToken(raw, idToken), nil +} diff --git a/pkg/token/token.go b/pkg/token/token.go new file mode 100644 index 0000000..f83f71c --- /dev/null +++ b/pkg/token/token.go @@ -0,0 +1,91 @@ +package token + +import ( + "fmt" + + "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" + "golang.org/x/oauth2" +) + +type Type int + +const ( + TypeIDToken Type = iota + TypeAccessToken +) + +const ( + JtiClaim = "jti" + SidClaim = "sid" +) + +type Tokens struct { + IDToken *IDToken + AccessToken *AccessToken +} + +type JwtIDs struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` +} + +func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) { + idToken, err := ParseIDToken(tokens, jwks) + if err != nil { + return nil, fmt.Errorf("id_token: %w", err) + } + + accessToken, err := ParseAccessToken(tokens, jwks) + if err != nil { + return nil, fmt.Errorf("access_token: %w", err) + } + + return &Tokens{ + IDToken: idToken, + AccessToken: accessToken, + }, nil +} + +func (in *Tokens) JwtIDs() JwtIDs { + return JwtIDs{ + IDToken: in.IDToken.GetJtiClaim(), + AccessToken: in.AccessToken.GetJtiClaim(), + } +} + +func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) { + parseOpts := []jwt.ParseOption{ + jwt.WithKeySet(jwks), + jwt.InferAlgorithmFromKey(true), + } + token, err := jwt.ParseString(raw, parseOpts...) + if err != nil { + return nil, fmt.Errorf("parsing jwt: %w", err) + } + + return token, nil +} + +func GetStringClaim(token jwt.Token, claim string) (string, error) { + gotClaim, ok := token.Get(claim) + if !ok { + return "", fmt.Errorf("missing required '%s' claim in id_token", claim) + } + + claimString, ok := gotClaim.(string) + if !ok { + return "", fmt.Errorf("'%s' claim is not a string", claim) + } + + return claimString, nil +} + +func GetStringClaimOrEmpty(token jwt.Token, claim string) string { + str, err := GetStringClaim(token, claim) + if err != nil { + return "" + } + + return str +}