diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/router/handler_frontchannellogout.go index f2d564a..b43a303 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/router/handler_frontchannellogout.go @@ -26,11 +26,17 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { } sessionID := h.localSessionID(sid) - - err := h.destroySession(w, r, sessionID) + sessionData, err := h.getSession(r.Context(), sessionID) if err != nil { - log.Error(err) + log.Errorf("get session: %+v", err) + } + + err = h.destroySession(w, r, sessionID) + if err != nil { + log.Errorf("destroying session: %+v", err) // Session is already destroyed at the OP and is highly unlikely to be used again. + } else if sessionData != nil { + log.WithField("jti", sessionData.JwtIDs).Infof("successful front-channel logout") } w.WriteHeader(http.StatusOK) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 1ddcd11..eaeba3e 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -5,6 +5,8 @@ import ( "net/http" "net/url" + log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/router/request" ) @@ -18,14 +20,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { var idToken string - sess, err := h.getSessionFromCookie(w, r) - if err == nil && sess != nil { - idToken = sess.IDToken - err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID)) + sessionData, err := h.getSessionFromCookie(w, r) + if err == nil && sessionData != nil { + idToken = sessionData.IDToken + err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID)) if err != nil { h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) return } + + log.WithField("jti", sessionData.JwtIDs).Infof("successful logout") } h.deleteCookie(w, SessionCookieName, h.CookieOptions) diff --git a/pkg/router/session.go b/pkg/router/session.go index 5a27494..d3dcc7b 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -1,6 +1,7 @@ package router import ( + "context" "errors" "fmt" "net/http" @@ -29,13 +30,8 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return nil, fmt.Errorf("no session cookie: %w", err) } - encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID) + sessionData, err := h.getSession(r.Context(), sessionID) if err == nil { - sessionData, err := encryptedSessionData.Decrypt(h.Crypter) - if err != nil { - return nil, fmt.Errorf("decrypting session data: %w", err) - } - h.DeleteSessionFallback(w, r) return sessionData, nil } @@ -54,6 +50,20 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return fallbackSessionData, nil } +func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) { + encryptedSessionData, err := h.Sessions.Read(ctx, sessionID) + if err != nil { + return nil, fmt.Errorf("reading session data from store: %w", err) + } + + sessionData, err := encryptedSessionData.Decrypt(h.Crypter) + if err != nil { + return nil, fmt.Errorf("decrypting session data: %w", err) + } + + return sessionData, nil +} + func (h *Handler) getSessionLifetime(accessToken *token.AccessToken) time.Duration { defaultSessionLifetime := h.Config.SessionMaxLifetime @@ -81,7 +91,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * return fmt.Errorf("setting session cookie: %w", err) } - sessionData := session.NewData(externalSessionID, tokens.AccessToken.Raw, tokens.IDToken.Raw) + sessionData := session.NewData(externalSessionID, tokens) encryptedSessionData, err := sessionData.Encrypt(h.Crypter) if err != nil { diff --git a/pkg/router/session_fallback.go b/pkg/router/session_fallback.go index 101478c..3a8ce40 100644 --- a/pkg/router/session_fallback.go +++ b/pkg/router/session_fallback.go @@ -7,6 +7,7 @@ import ( "time" "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) func (h *Handler) SessionFallbackExternalIDCookieName() string { @@ -58,7 +59,13 @@ func (h *Handler) GetSessionFallback(r *http.Request) (*session.Data, error) { return nil, fmt.Errorf("reading access_token from fallback cookie: %w", err) } - return session.NewData(externalSessionID, accessToken, idToken), nil + jwkSet := h.Provider.GetPublicJwkSet() + tokens, err := token.ParseTokensFromStrings(idToken, accessToken, *jwkSet) + if err != nil { + return nil, fmt.Errorf("parsing tokens: %w", err) + } + + return session.NewData(externalSessionID, tokens), nil } func (h *Handler) DeleteSessionFallback(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index c08ef04..1d3e6a6 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -7,15 +7,21 @@ import ( "testing" "time" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) func TestHandler_GetSessionFallback(t *testing.T) { - h := newHandler(mock.NewTestProvider()) + p := mock.NewTestProvider() + h := newHandler(p) + tokens := makeTokens(p) t.Run("request without fallback session cookies", func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -24,22 +30,26 @@ func TestHandler_GetSessionFallback(t *testing.T) { }) t.Run("request with fallback session cookies", func(t *testing.T) { - r := makeRequestWithFallbackCookies(t) + r := makeRequestWithFallbackCookies(t, h, tokens) sessionData, err := h.GetSessionFallback(r) assert.NoError(t, err) assert.Equal(t, "sid", sessionData.ExternalSessionID) - assert.Equal(t, "access_token", sessionData.AccessToken) - assert.Equal(t, "id_token", sessionData.IDToken) + assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken) + assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken) + assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken) + assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken) }) } func TestHandler_SetSessionFallback(t *testing.T) { - h := newHandler(mock.NewTestProvider()) + provider := mock.NewTestProvider() + h := newHandler(provider) // request should set session cookies in response writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", "access_token", "id_token") + tokens := makeTokens(provider) + data := session.NewData("sid", tokens) err := h.SetSessionFallback(writer, data, expiresIn) assert.NoError(t, err) @@ -55,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) { }, { cookieName: h.SessionFallbackIDTokenCookieName(), - want: "id_token", + want: tokens.IDToken.Raw, }, { cookieName: h.SessionFallbackAccessTokenCookieName(), - want: "access_token", + want: tokens.AccessToken.Raw, }, } { assertCookieExists(t, h, test.cookieName, test.want, cookies) @@ -67,10 +77,12 @@ func TestHandler_SetSessionFallback(t *testing.T) { } func TestHandler_DeleteSessionFallback(t *testing.T) { - h := newHandler(mock.NewTestProvider()) + p := mock.NewTestProvider() + h := newHandler(p) + tokens := makeTokens(p) t.Run("expire cookies if they are set", func(t *testing.T) { - r := makeRequestWithFallbackCookies(t) + r := makeRequestWithFallbackCookies(t, h, tokens) writer := httptest.NewRecorder() h.DeleteSessionFallback(writer, r) cookies := writer.Result().Cookies() @@ -93,11 +105,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) { }) } -func makeRequestWithFallbackCookies(t *testing.T) *http.Request { - h := newHandler(mock.NewTestProvider()) +func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *token.Tokens) *http.Request { writer := httptest.NewRecorder() expiresIn := time.Minute - data := session.NewData("sid", "access_token", "id_token") + data := session.NewData("sid", tokens) err := h.SetSessionFallback(writer, data, expiresIn) assert.NoError(t, err) @@ -138,3 +149,39 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal assert.NoError(t, err) assert.Equal(t, expectedValue, string(plainbytes)) } + +func makeTokens(provider mock.TestProvider) *token.Tokens { + jwks := *provider.PrivateJwkSet() + + signer, ok := jwks.Get(0) + if !ok { + log.Fatalf("getting signer") + } + + idToken := jwt.New() + idToken.Set("jti", "id-token-jti") + signedIdToken, err := jwt.Sign(idToken, jwa.RS256, signer) + if err != nil { + log.Fatalf("signing id_token: %+v", err) + } + parsedIdToken, err := jwt.Parse(signedIdToken) + if err != nil { + log.Fatalf("parsing signed id_token: %+v", err) + } + + accessToken := jwt.New() + accessToken.Set("jti", "access-token-jti") + signedAccessToken, err := jwt.Sign(accessToken, jwa.RS256, signer) + if err != nil { + log.Fatalf("signing access_token: %+v", err) + } + parsedAccessToken, err := jwt.Parse(signedAccessToken) + if err != nil { + log.Fatalf("parsing signed access_token: %+v", err) + } + + return &token.Tokens{ + IDToken: token.NewIDToken(string(signedIdToken), parsedIdToken), + AccessToken: token.NewAccessToken(string(signedAccessToken), parsedAccessToken), + } +} diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index 39d4753..9a33d07 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -5,11 +5,13 @@ import ( "testing" "time" + "github.com/lestrrat-go/jwx/jwt" "github.com/nais/liberator/pkg/keygen" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) func TestMemory(t *testing.T) { @@ -17,7 +19,17 @@ func TestMemory(t *testing.T) { assert.NoError(t, err) crypter := crypto.NewCrypter(key) - data := session.NewData("myid", "accesstoken", "idtoken") + idToken := jwt.New() + idToken.Set("jti", "id-token-jti") + + accessToken := jwt.New() + accessToken.Set("jti", "access-token-jti") + + tokens := &token.Tokens{ + IDToken: token.NewIDToken("id_token", idToken), + AccessToken: token.NewAccessToken("access_token", accessToken), + } + data := session.NewData("myid", tokens) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index 09a71b3..ac721d1 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -7,11 +7,13 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/go-redis/redis/v8" + "github.com/lestrrat-go/jwx/jwt" "github.com/nais/liberator/pkg/keygen" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) func TestRedis(t *testing.T) { @@ -19,7 +21,17 @@ func TestRedis(t *testing.T) { assert.NoError(t, err) crypter := crypto.NewCrypter(key) - data := session.NewData("myid", "accesstoken", "idtoken") + idToken := jwt.New() + idToken.Set("jti", "id-token-jti") + + accessToken := jwt.New() + accessToken.Set("jti", "access-token-jti") + + tokens := &token.Tokens{ + IDToken: token.NewIDToken("id_token", idToken), + AccessToken: token.NewAccessToken("access_token", accessToken), + } + data := session.NewData("myid", tokens) encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) diff --git a/pkg/session/session.go b/pkg/session/session.go index 0557010..23943c2 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -11,6 +11,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/token" ) type Store interface { @@ -79,16 +80,18 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) { } type Data struct { - ExternalSessionID string `json:"external_session_id"` - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` + ExternalSessionID string `json:"external_session_id"` + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + JwtIDs token.JwtIDs `json:"jti"` } -func NewData(externalSessionID, accessToken, idToken string) *Data { +func NewData(externalSessionID string, tokens *token.Tokens) *Data { return &Data{ ExternalSessionID: externalSessionID, - AccessToken: accessToken, - IDToken: idToken, + AccessToken: tokens.AccessToken.Raw, + IDToken: tokens.IDToken.Raw, + JwtIDs: tokens.JwtIDs(), } } diff --git a/pkg/token/access_token.go b/pkg/token/access_token.go index 0ab104e..2cefd84 100644 --- a/pkg/token/access_token.go +++ b/pkg/token/access_token.go @@ -3,7 +3,6 @@ package token import ( "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" - "golang.org/x/oauth2" ) type AccessToken struct { @@ -28,11 +27,11 @@ func NewAccessToken(raw string, token jwt.Token) *AccessToken { } } -func ParseAccessToken(tokens *oauth2.Token, jwks jwk.Set) (*AccessToken, error) { - accessToken, err := ParseJwt(tokens.AccessToken, jwks) +func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) { + accessToken, err := ParseJwt(raw, jwks) if err != nil { return nil, err } - return NewAccessToken(tokens.AccessToken, accessToken), nil + return NewAccessToken(raw, accessToken), nil } diff --git a/pkg/token/id_token.go b/pkg/token/id_token.go index 3156cf9..90ffd95 100644 --- a/pkg/token/id_token.go +++ b/pkg/token/id_token.go @@ -1,12 +1,10 @@ package token import ( - "fmt" "time" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" - "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/openid" ) @@ -59,12 +57,7 @@ func NewIDToken(raw string, token jwt.Token) *IDToken { } } -func ParseIDToken(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) { - raw, ok := tokens.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("missing id_token in token response") - } - +func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) { idToken, err := ParseJwt(raw, jwks) if err != nil { return nil, err diff --git a/pkg/token/token.go b/pkg/token/token.go index f83f71c..5f09aa0 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -25,28 +25,6 @@ type Tokens struct { AccessToken *AccessToken } -type JwtIDs struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` -} - -func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) { - idToken, err := ParseIDToken(tokens, jwks) - if err != nil { - return nil, fmt.Errorf("id_token: %w", err) - } - - accessToken, err := ParseAccessToken(tokens, jwks) - if err != nil { - return nil, fmt.Errorf("access_token: %w", err) - } - - return &Tokens{ - IDToken: idToken, - AccessToken: accessToken, - }, nil -} - func (in *Tokens) JwtIDs() JwtIDs { return JwtIDs{ IDToken: in.IDToken.GetJtiClaim(), @@ -54,6 +32,41 @@ func (in *Tokens) JwtIDs() JwtIDs { } } +func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens { + return &Tokens{ + IDToken: idToken, + AccessToken: accessToken, + } +} + +type JwtIDs struct { + IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token,omitempty"` +} + +func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) { + idToken, ok := tokens.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("missing id_token in token response") + } + + return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks) +} + +func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) { + parsedIdToken, err := ParseIDToken(idToken, jwks) + if err != nil { + return nil, fmt.Errorf("id_token: %w", err) + } + + parsedAccessToken, err := ParseAccessToken(accessToken, jwks) + if err != nil { + return nil, fmt.Errorf("access_token: %w", err) + } + + return NewTokens(parsedIdToken, parsedAccessToken), nil +} + func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) { parseOpts := []jwt.ParseOption{ jwt.WithKeySet(jwks), @@ -68,6 +81,10 @@ func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) { } func GetStringClaim(token jwt.Token, claim string) (string, error) { + if token == nil { + return "", fmt.Errorf("token is nil") + } + gotClaim, ok := token.Get(claim) if !ok { return "", fmt.Errorf("missing required '%s' claim in id_token", claim)