mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-09 01:47:03 +00:00
feat: log jwt IDs for tracability
This commit is contained in:
@@ -26,11 +26,17 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sessionID := h.localSessionID(sid)
|
||||
|
||||
err := h.destroySession(w, r, sessionID)
|
||||
sessionData, err := h.getSession(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("get session: %+v", err)
|
||||
}
|
||||
|
||||
err = h.destroySession(w, r, sessionID)
|
||||
if err != nil {
|
||||
log.Errorf("destroying session: %+v", err)
|
||||
// Session is already destroyed at the OP and is highly unlikely to be used again.
|
||||
} else if sessionData != nil {
|
||||
log.WithField("jti", sessionData.JwtIDs).Infof("successful front-channel logout")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
)
|
||||
|
||||
@@ -18,14 +20,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var idToken string
|
||||
|
||||
sess, err := h.getSessionFromCookie(w, r)
|
||||
if err == nil && sess != nil {
|
||||
idToken = sess.IDToken
|
||||
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))
|
||||
sessionData, err := h.getSessionFromCookie(w, r)
|
||||
if err == nil && sessionData != nil {
|
||||
idToken = sessionData.IDToken
|
||||
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.WithField("jti", sessionData.JwtIDs).Infof("successful logout")
|
||||
}
|
||||
|
||||
h.deleteCookie(w, SessionCookieName, h.CookieOptions)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -29,13 +30,8 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
sessionData, err := h.getSession(r.Context(), sessionID)
|
||||
if err == nil {
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
h.DeleteSessionFallback(w, r)
|
||||
return sessionData, nil
|
||||
}
|
||||
@@ -54,6 +50,20 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
|
||||
return fallbackSessionData, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) {
|
||||
encryptedSessionData, err := h.Sessions.Read(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading session data from store: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionLifetime(accessToken *token.AccessToken) time.Duration {
|
||||
defaultSessionLifetime := h.Config.SessionMaxLifetime
|
||||
|
||||
@@ -81,7 +91,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
return fmt.Errorf("setting session cookie: %w", err)
|
||||
}
|
||||
|
||||
sessionData := session.NewData(externalSessionID, tokens.AccessToken.Raw, tokens.IDToken.Raw)
|
||||
sessionData := session.NewData(externalSessionID, tokens)
|
||||
|
||||
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
func (h *Handler) SessionFallbackExternalIDCookieName() string {
|
||||
@@ -58,7 +59,13 @@ func (h *Handler) GetSessionFallback(r *http.Request) (*session.Data, error) {
|
||||
return nil, fmt.Errorf("reading access_token from fallback cookie: %w", err)
|
||||
}
|
||||
|
||||
return session.NewData(externalSessionID, accessToken, idToken), nil
|
||||
jwkSet := h.Provider.GetPublicJwkSet()
|
||||
tokens, err := token.ParseTokensFromStrings(idToken, accessToken, *jwkSet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing tokens: %w", err)
|
||||
}
|
||||
|
||||
return session.NewData(externalSessionID, tokens), nil
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteSessionFallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -7,15 +7,21 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
h := newHandler(mock.NewTestProvider())
|
||||
p := mock.NewTestProvider()
|
||||
h := newHandler(p)
|
||||
tokens := makeTokens(p)
|
||||
|
||||
t.Run("request without fallback session cookies", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
@@ -24,22 +30,26 @@ func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("request with fallback session cookies", func(t *testing.T) {
|
||||
r := makeRequestWithFallbackCookies(t)
|
||||
r := makeRequestWithFallbackCookies(t, h, tokens)
|
||||
sessionData, err := h.GetSessionFallback(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sid", sessionData.ExternalSessionID)
|
||||
assert.Equal(t, "access_token", sessionData.AccessToken)
|
||||
assert.Equal(t, "id_token", sessionData.IDToken)
|
||||
assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken)
|
||||
assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken)
|
||||
assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken)
|
||||
assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
h := newHandler(mock.NewTestProvider())
|
||||
provider := mock.NewTestProvider()
|
||||
h := newHandler(provider)
|
||||
|
||||
// request should set session cookies in response
|
||||
writer := httptest.NewRecorder()
|
||||
expiresIn := time.Minute
|
||||
data := session.NewData("sid", "access_token", "id_token")
|
||||
tokens := makeTokens(provider)
|
||||
data := session.NewData("sid", tokens)
|
||||
err := h.SetSessionFallback(writer, data, expiresIn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -55,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
},
|
||||
{
|
||||
cookieName: h.SessionFallbackIDTokenCookieName(),
|
||||
want: "id_token",
|
||||
want: tokens.IDToken.Raw,
|
||||
},
|
||||
{
|
||||
cookieName: h.SessionFallbackAccessTokenCookieName(),
|
||||
want: "access_token",
|
||||
want: tokens.AccessToken.Raw,
|
||||
},
|
||||
} {
|
||||
assertCookieExists(t, h, test.cookieName, test.want, cookies)
|
||||
@@ -67,10 +77,12 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_DeleteSessionFallback(t *testing.T) {
|
||||
h := newHandler(mock.NewTestProvider())
|
||||
p := mock.NewTestProvider()
|
||||
h := newHandler(p)
|
||||
tokens := makeTokens(p)
|
||||
|
||||
t.Run("expire cookies if they are set", func(t *testing.T) {
|
||||
r := makeRequestWithFallbackCookies(t)
|
||||
r := makeRequestWithFallbackCookies(t, h, tokens)
|
||||
writer := httptest.NewRecorder()
|
||||
h.DeleteSessionFallback(writer, r)
|
||||
cookies := writer.Result().Cookies()
|
||||
@@ -93,11 +105,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func makeRequestWithFallbackCookies(t *testing.T) *http.Request {
|
||||
h := newHandler(mock.NewTestProvider())
|
||||
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *token.Tokens) *http.Request {
|
||||
writer := httptest.NewRecorder()
|
||||
expiresIn := time.Minute
|
||||
data := session.NewData("sid", "access_token", "id_token")
|
||||
data := session.NewData("sid", tokens)
|
||||
err := h.SetSessionFallback(writer, data, expiresIn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -138,3 +149,39 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, string(plainbytes))
|
||||
}
|
||||
|
||||
func makeTokens(provider mock.TestProvider) *token.Tokens {
|
||||
jwks := *provider.PrivateJwkSet()
|
||||
|
||||
signer, ok := jwks.Get(0)
|
||||
if !ok {
|
||||
log.Fatalf("getting signer")
|
||||
}
|
||||
|
||||
idToken := jwt.New()
|
||||
idToken.Set("jti", "id-token-jti")
|
||||
signedIdToken, err := jwt.Sign(idToken, jwa.RS256, signer)
|
||||
if err != nil {
|
||||
log.Fatalf("signing id_token: %+v", err)
|
||||
}
|
||||
parsedIdToken, err := jwt.Parse(signedIdToken)
|
||||
if err != nil {
|
||||
log.Fatalf("parsing signed id_token: %+v", err)
|
||||
}
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
signedAccessToken, err := jwt.Sign(accessToken, jwa.RS256, signer)
|
||||
if err != nil {
|
||||
log.Fatalf("signing access_token: %+v", err)
|
||||
}
|
||||
parsedAccessToken, err := jwt.Parse(signedAccessToken)
|
||||
if err != nil {
|
||||
log.Fatalf("parsing signed access_token: %+v", err)
|
||||
}
|
||||
|
||||
return &token.Tokens{
|
||||
IDToken: token.NewIDToken(string(signedIdToken), parsedIdToken),
|
||||
AccessToken: token.NewAccessToken(string(signedAccessToken), parsedAccessToken),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
@@ -17,7 +19,17 @@ func TestMemory(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
crypter := crypto.NewCrypter(key)
|
||||
|
||||
data := session.NewData("myid", "accesstoken", "idtoken")
|
||||
idToken := jwt.New()
|
||||
idToken.Set("jti", "id-token-jti")
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
|
||||
tokens := &token.Tokens{
|
||||
IDToken: token.NewIDToken("id_token", idToken),
|
||||
AccessToken: token.NewAccessToken("access_token", accessToken),
|
||||
}
|
||||
data := session.NewData("myid", tokens)
|
||||
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
func TestRedis(t *testing.T) {
|
||||
@@ -19,7 +21,17 @@ func TestRedis(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
crypter := crypto.NewCrypter(key)
|
||||
|
||||
data := session.NewData("myid", "accesstoken", "idtoken")
|
||||
idToken := jwt.New()
|
||||
idToken.Set("jti", "id-token-jti")
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
|
||||
tokens := &token.Tokens{
|
||||
IDToken: token.NewIDToken("id_token", idToken),
|
||||
AccessToken: token.NewAccessToken("access_token", accessToken),
|
||||
}
|
||||
data := session.NewData("myid", tokens)
|
||||
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
@@ -79,16 +80,18 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) {
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
ExternalSessionID string `json:"external_session_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExternalSessionID string `json:"external_session_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
JwtIDs token.JwtIDs `json:"jti"`
|
||||
}
|
||||
|
||||
func NewData(externalSessionID, accessToken, idToken string) *Data {
|
||||
func NewData(externalSessionID string, tokens *token.Tokens) *Data {
|
||||
return &Data{
|
||||
ExternalSessionID: externalSessionID,
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
AccessToken: tokens.AccessToken.Raw,
|
||||
IDToken: tokens.IDToken.Raw,
|
||||
JwtIDs: tokens.JwtIDs(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package token
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
@@ -28,11 +27,11 @@ func NewAccessToken(raw string, token jwt.Token) *AccessToken {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseAccessToken(tokens *oauth2.Token, jwks jwk.Set) (*AccessToken, error) {
|
||||
accessToken, err := ParseJwt(tokens.AccessToken, jwks)
|
||||
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
|
||||
accessToken, err := ParseJwt(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAccessToken(tokens.AccessToken, accessToken), nil
|
||||
return NewAccessToken(raw, accessToken), nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
@@ -59,12 +57,7 @@ func NewIDToken(raw string, token jwt.Token) *IDToken {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseIDToken(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) {
|
||||
raw, ok := tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing id_token in token response")
|
||||
}
|
||||
|
||||
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
|
||||
idToken, err := ParseJwt(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -25,28 +25,6 @@ type Tokens struct {
|
||||
AccessToken *AccessToken
|
||||
}
|
||||
|
||||
type JwtIDs struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
|
||||
idToken, err := ParseIDToken(tokens, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("id_token: %w", err)
|
||||
}
|
||||
|
||||
accessToken, err := ParseAccessToken(tokens, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access_token: %w", err)
|
||||
}
|
||||
|
||||
return &Tokens{
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (in *Tokens) JwtIDs() JwtIDs {
|
||||
return JwtIDs{
|
||||
IDToken: in.IDToken.GetJtiClaim(),
|
||||
@@ -54,6 +32,41 @@ func (in *Tokens) JwtIDs() JwtIDs {
|
||||
}
|
||||
}
|
||||
|
||||
func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens {
|
||||
return &Tokens{
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
}
|
||||
|
||||
type JwtIDs struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
}
|
||||
|
||||
func ParseTokens(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
|
||||
idToken, ok := tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing id_token in token response")
|
||||
}
|
||||
|
||||
return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks)
|
||||
}
|
||||
|
||||
func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) {
|
||||
parsedIdToken, err := ParseIDToken(idToken, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("id_token: %w", err)
|
||||
}
|
||||
|
||||
parsedAccessToken, err := ParseAccessToken(accessToken, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access_token: %w", err)
|
||||
}
|
||||
|
||||
return NewTokens(parsedIdToken, parsedAccessToken), nil
|
||||
}
|
||||
|
||||
func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) {
|
||||
parseOpts := []jwt.ParseOption{
|
||||
jwt.WithKeySet(jwks),
|
||||
@@ -68,6 +81,10 @@ func ParseJwt(raw string, jwks jwk.Set) (jwt.Token, error) {
|
||||
}
|
||||
|
||||
func GetStringClaim(token jwt.Token, claim string) (string, error) {
|
||||
if token == nil {
|
||||
return "", fmt.Errorf("token is nil")
|
||||
}
|
||||
|
||||
gotClaim, ok := token.Get(claim)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
|
||||
|
||||
Reference in New Issue
Block a user