feat: log jwt IDs for tracability

This commit is contained in:
Trong Huu Nguyen
2022-02-02 20:01:43 +01:00
parent e4e95ef5c6
commit eeccebc5dd
11 changed files with 180 additions and 70 deletions

View File

@@ -26,11 +26,17 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
}
sessionID := h.localSessionID(sid)
err := h.destroySession(w, r, sessionID)
sessionData, err := h.getSession(r.Context(), sessionID)
if err != nil {
log.Error(err)
log.Errorf("get session: %+v", err)
}
err = h.destroySession(w, r, sessionID)
if err != nil {
log.Errorf("destroying session: %+v", err)
// Session is already destroyed at the OP and is highly unlikely to be used again.
} else if sessionData != nil {
log.WithField("jti", sessionData.JwtIDs).Infof("successful front-channel logout")
}
w.WriteHeader(http.StatusOK)

View File

@@ -5,6 +5,8 @@ import (
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/router/request"
)
@@ -18,14 +20,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
var idToken string
sess, err := h.getSessionFromCookie(w, r)
if err == nil && sess != nil {
idToken = sess.IDToken
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))
sessionData, err := h.getSessionFromCookie(w, r)
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}
log.WithField("jti", sessionData.JwtIDs).Infof("successful logout")
}
h.deleteCookie(w, SessionCookieName, h.CookieOptions)

View File

