refactor(jwt): skip parsing access tokens

Access Tokens are not necessarily JWTs. We also don't
have to validate them as we only pass it on as an opaque
string.

This also means that we don't log the JTI access tokens
anymore.

We also simplify handling of oidc callbacks.
This commit is contained in:
Trong Huu Nguyen
2022-07-14 12:14:22 +02:00
parent 6469c527a7
commit aab249d78a
24 changed files with 221 additions and 302 deletions

View File

@@ -1,14 +0,0 @@
package jwt
const (
JtiClaim = "jti"
SidClaim = "sid"
UtiClaim = "uti"
)
type Claims struct {
IDTokenJti string `json:"id_token_jti,omitempty"`
IDTokenUti string `json:"id_token_uti,omitempty"`
AccessTokenJti string `json:"access_token_jti,omitempty"`
AccessTokenUti string `json:"access_token_uti,omitempty"`
}

View File

@@ -11,15 +11,18 @@ import (
const (
AcceptableClockSkew = 10 * time.Second
JtiClaim = "jti"
SidClaim = "sid"
UtiClaim = "uti"
)
type Token interface {
GetExpiration() time.Time
GetJtiClaim() string
GetJwtID() string
GetSerialized() string
GetStringClaim(claim string) (string, error)
GetToken() jwt.Token
GetUtiClaim() string
}
type token struct {
@@ -31,8 +34,17 @@ func (in *token) GetExpiration() time.Time {
return in.token.Expiration()
}
func (in *token) GetJtiClaim() string {
return in.GetStringClaimOrEmpty(JtiClaim)
func (in *token) GetJwtID() string {
jti := in.GetStringClaimOrEmpty(JtiClaim)
uti := in.GetStringClaimOrEmpty(UtiClaim)
// jti is the standard JWT ID claim
if len(jti) > 0 {
return jti
}
// else, try to return uti - which seems to be Azure AD's variant
return uti
}
func (in *token) GetSerialized() string {
@@ -70,10 +82,6 @@ func (in *token) GetToken() jwt.Token {
return in.token
}
func (in *token) GetUtiClaim() string {
return in.GetStringClaimOrEmpty(UtiClaim)
}
func NewToken(raw string, jwtToken jwt.Token) Token {
return &token{
serialized: raw,

View File

@@ -1,25 +0,0 @@
package jwt
import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type AccessToken struct {
Token
}
func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken {
return &AccessToken{
NewToken(raw, jwtToken),
}
}
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := Parse(raw, jwks)
if err != nil {
return nil, err
}
return NewAccessToken(raw, accessToken), nil
}

View File

@@ -1,55 +0,0 @@
package jwt
import (
"time"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
)
type IDToken struct {
Token
}
func (in *IDToken) GetSidClaim() (string, error) {
return in.GetStringClaim(SidClaim)
}
func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error {
openIDconfig := cfg.Provider()
clientConfig := cfg.Client()
opts := []jwt.ValidateOption{
jwt.WithAudience(clientConfig.GetClientID()),
jwt.WithClaimValue("nonce", nonce),
jwt.WithIssuer(openIDconfig.Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
}
if openIDconfig.SidClaimRequired() {
opts = append(opts, jwt.WithRequiredClaim("sid"))
}
if len(clientConfig.GetACRValues()) > 0 {
opts = append(opts, jwt.WithRequiredClaim("acr"))
}
return jwt.Validate(in.GetToken(), opts...)
}
func NewIDToken(raw string, jwtToken jwt.Token) *IDToken {
return &IDToken{
NewToken(raw, jwtToken),
}
}
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
idToken, err := Parse(raw, jwks)
if err != nil {
return nil, err
}
return NewIDToken(raw, idToken), nil
}

View File

@@ -1,52 +0,0 @@
package jwt
import (
"fmt"
"github.com/lestrrat-go/jwx/v2/jwk"
"golang.org/x/oauth2"
)
type Tokens struct {
IDToken *IDToken
AccessToken *AccessToken
}
func (in *Tokens) Claims() Claims {
return Claims{
IDTokenJti: in.IDToken.GetJtiClaim(),
IDTokenUti: in.IDToken.GetUtiClaim(),
AccessTokenJti: in.AccessToken.GetJtiClaim(),
AccessTokenUti: in.AccessToken.GetUtiClaim(),
}
}
func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens {
return &Tokens{
IDToken: idToken,
AccessToken: accessToken,
}
}
func ParseOauth2Token(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
idToken, ok := tokens.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks)
}
func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) {
parsedIdToken, err := ParseIDToken(idToken, jwks)
if err != nil {
return nil, fmt.Errorf("id_token: %w", err)
}
parsedAccessToken, err := ParseAccessToken(accessToken, jwks)
if err != nil {
return nil, fmt.Errorf("access_token: %w", err)
}
return NewTokens(parsedIdToken, parsedAccessToken), nil
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/jwt"
)
const (
@@ -19,7 +18,7 @@ const (
)
type Client interface {
ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error)
ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error)
SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options)
HasCookie(r *http.Request) bool
ClearCookie(w http.ResponseWriter, opts cookie.Options)
@@ -48,7 +47,7 @@ type client struct {
httpClient *http.Client
}
func (c client) ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) {
func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) {
req, err := request(ctx, c.config.TokenURL, accessToken)
if err != nil {
return nil, fmt.Errorf("creating request %w", err)
@@ -101,13 +100,13 @@ func (c client) CookieOptions(opts cookie.Options) cookie.Options {
WithPath("/")
}
func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Request, error) {
func request(ctx context.Context, url string, token string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized()))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Accept", "application/json")
return req, nil

View File

@@ -13,7 +13,6 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/loginstatus"
)
@@ -27,18 +26,18 @@ func TestClient_ExchangeToken(t *testing.T) {
client := loginstatus.NewClient(cfg, httpclient)
for _, test := range []struct {
token *jwt.AccessToken
token string
err error
}{
{
token: jwt.NewAccessToken("valid-token", nil),
token: "valid-token",
},
{
token: jwt.NewAccessToken("invalid-token", nil),
token: "invalid-token",
err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"),
},
{
token: jwt.NewAccessToken("internal-server-error", nil),
token: "internal-server-error",
err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"),
},
} {

View File

@@ -9,7 +9,6 @@ import (
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/provider"
)
@@ -17,8 +16,7 @@ import (
type LoginCallback interface {
IdentityProviderError() error
StateMismatchError() error
ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error)
ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error)
RedeemTokens(ctx context.Context) (*openid.Tokens, error)
}
type loginCallback struct {
@@ -68,7 +66,7 @@ func (in loginCallback) StateMismatchError() error {
return nil
}
func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) {
func (in loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
clientAssertion, err := in.client.MakeAssertion(time.Second * 30)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
@@ -81,21 +79,17 @@ func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, er
}
code := in.requestParams.Get("code")
tokens, err := in.client.AuthCodeGrant(ctx, code, opts)
rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts)
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
}
return tokens, nil
}
func (in loginCallback) ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) {
jwkSet, err := in.provider.GetPublicJwkSet(ctx)
if err != nil {
return nil, fmt.Errorf("getting jwks: %w", err)
}
tokens, err := jwt.ParseOauth2Token(rawTokens, *jwkSet)
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
if err != nil {
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
_, _ = in.provider.RefreshPublicJwkSet(ctx)

View File

@@ -43,20 +43,20 @@ func TestLoginCallback_IdentityProviderError(t *testing.T) {
assert.Error(t, err)
}
func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
t.Run("valid code", func(t *testing.T) {
url := "http://wonderwall/oauth2/callback?code=some-code"
func TestLoginCallback_RedeemTokens(t *testing.T) {
url := "http://wonderwall/oauth2/callback?code=some-code"
t.Run("happy path", func(t *testing.T) {
idp, lc := newLoginCallback(t, url)
defer idp.Close()
tokens, err := lc.ExchangeAuthCode(context.Background())
tokens, err := lc.RedeemTokens(context.Background())
assert.NoError(t, err)
assert.NotNil(t, tokens)
assert.NotEmpty(t, tokens.AccessToken)
assert.NotEmpty(t, tokens.RefreshToken)
assert.NotEmpty(t, tokens.Extra("id_token"))
assert.NotEmpty(t, tokens.IDToken.GetSerialized())
assert.NotEmpty(t, tokens.TokenType)
assert.NotEmpty(t, tokens.Expiry)
@@ -67,8 +67,6 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
})
t.Run("invalid code", func(t *testing.T) {
url := "http://wonderwall/oauth2/callback?code=some-code"
idp, lc := newLoginCallback(t, url)
defer idp.Close()
idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{
@@ -76,38 +74,17 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
"another-code": {},
}
tokens, err := lc.ExchangeAuthCode(context.Background())
tokens, err := lc.RedeemTokens(context.Background())
assert.Error(t, err)
assert.Nil(t, tokens)
})
}
func TestLoginCallback_ProcessTokens(t *testing.T) {
url := "http://wonderwall/oauth2/callback?code=some-code"
t.Run("happy path", func(t *testing.T) {
idp, lc := newLoginCallback(t, url)
defer idp.Close()
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
assert.NoError(t, err)
assert.NotNil(t, tokens)
})
t.Run("nonce mismatch", func(t *testing.T) {
idp, lc := newLoginCallback(t, url)
defer idp.Close()
idp.ProviderHandler.Codes["some-code"].Nonce = "some-other-nonce"
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
tokens, err := lc.RedeemTokens(context.Background())
assert.Error(t, err)
assert.Nil(t, tokens)
})
@@ -117,11 +94,7 @@ func TestLoginCallback_ProcessTokens(t *testing.T) {
defer idp.Close()
idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id"
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
tokens, err := lc.RedeemTokens(context.Background())
assert.Error(t, err)
assert.Nil(t, tokens)
})

View File

@@ -7,7 +7,6 @@ import (
log "github.com/sirupsen/logrus"
wonderwallconfig "github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/scopes"
"github.com/nais/wonderwall/pkg/router/paths"
)
@@ -97,12 +96,12 @@ func NewClientConfig(cfg *wonderwallconfig.Config) (Client, error) {
return nil, fmt.Errorf("missing required config %s", wonderwallconfig.Ingress)
}
callbackURI, err := openid.RedirectURI(ingress, paths.Callback)
callbackURI, err := RedirectURI(ingress, paths.Callback)
if err != nil {
return nil, fmt.Errorf("creating callback URI from ingress: %w", err)
}
logoutCallbackURI, err := openid.RedirectURI(ingress, paths.LogoutCallback)
logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback)
if err != nil {
return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err)
}

View File

@@ -1,4 +1,4 @@
package openid
package config
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package openid_test
package config_test
import (
"fmt"
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/router/paths"
)
@@ -47,7 +47,7 @@ func TestRedirectURI(t *testing.T) {
err: fmt.Errorf("ingress cannot be empty"),
},
} {
actual, err := openid.RedirectURI(test.input, test.path)
actual, err := config.RedirectURI(test.input, test.path)
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {

90
pkg/openid/tokens.go Normal file
View File

@@ -0,0 +1,90 @@
package openid
import (
"fmt"
"time"
"github.com/lestrrat-go/jwx/v2/jwk"
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/jwt"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
)
type Tokens struct {
AccessToken string
Expiry time.Time
IDToken *IDToken
RefreshToken string
TokenType string
}
func NewTokens(src *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
idToken, err := ParseIDTokenFrom(src, jwks)
if err != nil {
return nil, fmt.Errorf("parsing id_token: %w", err)
}
return &Tokens{
AccessToken: src.AccessToken,
Expiry: src.Expiry,
IDToken: idToken,
RefreshToken: src.RefreshToken,
TokenType: src.TokenType,
}, nil
}
type IDToken struct {
jwt.Token
}
func (in *IDToken) GetSidClaim() (string, error) {
return in.GetStringClaim(jwt.SidClaim)
}
func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error {
openIDconfig := cfg.Provider()
clientConfig := cfg.Client()
opts := []jwtlib.ValidateOption{
jwtlib.WithAudience(clientConfig.GetClientID()),
jwtlib.WithClaimValue("nonce", nonce),
jwtlib.WithIssuer(openIDconfig.Issuer),
jwtlib.WithAcceptableSkew(5 * time.Second),
}
if openIDconfig.SidClaimRequired() {
opts = append(opts, jwtlib.WithRequiredClaim("sid"))
}
if len(clientConfig.GetACRValues()) > 0 {
opts = append(opts, jwtlib.WithRequiredClaim("acr"))
}
return jwtlib.Validate(in.GetToken(), opts...)
}
func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken {
return &IDToken{
jwt.NewToken(raw, jwtToken),
}
}
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
idToken, err := jwt.Parse(raw, jwks)
if err != nil {
return nil, err
}
return NewIDToken(raw, idToken), nil
}
func ParseIDTokenFrom(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) {
idToken, ok := tokens.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
}
return ParseIDToken(idToken, jwks)
}

View File

@@ -9,10 +9,9 @@ import (
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/loginstatus"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
@@ -52,19 +51,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
rawTokens, err := h.exchangeAuthCode(r.Context(), loginCallback)
tokens, err := h.redeemValidTokens(r.Context(), loginCallback)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: %w", err))
return
}
tokens, err := loginCallback.ProcessTokens(r.Context(), rawTokens)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: %w", err))
return
}
err = h.createSession(w, r, tokens, rawTokens)
err = h.createSession(w, r, tokens)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
@@ -85,12 +78,12 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect)
}
func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.LoginCallback) (*oauth2.Token, error) {
var tokens *oauth2.Token
func (h *Handler) redeemValidTokens(ctx context.Context, loginCallback client.LoginCallback) (*openid.Tokens, error) {
var tokens *openid.Tokens
var err error
retryable := func(ctx context.Context) error {
tokens, err = loginCallback.ExchangeAuthCode(ctx)
tokens, err = loginCallback.RedeemTokens(ctx)
if err != nil {
log.Warnf("callback: retrying: %+v", err)
return retry.RetryableError(err)
@@ -107,7 +100,7 @@ func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.Log
return tokens, nil
}
func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (*loginstatus.TokenResponse, error) {
func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
var tokenResponse *loginstatus.TokenResponse
err := retry.Do(ctx, backoff(), func(ctx context.Context) error {
@@ -128,10 +121,10 @@ func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (
return tokenResponse, nil
}
func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) {
func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) {
fields := map[string]interface{}{
"redirect_to": referer,
"claims": tokens.Claims(),
"jti": tokens.IDToken.GetJwtID(),
}
logger := logentry.LogEntryWithFields(r.Context(), fields)

View File

@@ -36,7 +36,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
if err != nil {
log.Errorf("front-channel logout: destroying session: %+v", err)
} else if sessionData != nil {
log.WithField("claims", sessionData.Claims).Infof("front-channel logout: successful logout")
log.WithField("jti", sessionData.IDTokenJwtID).Infof("front-channel logout: successful logout")
}
w.WriteHeader(http.StatusOK)

View File

@@ -25,7 +25,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
}
fields := map[string]interface{}{
"claims": sessionData.Claims,
"jti": sessionData.IDTokenJwtID,
}
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger.Info().Msg("logout: successful local logout")

View File

@@ -9,10 +9,9 @@ import (
"github.com/go-redis/redis/v8"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
)
@@ -77,7 +76,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
return defaultSessionLifetime
}
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token) error {
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *openid.Tokens) error {
params := r.URL.Query()
externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params)
@@ -85,7 +84,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
return fmt.Errorf("generating session ID: %w", err)
}
sessionLifetime := h.getSessionLifetime(rawTokens.Expiry)
sessionLifetime := h.getSessionLifetime(tokens.Expiry)
opts := h.CookieOptions.WithExpiresIn(sessionLifetime)
sessionID := h.localSessionID(externalSessionID)
@@ -95,7 +94,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
}
sessionMetadata := session.NewMetadata(time.Now().Add(sessionLifetime))
sessionData := session.NewData(externalSessionID, tokens, rawTokens.RefreshToken, sessionMetadata)
sessionData := session.NewData(externalSessionID, tokens, sessionMetadata)
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
if err != nil {

View File

@@ -13,8 +13,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
)
@@ -40,10 +40,9 @@ func TestHandler_GetSessionFallback(t *testing.T) {
sessionData, err := rpHandler.GetSessionFallback(w, r)
assert.NoError(t, err)
assert.Equal(t, "sid", sessionData.ExternalSessionID)
assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken)
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
assert.Equal(t, "id-token-jti", sessionData.Claims.IDTokenJti)
assert.Equal(t, "access-token-jti", sessionData.Claims.AccessTokenJti)
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
assert.Empty(t, sessionData.RefreshToken)
})
}
@@ -59,7 +58,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
// request should set session cookies in response
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", tokens, "", nil)
data := session.NewData("sid", tokens, nil)
err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn)
assert.NoError(t, err)
@@ -79,7 +78,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
},
{
cookieName: "wonderwall-3",
want: tokens.AccessToken.GetSerialized(),
want: tokens.AccessToken,
},
} {
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
@@ -118,10 +117,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) {
})
}
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *jwt.Tokens) *http.Request {
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request {
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", tokens, "", nil)
data := session.NewData("sid", tokens, nil)
err := h.SetSessionFallback(writer, nil, data, expiresIn)
assert.NoError(t, err)
@@ -163,7 +162,7 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal
assert.Equal(t, expectedValue, string(plainbytes))
}
func makeTokens(provider mock.TestProvider) *jwt.Tokens {
func makeTokens(provider mock.TestProvider) *openid.Tokens {
jwks := *provider.PrivateJwkSet()
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
if err != nil {
@@ -188,20 +187,10 @@ func makeTokens(provider mock.TestProvider) *jwt.Tokens {
log.Fatalf("parsing signed id_token: %+v", err)
}
accessToken := jwtlib.New()
accessToken.Set("jti", "access-token-jti")
accessToken := "some-access-token"
signedAccessToken, err := jwtlib.Sign(accessToken, jwtlib.WithKey(jwa.RS256, signer))
if err != nil {
log.Fatalf("signing access_token: %+v", err)
}
parsedAccessToken, err := jwtlib.Parse(signedAccessToken, jwtlib.WithKeySet(*jwksPublic))
if err != nil {
log.Fatalf("parsing signed access_token: %+v", err)
}
return &jwt.Tokens{
IDToken: jwt.NewIDToken(string(signedIdToken), parsedIdToken),
AccessToken: jwt.NewAccessToken(string(signedAccessToken), parsedAccessToken),
return &openid.Tokens{
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
AccessToken: accessToken,
}
}

View File

@@ -7,9 +7,12 @@ import (
"net/http"
"time"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/provider"
)
@@ -87,15 +90,34 @@ func (c cookieSessionStore) Read(ctx context.Context) (*Data, error) {
return nil, fmt.Errorf("callback: getting jwks: %w", err)
}
tokens, err := jwt.ParseTokensFromStrings(idToken, accessToken, *jwkSet)
// TODO: currently a placeholder fallback value, should fetch from metadata cookie
expiry := time.Now().Add(time.Hour)
// attempt to get expiry from access_token if it is a JWT
parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet)
if err == nil {
expiry = parsedAccessToken.Expiration()
}
// TODO: set refresh token and metadata
rawTokens := &oauth2.Token{
AccessToken: accessToken,
TokenType: "Bearer",
RefreshToken: "",
Expiry: expiry,
}
rawTokens = rawTokens.WithExtra(map[string]interface{}{
"id_token": idToken,
})
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
if err != nil {
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
_, _ = c.provider.RefreshPublicJwkSet(ctx)
return nil, fmt.Errorf("parsing tokens: %w", err)
}
// TODO: set refresh token and metadata
return NewData(externalSessionID, tokens, "", nil), nil
return NewData(externalSessionID, tokens, nil), nil
}
func (c cookieSessionStore) Delete() {

View File

@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
)
@@ -22,16 +22,16 @@ func TestMemory(t *testing.T) {
idToken := jwtlib.New()
idToken.Set("jti", "id-token-jti")
accessToken := jwtlib.New()
accessToken.Set("jti", "access-token-jti")
tokens := &jwt.Tokens{
IDToken: jwt.NewIDToken("id_token", idToken),
AccessToken: jwt.NewAccessToken("access_token", accessToken),
}
accessToken := "some-access-token"
refreshToken := "some-refresh-token"
tokens := &openid.Tokens{
AccessToken: accessToken,
IDToken: openid.NewIDToken("id_token", idToken),
RefreshToken: refreshToken,
}
metadata := session.NewMetadata(time.Now().Add(time.Hour))
data := session.NewData("myid", tokens, refreshToken, metadata)
data := session.NewData("myid", tokens, metadata)
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)

View File

@@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
)
@@ -24,16 +24,16 @@ func TestRedis(t *testing.T) {
idToken := jwtlib.New()
idToken.Set("jti", "id-token-jti")
accessToken := jwtlib.New()
accessToken.Set("jti", "access-token-jti")
tokens := &jwt.Tokens{
IDToken: jwt.NewIDToken("id_token", idToken),
AccessToken: jwt.NewAccessToken("access_token", accessToken),
}
accessToken := "some-access-token"
refreshToken := "some-refresh-token"
tokens := &openid.Tokens{
AccessToken: accessToken,
IDToken: openid.NewIDToken("id_token", idToken),
RefreshToken: refreshToken,
}
metadata := session.NewMetadata(time.Now().Add(time.Hour))
data := session.NewData("myid", tokens, refreshToken, metadata)
data := session.NewData("myid", tokens, metadata)
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)

View File

@@ -7,7 +7,7 @@ import (
"time"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
)
type EncryptedData struct {
@@ -46,21 +46,21 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) {
}
type Data struct {
ExternalSessionID string `json:"external_session_id"`
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Claims jwt.Claims `json:"claims"`
Metadata Metadata `json:"metadata"`
ExternalSessionID string `json:"external_session_id"`
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
IDTokenJwtID string `json:"id_token_jwt_id"`
Metadata Metadata `json:"metadata"`
}
func NewData(externalSessionID string, tokens *jwt.Tokens, refreshToken string, metadata *Metadata) *Data {
func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data {
data := &Data{
ExternalSessionID: externalSessionID,
AccessToken: tokens.AccessToken.GetSerialized(),
AccessToken: tokens.AccessToken,
IDToken: tokens.IDToken.GetSerialized(),
RefreshToken: refreshToken,
Claims: tokens.Claims(),
IDTokenJwtID: tokens.IDToken.GetJwtID(),
RefreshToken: tokens.RefreshToken,
}
if metadata != nil {

View File

@@ -7,7 +7,7 @@ import (
"io"
"net/url"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
)
@@ -15,7 +15,7 @@ const (
SessionStateParamKey = "session_state"
)
func NewSessionID(cfg *config.Provider, idToken *jwt.IDToken, params url.Values) (string, error) {
func NewSessionID(cfg *config.Provider, idToken *openid.IDToken, params url.Values) (string, error) {
// 1. check for 'sid' claim in id_token
sessionID, err := idToken.GetSidClaim()
if err == nil {

View File

@@ -8,7 +8,7 @@ import (
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/session"
)
@@ -17,7 +17,7 @@ func TestSessionID(t *testing.T) {
for _, test := range []struct {
name string
config *config.Provider
idToken *jwt.IDToken
idToken *openid.IDToken
params url.Values
want string
exactMatch bool
@@ -136,7 +136,7 @@ func params(key, value string) url.Values {
return values
}
func newIDToken(extraClaims map[string]string) *jwt.IDToken {
func newIDToken(extraClaims map[string]string) *openid.IDToken {
idToken := jwtlib.New()
idToken.Set("sub", "test")
idToken.Set("iss", "test")
@@ -155,15 +155,15 @@ func newIDToken(extraClaims map[string]string) *jwt.IDToken {
panic(err)
}
return jwt.NewIDToken(string(serialized), idToken)
return openid.NewIDToken(string(serialized), idToken)
}
func idTokenWithSid(sid string) *jwt.IDToken {
func idTokenWithSid(sid string) *openid.IDToken {
return newIDToken(map[string]string{
"sid": sid,
})
}
func idToken() *jwt.IDToken {
func idToken() *openid.IDToken {
return newIDToken(nil)
}