refactor: move token parsing to own package; prepare for audit logs

This commit is contained in:
Trong Huu Nguyen
2022-02-02 18:13:32 +01:00
parent 6a4a268e15
commit e4e95ef5c6
10 changed files with 251 additions and 131 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/token"
)
const (
@@ -18,7 +19,7 @@ const (
)
type Client interface {
ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error)
ExchangeToken(ctx context.Context, accessToken *token.AccessToken) (*TokenResponse, error)
SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options)
HasCookie(r *http.Request) bool
ClearCookie(w http.ResponseWriter, opts cookie.Options)
@@ -46,7 +47,7 @@ type client struct {
httpClient *http.Client
}
func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) {
func (c client) ExchangeToken(ctx context.Context, accessToken *token.AccessToken) (*TokenResponse, error) {
req, err := request(ctx, c.config.TokenURL, accessToken)
if err != nil {
return nil, fmt.Errorf("creating request %w", err)
@@ -98,13 +99,13 @@ func (c client) cookieOptions(opts cookie.Options) cookie.Options {
WithSameSite(SameSiteMode)
}
func request(ctx context.Context, url, token string) (*http.Request, error) {
func request(ctx context.Context, url string, token *token.AccessToken) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
req.Header.Set("Accept", "application/json")
return req, nil

View File

@@ -14,6 +14,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/loginstatus"
"github.com/nais/wonderwall/pkg/token"
)
func TestClient_ExchangeToken(t *testing.T) {
@@ -24,18 +25,18 @@ func TestClient_ExchangeToken(t *testing.T) {
client := loginstatus.NewClient(cfg, httpclient)
for _, test := range []struct {
token string
token *token.AccessToken
err error
}{
{
token: "valid-token",
token: token.NewAccessToken("valid-token", nil),
},
{
token: "invalid-token",
token: token.NewAccessToken("invalid-token", nil),
err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"),
},
{
token: "internal-server-error",
token: token.NewAccessToken("internal-server-error", nil),
err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"),
},
} {

View File

@@ -1,60 +0,0 @@
package openid
import (
"fmt"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
)
type IDToken struct {
Raw string
Token jwt.Token
}
func (in *IDToken) Validate(opts ...jwt.ValidateOption) error {
err := jwt.Validate(in.Token, opts...)
if err != nil {
return fmt.Errorf("validating id_token: %w", err)
}
return nil
}
func (in *IDToken) GetStringClaim(claim string) (string, error) {
gotClaim, ok := in.Token.Get(claim)
if !ok {
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
}
claimString, ok := gotClaim.(string)
if !ok {
return "", fmt.Errorf("'%s' claim is not a string", claim)
}
return claimString, nil
}
func ParseIDToken(jwks jwk.Set, token *oauth2.Token) (*IDToken, error) {
raw, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
parseOpts := []jwt.ParseOption{
jwt.WithKeySet(jwks),
jwt.InferAlgorithmFromKey(true),
}
idToken, err := jwt.Parse([]byte(raw), parseOpts...)
if err != nil {
return nil, fmt.Errorf("parsing jwt: %w", err)
}
result := &IDToken{
Raw: raw,
Token: idToken,
}
return result, nil
}

View File

@@ -4,13 +4,13 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"time"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/token"
)
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
@@ -33,49 +33,44 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
tokens, err := h.codeExchangeForToken(r.Context(), loginCookie, params.Get("code"))
rawTokens, err := h.codeExchangeForToken(r.Context(), loginCookie, params.Get("code"))
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: exchanging code: %w", err))
return
}
jwkSet := h.Provider.GetPublicJwkSet()
idToken, err := openid.ParseIDToken(*jwkSet, tokens)
tokens, err := token.ParseTokens(rawTokens, *jwkSet)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: parsing id_token: %w", err))
h.InternalError(w, r, fmt.Errorf("callback: parsing tokens: %w", err))
return
}
err = h.validateIDToken(idToken, loginCookie, params)
err = tokens.IDToken.Validate(h.Provider, loginCookie.Nonce)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err))
return
}
sessionID, err := SessionID(h.Provider.GetOpenIDConfiguration(), idToken, params)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: generating session ID: %w", err))
return
}
err = h.createSession(w, r, sessionID, tokens, idToken)
err = h.createSession(w, r, tokens, params)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
}
if h.Config.Features.Loginstatus.Enabled {
token, err := h.Loginstatus.ExchangeToken(r.Context(), tokens.AccessToken)
loginstatusToken, err := h.Loginstatus.ExchangeToken(r.Context(), tokens.AccessToken)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: exchanging loginstatus token: %w", err))
return
}
h.Loginstatus.SetCookie(w, token, h.CookieOptions)
h.Loginstatus.SetCookie(w, loginstatusToken, h.CookieOptions)
}
h.clearLoginCookies(w)
logSuccessfulLogin(tokens, loginCookie.Referer)
http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect)
}
@@ -99,24 +94,11 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid.
return tokens, nil
}
func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) error {
openIDconfig := h.Provider.GetOpenIDConfiguration()
clientConfig := h.Provider.GetClientConfiguration()
validateOpts := []jwt.ValidateOption{
jwt.WithAudience(clientConfig.GetClientID()),
jwt.WithClaimValue("nonce", loginCookie.Nonce),
jwt.WithIssuer(openIDconfig.Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
func logSuccessfulLogin(tokens *token.Tokens, referer string) {
fields := log.Fields{
"redirect_to": referer,
"jti": tokens.JwtIDs(),
}
if openIDconfig.SidClaimRequired() {
validateOpts = append(validateOpts, jwt.WithRequiredClaim("sid"))
}
if len(clientConfig.GetACRValues()) > 0 {
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
}
return idToken.Validate(validateOpts...)
log.WithFields(fields).Info("successful login")
}

View File

@@ -4,15 +4,14 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"time"
"github.com/go-redis/redis/v8"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
// localSessionID prefixes the given `sid` or `session_state` with the given client ID to prevent key collisions.
@@ -55,39 +54,34 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return fallbackSessionData, nil
}
func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) {
func (h *Handler) getSessionLifetime(accessToken *token.AccessToken) time.Duration {
defaultSessionLifetime := h.Config.SessionMaxLifetime
tok, err := jwt.Parse([]byte(accessToken))
if err != nil {
return 0, err
}
tokenDuration := tok.Expiration().Sub(time.Now())
tokenDuration := accessToken.Token.Expiration().Sub(time.Now())
if tokenDuration <= defaultSessionLifetime {
return tokenDuration, nil
return tokenDuration
}
return defaultSessionLifetime, nil
return defaultSessionLifetime
}
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, externalSessionID string, tokens *oauth2.Token, idToken *openid.IDToken) error {
sessionID := h.localSessionID(externalSessionID)
sessionLifetime, err := h.getSessionLifetime(tokens.AccessToken)
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *token.Tokens, params url.Values) error {
externalSessionID, err := NewSessionID(h.Provider.GetOpenIDConfiguration(), tokens.IDToken, params)
if err != nil {
return fmt.Errorf("getting access token lifetime: %w", err)
return fmt.Errorf("generating session ID: %w", err)
}
sessionLifetime := h.getSessionLifetime(tokens.AccessToken)
opts := h.CookieOptions.WithExpiresIn(sessionLifetime)
sessionID := h.localSessionID(externalSessionID)
err = h.setEncryptedCookie(w, SessionCookieName, sessionID, opts)
if err != nil {
return fmt.Errorf("setting session cookie: %w", err)
}
sessionData := session.NewData(externalSessionID, tokens.AccessToken, idToken.Raw)
sessionData := session.NewData(externalSessionID, tokens.AccessToken.Raw, tokens.IDToken.Raw)
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
if err != nil {

View File

@@ -8,15 +8,16 @@ import (
"net/url"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/token"
)
const (
SessionStateParamKey = "session_state"
)
func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) {
func NewSessionID(cfg *openid.Configuration, idToken *token.IDToken, params url.Values) (string, error) {
// 1. check for 'sid' claim in id_token
sessionID, err := idToken.GetStringClaim("sid")
sessionID, err := idToken.GetSidClaim()
if err == nil {
return sessionID, nil
}

View File

@@ -10,13 +10,14 @@ import (
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/token"
)
func TestSessionID(t *testing.T) {
for _, test := range []struct {
name string
config *openid.Configuration
idToken *openid.IDToken
idToken *token.IDToken
params url.Values
want string
exactMatch bool
@@ -97,7 +98,7 @@ func TestSessionID(t *testing.T) {
exactMatch: true,
},
} {
actual, err := router.SessionID(test.config, test.idToken, test.params)
actual, err := router.NewSessionID(test.config, test.idToken, test.params)
t.Run(test.name, func(t *testing.T) {
if test.expectErr {
@@ -135,7 +136,7 @@ func params(key, value string) url.Values {
return values
}
func newIDToken(extraClaims map[string]string) *openid.IDToken {
func newIDToken(extraClaims map[string]string) *token.IDToken {
idToken := jwt.New()
idToken.Set("sub", "test")
idToken.Set("iss", "test")
@@ -154,18 +155,15 @@ func newIDToken(extraClaims map[string]string) *openid.IDToken {
panic(err)
}
return &openid.IDToken{
Raw: string(serialized),
Token: idToken,
}
return token.NewIDToken(string(serialized), idToken)
}
func idTokenWithSid(sid string) *openid.IDToken {
func idTokenWithSid(sid string) *token.IDToken {
return newIDToken(map[string]string{
"sid": sid,
})
}
func idToken() *openid.IDToken {
func idToken() *token.IDToken {
return newIDToken(nil)
}

38
pkg/token/access_token.go Normal file
View File

@@ -0,0 +1,38 @@
package token
import (
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
)
type AccessToken struct {
Raw string
Token jwt.Token
Type Type
}
func (in *AccessToken) GetJtiClaim() string {
return GetStringClaimOrEmpty(in.Token, JtiClaim)
}
func (in *AccessToken) GetStringClaim(claim string) (string, error) {
return GetStringClaim(in.Token, claim)
}
func NewAccessToken(raw string, token jwt.Token) *AccessToken {
return &AccessToken{
Raw: raw,
Token: token,
Type: TypeAccessToken,
}
}
func ParseAccessToken(tokens *oauth2.Token, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := ParseJwt(tokens.AccessToken, jwks)
if err != nil {
return nil, err
}
return NewAccessToken(tokens.AccessToken, accessToken), nil
}

74
pkg/token/id_token.go Normal file
View File

@@ -0,0 +1,74 @@
package token
import (
"fmt"
"time"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/openid"
)
type IDToken struct {
Raw string
Token jwt.Token
Type Type
}
func (in *IDToken) GetJtiClaim() string {
return GetStringClaimOrEmpty(in.Token, JtiClaim)
}
func (in *IDToken) GetSidClaim() (string, error) {
return in.GetStringClaim(SidClaim)
}
func (in *IDToken) GetStringClaim(claim string) (string, error) {
return GetStringClaim(in.Token, claim)
}
func (in *IDToken) Validate(provider openid.Provider, nonce string) error {
openIDconfig := provider.GetOpenIDConfiguration()
clientConfig := provider.GetClientConfiguration()
opts := []jwt.ValidateOption{
jwt.WithAudience(clientConfig.GetClientID()),
jwt.WithClaimValue("nonce", nonce),
jwt.WithIssuer(openIDconfig.Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
}
if openIDconfig.SidClaimRequired() {
opts = append(opts, jwt.WithRequiredClaim("sid"))
}
if len(clientConfig.GetACRValues()) > 0 {
opts = append(opts, jwt.WithRequiredClaim("acr"))
}
return jwt.Validate(in.Token, opts...)
}
func NewIDToken(raw string, token jwt.Token) *IDToken {
return &IDToken{
Raw: raw,
Token: token,
Type: TypeIDToken,
}
}
func ParseIDToken(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) {
raw, ok := tokens.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
idToken, err := ParseJwt(raw, jwks)
if err != nil {
return nil, err
}
return NewIDToken(raw, idToken), nil
}

91
pkg/token/token.go Normal file
View File

@@ -0,0 +1,91 @@
package token
import (
"fmt"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
)
type Type int
const (
TypeIDToken Type = iota
TypeAccessToken
)
const (
JtiClaim = "jti"
SidClaim = "sid"
)
type Tokens struct {
IDToken *IDToken
AccessToken *AccessToken
}
type JwtIDs struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
}
func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
idToken, err := ParseIDToken(tokens, jwks)
if err != nil {
return nil, fmt.Errorf("id_token: %w", err)
}
accessToken, err := ParseAccessToken(tokens, jwks)
if err != nil {
return nil, fmt.Errorf("access_token: %w", err)
}
return &Tokens{
IDToken: idToken,
AccessToken: accessToken,
}, nil
}
func (in *Tokens) JwtIDs() JwtIDs {
return JwtIDs{
IDToken: in.IDToken.GetJtiClaim(),
AccessToken: in.AccessToken.GetJtiClaim(),
}
}
func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) {
parseOpts := []jwt.ParseOption{
jwt.WithKeySet(jwks),
jwt.InferAlgorithmFromKey(true),
}
token, err := jwt.ParseString(raw, parseOpts...)
if err != nil {
return nil, fmt.Errorf("parsing jwt: %w", err)
}
return token, nil
}
func GetStringClaim(token jwt.Token, claim string) (string, error) {
gotClaim, ok := token.Get(claim)
if !ok {
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
}
claimString, ok := gotClaim.(string)
if !ok {
return "", fmt.Errorf("'%s' claim is not a string", claim)
}
return claimString, nil
}
func GetStringClaimOrEmpty(token jwt.Token, claim string) string {
str, err := GetStringClaim(token, claim)
if err != nil {
return ""
}
return str
}