@@ -1,6 +1,7 @@
package router
import (
"context"
"errors"
"fmt"
"net/http"
@@ -29,13 +30,8 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return nil, fmt.Errorf("no session cookie: %w", err)
}
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
sessionData, err := h.getSession(r.Context(), sessionID)
if err == nil {
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
h.DeleteSessionFallback(w, r)
return sessionData, nil
}
@@ -54,6 +50,20 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return fallbackSessionData, nil
}
func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) {
encryptedSessionData, err := h.Sessions.Read(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("reading session data from store: %w", err)
}
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return sessionData, nil
}
func (h *Handler) getSessionLifetime(accessToken *token.AccessToken) time.Duration {
defaultSessionLifetime := h.Config.SessionMaxLifetime
@@ -81,7 +91,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
return fmt.Errorf("setting session cookie: %w", err)
}
sessionData := session.NewData(externalSessionID, tokens.AccessToken.Raw, tokens.IDToken.Raw)
sessionData := session.NewData(externalSessionID, tokens)
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
func (h *Handler) SessionFallbackExternalIDCookieName() string {
@@ -58,7 +59,13 @@ func (h *Handler) GetSessionFallback(r *http.Request) (*session.Data, error) {
return nil, fmt.Errorf("reading access_token from fallback cookie: %w", err)
}
return session.NewData(externalSessionID, accessToken, idToken), nil
jwkSet := h.Provider.GetPublicJwkSet()
tokens, err := token.ParseTokensFromStrings(idToken, accessToken, *jwkSet)
if err != nil {
return nil, fmt.Errorf("parsing tokens: %w", err)
}
return session.NewData(externalSessionID, tokens), nil
}
func (h *Handler) DeleteSessionFallback(w http.ResponseWriter, r *http.Request) {

View File

@@ -7,15 +7,21 @@ import (
"testing"
"time"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
func TestHandler_GetSessionFallback(t *testing.T) {
h := newHandler(mock.NewTestProvider())
p := mock.NewTestProvider()
h := newHandler(p)
tokens := makeTokens(p)
t.Run("request without fallback session cookies", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -24,22 +30,26 @@ func TestHandler_GetSessionFallback(t *testing.T) {
})
t.Run("request with fallback session cookies", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t)
r := makeRequestWithFallbackCookies(t, h, tokens)
sessionData, err := h.GetSessionFallback(r)
assert.NoError(t, err)
assert.Equal(t, "sid", sessionData.ExternalSessionID)
assert.Equal(t, "access_token", sessionData.AccessToken)
assert.Equal(t, "id_token", sessionData.IDToken)
assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken)
assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken)
assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken)
})
}
func TestHandler_SetSessionFallback(t *testing.T) {
h := newHandler(mock.NewTestProvider())
provider := mock.NewTestProvider()
h := newHandler(provider)
// request should set session cookies in response
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", "access_token", "id_token")
tokens := makeTokens(provider)
data := session.NewData("sid", tokens)
err := h.SetSessionFallback(writer, data, expiresIn)
assert.NoError(t, err)
@@ -55,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) {
},
{
cookieName: h.SessionFallbackIDTokenCookieName(),
want: "id_token",
want: tokens.IDToken.Raw,
},
{
cookieName: h.SessionFallbackAccessTokenCookieName(),
want: "access_token",
want: tokens.AccessToken.Raw,
},
} {
assertCookieExists(t, h, test.cookieName, test.want, cookies)
@@ -67,10 +77,12 @@ func TestHandler_SetSessionFallback(t *testing.T) {
}
func TestHandler_DeleteSessionFallback(t *testing.T) {
h := newHandler(mock.NewTestProvider())
p := mock.NewTestProvider()
h := newHandler(p)
tokens := makeTokens(p)
t.Run("expire cookies if they are set", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t)
r := makeRequestWithFallbackCookies(t, h, tokens)
writer := httptest.NewRecorder()
h.DeleteSessionFallback(writer, r)
cookies := writer.Result().Cookies()
@@ -93,11 +105,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) {
})
}
func makeRequestWithFallbackCookies(t *testing.T) *http.Request {
h := newHandler(mock.NewTestProvider())
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *token.Tokens) *http.Request {
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", "access_token", "id_token")
data := session.NewData("sid", tokens)
err := h.SetSessionFallback(writer, data, expiresIn)
assert.NoError(t, err)
@@ -138,3 +149,39 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal
assert.NoError(t, err)
assert.Equal(t, expectedValue, string(plainbytes))
}
func makeTokens(provider mock.TestProvider) *token.Tokens {
jwks := *provider.PrivateJwkSet()
signer, ok := jwks.Get(0)
if !ok {
log.Fatalf("getting signer")
}
idToken := jwt.New()
idToken.Set("jti", "id-token-jti")
signedIdToken, err := jwt.Sign(idToken, jwa.RS256, signer)
if err != nil {
log.Fatalf("signing id_token: %+v", err)
}
parsedIdToken, err := jwt.Parse(signedIdToken)
if err != nil {
log.Fatalf("parsing signed id_token: %+v", err)
}
accessToken := jwt.New()
accessToken.Set("jti", "access-token-jti")
signedAccessToken, err := jwt.Sign(accessToken, jwa.RS256, signer)
if err != nil {
log.Fatalf("signing access_token: %+v", err)
}
parsedAccessToken, err := jwt.Parse(signedAccessToken)
if err != nil {
log.Fatalf("parsing signed access_token: %+v", err)
}
return &token.Tokens{
IDToken: token.NewIDToken(string(signedIdToken), parsedIdToken),
AccessToken: token.NewAccessToken(string(signedAccessToken), parsedAccessToken),
}
}

View File

@@ -5,11 +5,13 @@ import (
"testing"
"time"
"github.com/lestrrat-go/jwx/jwt"
"github.com/nais/liberator/pkg/keygen"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
func TestMemory(t *testing.T) {
@@ -17,7 +19,17 @@ func TestMemory(t *testing.T) {
assert.NoError(t, err)
crypter := crypto.NewCrypter(key)
data := session.NewData("myid", "accesstoken", "idtoken")
idToken := jwt.New()
idToken.Set("jti", "id-token-jti")
accessToken := jwt.New()
accessToken.Set("jti", "access-token-jti")
tokens := &token.Tokens{
IDToken: token.NewIDToken("id_token", idToken),
AccessToken: token.NewAccessToken("access_token", accessToken),
}
data := session.NewData("myid", tokens)
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)

View File

@@ -7,11 +7,13 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
"github.com/lestrrat-go/jwx/jwt"
"github.com/nais/liberator/pkg/keygen"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
func TestRedis(t *testing.T) {
@@ -19,7 +21,17 @@ func TestRedis(t *testing.T) {
assert.NoError(t, err)
crypter := crypto.NewCrypter(key)
data := session.NewData("myid", "accesstoken", "idtoken")
idToken := jwt.New()
idToken.Set("jti", "id-token-jti")
accessToken := jwt.New()
accessToken.Set("jti", "access-token-jti")
tokens := &token.Tokens{
IDToken: token.NewIDToken("id_token", idToken),
AccessToken: token.NewAccessToken("access_token", accessToken),
}
data := session.NewData("myid", tokens)
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)

View File

@@ -11,6 +11,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/token"
)
type Store interface {
@@ -79,16 +80,18 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) {
}
type Data struct {
ExternalSessionID string `json:"external_session_id"`
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
ExternalSessionID string `json:"external_session_id"`
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
JwtIDs token.JwtIDs `json:"jti"`
}
func NewData(externalSessionID, accessToken, idToken string) *Data {
func NewData(externalSessionID string, tokens *token.Tokens) *Data {
return &Data{
ExternalSessionID: externalSessionID,
AccessToken: accessToken,
IDToken: idToken,
AccessToken: tokens.AccessToken.Raw,
IDToken: tokens.IDToken.Raw,
JwtIDs: tokens.JwtIDs(),
}
}

View File

@@ -3,7 +3,6 @@ package token
import (
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
)
type AccessToken struct {
@@ -28,11 +27,11 @@ func NewAccessToken(raw string, token jwt.Token) *AccessToken {
}
}
func ParseAccessToken(tokens *oauth2.Token, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := ParseJwt(tokens.AccessToken, jwks)
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := ParseJwt(raw, jwks)
if err != nil {
return nil, err
}
return NewAccessToken(tokens.AccessToken, accessToken), nil
return NewAccessToken(raw, accessToken), nil
}

View File

@@ -1,12 +1,10 @@
package token
import (
"fmt"
"time"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/openid"
)
@@ -59,12 +57,7 @@ func NewIDToken(raw string, token jwt.Token) *IDToken {
}
}
func ParseIDToken(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) {
raw, ok := tokens.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
idToken, err := ParseJwt(raw, jwks)
if err != nil {
return nil, err

View File

@@ -25,28 +25,6 @@ type Tokens struct {
AccessToken *AccessToken
}
type JwtIDs struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
}
func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
idToken, err := ParseIDToken(tokens, jwks)
if err != nil {
return nil, fmt.Errorf("id_token: %w", err)
}
accessToken, err := ParseAccessToken(tokens, jwks)
if err != nil {
return nil, fmt.Errorf("access_token: %w", err)
}
return &Tokens{
IDToken: idToken,
AccessToken: accessToken,
}, nil
}
func (in *Tokens) JwtIDs() JwtIDs {
return JwtIDs{
IDToken: in.IDToken.GetJtiClaim(),
@@ -54,6 +32,41 @@ func (in *Tokens) JwtIDs() JwtIDs {
}
}
func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens {
return &Tokens{
IDToken: idToken,
AccessToken: accessToken,
}
}
type JwtIDs struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token,omitempty"`
}
func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
idToken, ok := tokens.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks)
}
func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) {
parsedIdToken, err := ParseIDToken(idToken, jwks)
if err != nil {
return nil, fmt.Errorf("id_token: %w", err)
}
parsedAccessToken, err := ParseAccessToken(accessToken, jwks)
if err != nil {
return nil, fmt.Errorf("access_token: %w", err)
}
return NewTokens(parsedIdToken, parsedAccessToken), nil
}
func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) {
parseOpts := []jwt.ParseOption{
jwt.WithKeySet(jwks),
@@ -68,6 +81,10 @@ func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) {
}
func GetStringClaim(token jwt.Token, claim string) (string, error) {
if token == nil {
return "", fmt.Errorf("token is nil")
}
gotClaim, ok := token.Get(claim)
if !ok {
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)