mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-10 02:16:59 +00:00
refactor(jwt): skip parsing access tokens
Access Tokens are not necessarily JWTs. We also don't have to validate them as we only pass it on as an opaque string. This also means that we don't log the JTI access tokens anymore. We also simplify handling of oidc callbacks.
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
package jwt
|
||||
|
||||
const (
|
||||
JtiClaim = "jti"
|
||||
SidClaim = "sid"
|
||||
UtiClaim = "uti"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
IDTokenJti string `json:"id_token_jti,omitempty"`
|
||||
IDTokenUti string `json:"id_token_uti,omitempty"`
|
||||
AccessTokenJti string `json:"access_token_jti,omitempty"`
|
||||
AccessTokenUti string `json:"access_token_uti,omitempty"`
|
||||
}
|
||||
@@ -11,15 +11,18 @@ import (
|
||||
|
||||
const (
|
||||
AcceptableClockSkew = 10 * time.Second
|
||||
|
||||
JtiClaim = "jti"
|
||||
SidClaim = "sid"
|
||||
UtiClaim = "uti"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
GetExpiration() time.Time
|
||||
GetJtiClaim() string
|
||||
GetJwtID() string
|
||||
GetSerialized() string
|
||||
GetStringClaim(claim string) (string, error)
|
||||
GetToken() jwt.Token
|
||||
GetUtiClaim() string
|
||||
}
|
||||
|
||||
type token struct {
|
||||
@@ -31,8 +34,17 @@ func (in *token) GetExpiration() time.Time {
|
||||
return in.token.Expiration()
|
||||
}
|
||||
|
||||
func (in *token) GetJtiClaim() string {
|
||||
return in.GetStringClaimOrEmpty(JtiClaim)
|
||||
func (in *token) GetJwtID() string {
|
||||
jti := in.GetStringClaimOrEmpty(JtiClaim)
|
||||
uti := in.GetStringClaimOrEmpty(UtiClaim)
|
||||
|
||||
// jti is the standard JWT ID claim
|
||||
if len(jti) > 0 {
|
||||
return jti
|
||||
}
|
||||
|
||||
// else, try to return uti - which seems to be Azure AD's variant
|
||||
return uti
|
||||
}
|
||||
|
||||
func (in *token) GetSerialized() string {
|
||||
@@ -70,10 +82,6 @@ func (in *token) GetToken() jwt.Token {
|
||||
return in.token
|
||||
}
|
||||
|
||||
func (in *token) GetUtiClaim() string {
|
||||
return in.GetStringClaimOrEmpty(UtiClaim)
|
||||
}
|
||||
|
||||
func NewToken(raw string, jwtToken jwt.Token) Token {
|
||||
return &token{
|
||||
serialized: raw,
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token
|
||||
}
|
||||
|
||||
func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken {
|
||||
return &AccessToken{
|
||||
NewToken(raw, jwtToken),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
|
||||
accessToken, err := Parse(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAccessToken(raw, accessToken), nil
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
)
|
||||
|
||||
type IDToken struct {
|
||||
Token
|
||||
}
|
||||
|
||||
func (in *IDToken) GetSidClaim() (string, error) {
|
||||
return in.GetStringClaim(SidClaim)
|
||||
}
|
||||
|
||||
func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error {
|
||||
openIDconfig := cfg.Provider()
|
||||
clientConfig := cfg.Client()
|
||||
|
||||
opts := []jwt.ValidateOption{
|
||||
jwt.WithAudience(clientConfig.GetClientID()),
|
||||
jwt.WithClaimValue("nonce", nonce),
|
||||
jwt.WithIssuer(openIDconfig.Issuer),
|
||||
jwt.WithAcceptableSkew(5 * time.Second),
|
||||
}
|
||||
|
||||
if openIDconfig.SidClaimRequired() {
|
||||
opts = append(opts, jwt.WithRequiredClaim("sid"))
|
||||
}
|
||||
|
||||
if len(clientConfig.GetACRValues()) > 0 {
|
||||
opts = append(opts, jwt.WithRequiredClaim("acr"))
|
||||
}
|
||||
|
||||
return jwt.Validate(in.GetToken(), opts...)
|
||||
}
|
||||
|
||||
func NewIDToken(raw string, jwtToken jwt.Token) *IDToken {
|
||||
return &IDToken{
|
||||
NewToken(raw, jwtToken),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
|
||||
idToken, err := Parse(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewIDToken(raw, idToken), nil
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type Tokens struct {
|
||||
IDToken *IDToken
|
||||
AccessToken *AccessToken
|
||||
}
|
||||
|
||||
func (in *Tokens) Claims() Claims {
|
||||
return Claims{
|
||||
IDTokenJti: in.IDToken.GetJtiClaim(),
|
||||
IDTokenUti: in.IDToken.GetUtiClaim(),
|
||||
AccessTokenJti: in.AccessToken.GetJtiClaim(),
|
||||
AccessTokenUti: in.AccessToken.GetUtiClaim(),
|
||||
}
|
||||
}
|
||||
|
||||
func NewTokens(idToken *IDToken, accessToken *AccessToken) *Tokens {
|
||||
return &Tokens{
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseOauth2Token(tokens *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
|
||||
idToken, ok := tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing id_token in token response")
|
||||
}
|
||||
|
||||
return ParseTokensFromStrings(idToken, tokens.AccessToken, jwks)
|
||||
}
|
||||
|
||||
func ParseTokensFromStrings(idToken, accessToken string, jwks jwk.Set) (*Tokens, error) {
|
||||
parsedIdToken, err := ParseIDToken(idToken, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("id_token: %w", err)
|
||||
}
|
||||
|
||||
parsedAccessToken, err := ParseAccessToken(accessToken, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access_token: %w", err)
|
||||
}
|
||||
|
||||
return NewTokens(parsedIdToken, parsedAccessToken), nil
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,7 +18,7 @@ const (
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error)
|
||||
ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error)
|
||||
SetCookie(w http.ResponseWriter, token *TokenResponse, opts cookie.Options)
|
||||
HasCookie(r *http.Request) bool
|
||||
ClearCookie(w http.ResponseWriter, opts cookie.Options)
|
||||
@@ -48,7 +47,7 @@ type client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func (c client) ExchangeToken(ctx context.Context, accessToken *jwt.AccessToken) (*TokenResponse, error) {
|
||||
func (c client) ExchangeToken(ctx context.Context, accessToken string) (*TokenResponse, error) {
|
||||
req, err := request(ctx, c.config.TokenURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request %w", err)
|
||||
@@ -101,13 +100,13 @@ func (c client) CookieOptions(opts cookie.Options) cookie.Options {
|
||||
WithPath("/")
|
||||
}
|
||||
|
||||
func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Request, error) {
|
||||
func request(ctx context.Context, url string, token string) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized()))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/loginstatus"
|
||||
)
|
||||
|
||||
@@ -27,18 +26,18 @@ func TestClient_ExchangeToken(t *testing.T) {
|
||||
client := loginstatus.NewClient(cfg, httpclient)
|
||||
|
||||
for _, test := range []struct {
|
||||
token *jwt.AccessToken
|
||||
token string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
token: jwt.NewAccessToken("valid-token", nil),
|
||||
token: "valid-token",
|
||||
},
|
||||
{
|
||||
token: jwt.NewAccessToken("invalid-token", nil),
|
||||
token: "invalid-token",
|
||||
err: fmt.Errorf("client error: HTTP: %d: %s: %s", http.StatusUnauthorized, "access_denied", "No new and shiny token for you!"),
|
||||
},
|
||||
{
|
||||
token: jwt.NewAccessToken("internal-server-error", nil),
|
||||
token: "internal-server-error",
|
||||
err: fmt.Errorf("server error: HTTP: %d: %s", http.StatusInternalServerError, "Oh no, it broke"),
|
||||
},
|
||||
} {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/provider"
|
||||
)
|
||||
@@ -17,8 +16,7 @@ import (
|
||||
type LoginCallback interface {
|
||||
IdentityProviderError() error
|
||||
StateMismatchError() error
|
||||
ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error)
|
||||
ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error)
|
||||
RedeemTokens(ctx context.Context) (*openid.Tokens, error)
|
||||
}
|
||||
|
||||
type loginCallback struct {
|
||||
@@ -68,7 +66,7 @@ func (in loginCallback) StateMismatchError() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, error) {
|
||||
func (in loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
|
||||
clientAssertion, err := in.client.MakeAssertion(time.Second * 30)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating client assertion: %w", err)
|
||||
@@ -81,21 +79,17 @@ func (in loginCallback) ExchangeAuthCode(ctx context.Context) (*oauth2.Token, er
|
||||
}
|
||||
|
||||
code := in.requestParams.Get("code")
|
||||
tokens, err := in.client.AuthCodeGrant(ctx, code, opts)
|
||||
rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (in loginCallback) ProcessTokens(ctx context.Context, rawTokens *oauth2.Token) (*jwt.Tokens, error) {
|
||||
jwkSet, err := in.provider.GetPublicJwkSet(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting jwks: %w", err)
|
||||
}
|
||||
|
||||
tokens, err := jwt.ParseOauth2Token(rawTokens, *jwkSet)
|
||||
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
|
||||
if err != nil {
|
||||
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
|
||||
_, _ = in.provider.RefreshPublicJwkSet(ctx)
|
||||
|
||||
@@ -43,20 +43,20 @@ func TestLoginCallback_IdentityProviderError(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
|
||||
t.Run("valid code", func(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
func TestLoginCallback_RedeemTokens(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url)
|
||||
defer idp.Close()
|
||||
|
||||
tokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
tokens, err := lc.RedeemTokens(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokens)
|
||||
|
||||
assert.NotEmpty(t, tokens.AccessToken)
|
||||
assert.NotEmpty(t, tokens.RefreshToken)
|
||||
assert.NotEmpty(t, tokens.Extra("id_token"))
|
||||
assert.NotEmpty(t, tokens.IDToken.GetSerialized())
|
||||
assert.NotEmpty(t, tokens.TokenType)
|
||||
assert.NotEmpty(t, tokens.Expiry)
|
||||
|
||||
@@ -67,8 +67,6 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("invalid code", func(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
idp, lc := newLoginCallback(t, url)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{
|
||||
@@ -76,38 +74,17 @@ func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
|
||||
"another-code": {},
|
||||
}
|
||||
|
||||
tokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
tokens, err := lc.RedeemTokens(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginCallback_ProcessTokens(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url)
|
||||
defer idp.Close()
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokens)
|
||||
})
|
||||
|
||||
t.Run("nonce mismatch", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes["some-code"].Nonce = "some-other-nonce"
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
tokens, err := lc.RedeemTokens(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
@@ -117,11 +94,7 @@ func TestLoginCallback_ProcessTokens(t *testing.T) {
|
||||
defer idp.Close()
|
||||
idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id"
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
tokens, err := lc.RedeemTokens(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
wonderwallconfig "github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/scopes"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
@@ -97,12 +96,12 @@ func NewClientConfig(cfg *wonderwallconfig.Config) (Client, error) {
|
||||
return nil, fmt.Errorf("missing required config %s", wonderwallconfig.Ingress)
|
||||
}
|
||||
|
||||
callbackURI, err := openid.RedirectURI(ingress, paths.Callback)
|
||||
callbackURI, err := RedirectURI(ingress, paths.Callback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating callback URI from ingress: %w", err)
|
||||
}
|
||||
|
||||
logoutCallbackURI, err := openid.RedirectURI(ingress, paths.LogoutCallback)
|
||||
logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package openid
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package openid_test
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/config"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestRedirectURI(t *testing.T) {
|
||||
err: fmt.Errorf("ingress cannot be empty"),
|
||||
},
|
||||
} {
|
||||
actual, err := openid.RedirectURI(test.input, test.path)
|
||||
actual, err := config.RedirectURI(test.input, test.path)
|
||||
if test.err != nil {
|
||||
assert.EqualError(t, err, test.err.Error())
|
||||
} else {
|
||||
90
pkg/openid/tokens.go
Normal file
90
pkg/openid/tokens.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package openid
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
)
|
||||
|
||||
type Tokens struct {
|
||||
AccessToken string
|
||||
Expiry time.Time
|
||||
IDToken *IDToken
|
||||
RefreshToken string
|
||||
TokenType string
|
||||
}
|
||||
|
||||
func NewTokens(src *oauth2.Token, jwks jwk.Set) (*Tokens, error) {
|
||||
idToken, err := ParseIDTokenFrom(src, jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing id_token: %w", err)
|
||||
}
|
||||
|
||||
return &Tokens{
|
||||
AccessToken: src.AccessToken,
|
||||
Expiry: src.Expiry,
|
||||
IDToken: idToken,
|
||||
RefreshToken: src.RefreshToken,
|
||||
TokenType: src.TokenType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type IDToken struct {
|
||||
jwt.Token
|
||||
}
|
||||
|
||||
func (in *IDToken) GetSidClaim() (string, error) {
|
||||
return in.GetStringClaim(jwt.SidClaim)
|
||||
}
|
||||
|
||||
func (in *IDToken) Validate(cfg openidconfig.Config, nonce string) error {
|
||||
openIDconfig := cfg.Provider()
|
||||
clientConfig := cfg.Client()
|
||||
|
||||
opts := []jwtlib.ValidateOption{
|
||||
jwtlib.WithAudience(clientConfig.GetClientID()),
|
||||
jwtlib.WithClaimValue("nonce", nonce),
|
||||
jwtlib.WithIssuer(openIDconfig.Issuer),
|
||||
jwtlib.WithAcceptableSkew(5 * time.Second),
|
||||
}
|
||||
|
||||
if openIDconfig.SidClaimRequired() {
|
||||
opts = append(opts, jwtlib.WithRequiredClaim("sid"))
|
||||
}
|
||||
|
||||
if len(clientConfig.GetACRValues()) > 0 {
|
||||
opts = append(opts, jwtlib.WithRequiredClaim("acr"))
|
||||
}
|
||||
|
||||
return jwtlib.Validate(in.GetToken(), opts...)
|
||||
}
|
||||
|
||||
func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken {
|
||||
return &IDToken{
|
||||
jwt.NewToken(raw, jwtToken),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseIDToken(raw string, jwks jwk.Set) (*IDToken, error) {
|
||||
idToken, err := jwt.Parse(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewIDToken(raw, idToken), nil
|
||||
}
|
||||
|
||||
func ParseIDTokenFrom(tokens *oauth2.Token, jwks jwk.Set) (*IDToken, error) {
|
||||
idToken, ok := tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing id_token in token response")
|
||||
}
|
||||
|
||||
return ParseIDToken(idToken, jwks)
|
||||
}
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/loginstatus"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
logentry "github.com/nais/wonderwall/pkg/router/middleware"
|
||||
)
|
||||
@@ -52,19 +51,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rawTokens, err := h.exchangeAuthCode(r.Context(), loginCallback)
|
||||
tokens, err := h.redeemValidTokens(r.Context(), loginCallback)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := loginCallback.ProcessTokens(r.Context(), rawTokens)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = h.createSession(w, r, tokens, rawTokens)
|
||||
err = h.createSession(w, r, tokens)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
@@ -85,12 +78,12 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.LoginCallback) (*oauth2.Token, error) {
|
||||
var tokens *oauth2.Token
|
||||
func (h *Handler) redeemValidTokens(ctx context.Context, loginCallback client.LoginCallback) (*openid.Tokens, error) {
|
||||
var tokens *openid.Tokens
|
||||
var err error
|
||||
|
||||
retryable := func(ctx context.Context) error {
|
||||
tokens, err = loginCallback.ExchangeAuthCode(ctx)
|
||||
tokens, err = loginCallback.RedeemTokens(ctx)
|
||||
if err != nil {
|
||||
log.Warnf("callback: retrying: %+v", err)
|
||||
return retry.RetryableError(err)
|
||||
@@ -107,7 +100,7 @@ func (h *Handler) exchangeAuthCode(ctx context.Context, loginCallback client.Log
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (*loginstatus.TokenResponse, error) {
|
||||
func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
|
||||
var tokenResponse *loginstatus.TokenResponse
|
||||
|
||||
err := retry.Do(ctx, backoff(), func(ctx context.Context) error {
|
||||
@@ -128,10 +121,10 @@ func (h *Handler) getLoginstatusToken(ctx context.Context, tokens *jwt.Tokens) (
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) {
|
||||
func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) {
|
||||
fields := map[string]interface{}{
|
||||
"redirect_to": referer,
|
||||
"claims": tokens.Claims(),
|
||||
"jti": tokens.IDToken.GetJwtID(),
|
||||
}
|
||||
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
|
||||
@@ -36,7 +36,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Errorf("front-channel logout: destroying session: %+v", err)
|
||||
} else if sessionData != nil {
|
||||
log.WithField("claims", sessionData.Claims).Infof("front-channel logout: successful logout")
|
||||
log.WithField("jti", sessionData.IDTokenJwtID).Infof("front-channel logout: successful logout")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
@@ -25,7 +25,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"claims": sessionData.Claims,
|
||||
"jti": sessionData.IDTokenJwtID,
|
||||
}
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
logger.Info().Msg("logout: successful local logout")
|
||||
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -77,7 +76,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
|
||||
return defaultSessionLifetime
|
||||
}
|
||||
|
||||
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token) error {
|
||||
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *openid.Tokens) error {
|
||||
params := r.URL.Query()
|
||||
|
||||
externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params)
|
||||
@@ -85,7 +84,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
return fmt.Errorf("generating session ID: %w", err)
|
||||
}
|
||||
|
||||
sessionLifetime := h.getSessionLifetime(rawTokens.Expiry)
|
||||
sessionLifetime := h.getSessionLifetime(tokens.Expiry)
|
||||
opts := h.CookieOptions.WithExpiresIn(sessionLifetime)
|
||||
|
||||
sessionID := h.localSessionID(externalSessionID)
|
||||
@@ -95,7 +94,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
}
|
||||
|
||||
sessionMetadata := session.NewMetadata(time.Now().Add(sessionLifetime))
|
||||
sessionData := session.NewData(externalSessionID, tokens, rawTokens.RefreshToken, sessionMetadata)
|
||||
sessionData := session.NewData(externalSessionID, tokens, sessionMetadata)
|
||||
|
||||
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
@@ -40,10 +40,9 @@ func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
sessionData, err := rpHandler.GetSessionFallback(w, r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sid", sessionData.ExternalSessionID)
|
||||
assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken)
|
||||
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
|
||||
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
|
||||
assert.Equal(t, "id-token-jti", sessionData.Claims.IDTokenJti)
|
||||
assert.Equal(t, "access-token-jti", sessionData.Claims.AccessTokenJti)
|
||||
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
|
||||
assert.Empty(t, sessionData.RefreshToken)
|
||||
})
|
||||
}
|
||||
@@ -59,7 +58,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
// request should set session cookies in response
|
||||
writer := httptest.NewRecorder()
|
||||
expiresIn := time.Minute
|
||||
data := session.NewData("sid", tokens, "", nil)
|
||||
data := session.NewData("sid", tokens, nil)
|
||||
err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -79,7 +78,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
},
|
||||
{
|
||||
cookieName: "wonderwall-3",
|
||||
want: tokens.AccessToken.GetSerialized(),
|
||||
want: tokens.AccessToken,
|
||||
},
|
||||
} {
|
||||
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
|
||||
@@ -118,10 +117,10 @@ func TestHandler_DeleteSessionFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *jwt.Tokens) *http.Request {
|
||||
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request {
|
||||
writer := httptest.NewRecorder()
|
||||
expiresIn := time.Minute
|
||||
data := session.NewData("sid", tokens, "", nil)
|
||||
data := session.NewData("sid", tokens, nil)
|
||||
err := h.SetSessionFallback(writer, nil, data, expiresIn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -163,7 +162,7 @@ func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedVal
|
||||
assert.Equal(t, expectedValue, string(plainbytes))
|
||||
}
|
||||
|
||||
func makeTokens(provider mock.TestProvider) *jwt.Tokens {
|
||||
func makeTokens(provider mock.TestProvider) *openid.Tokens {
|
||||
jwks := *provider.PrivateJwkSet()
|
||||
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
|
||||
if err != nil {
|
||||
@@ -188,20 +187,10 @@ func makeTokens(provider mock.TestProvider) *jwt.Tokens {
|
||||
log.Fatalf("parsing signed id_token: %+v", err)
|
||||
}
|
||||
|
||||
accessToken := jwtlib.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
accessToken := "some-access-token"
|
||||
|
||||
signedAccessToken, err := jwtlib.Sign(accessToken, jwtlib.WithKey(jwa.RS256, signer))
|
||||
if err != nil {
|
||||
log.Fatalf("signing access_token: %+v", err)
|
||||
}
|
||||
parsedAccessToken, err := jwtlib.Parse(signedAccessToken, jwtlib.WithKeySet(*jwksPublic))
|
||||
if err != nil {
|
||||
log.Fatalf("parsing signed access_token: %+v", err)
|
||||
}
|
||||
|
||||
return &jwt.Tokens{
|
||||
IDToken: jwt.NewIDToken(string(signedIdToken), parsedIdToken),
|
||||
AccessToken: jwt.NewAccessToken(string(signedAccessToken), parsedAccessToken),
|
||||
return &openid.Tokens{
|
||||
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,12 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/provider"
|
||||
)
|
||||
|
||||
@@ -87,15 +90,34 @@ func (c cookieSessionStore) Read(ctx context.Context) (*Data, error) {
|
||||
return nil, fmt.Errorf("callback: getting jwks: %w", err)
|
||||
}
|
||||
|
||||
tokens, err := jwt.ParseTokensFromStrings(idToken, accessToken, *jwkSet)
|
||||
// TODO: currently a placeholder fallback value, should fetch from metadata cookie
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
// attempt to get expiry from access_token if it is a JWT
|
||||
parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet)
|
||||
if err == nil {
|
||||
expiry = parsedAccessToken.Expiration()
|
||||
}
|
||||
|
||||
// TODO: set refresh token and metadata
|
||||
rawTokens := &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: "",
|
||||
Expiry: expiry,
|
||||
}
|
||||
rawTokens = rawTokens.WithExtra(map[string]interface{}{
|
||||
"id_token": idToken,
|
||||
})
|
||||
|
||||
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
|
||||
if err != nil {
|
||||
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
|
||||
_, _ = c.provider.RefreshPublicJwkSet(ctx)
|
||||
return nil, fmt.Errorf("parsing tokens: %w", err)
|
||||
}
|
||||
|
||||
// TODO: set refresh token and metadata
|
||||
return NewData(externalSessionID, tokens, "", nil), nil
|
||||
return NewData(externalSessionID, tokens, nil), nil
|
||||
}
|
||||
|
||||
func (c cookieSessionStore) Delete() {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -22,16 +22,16 @@ func TestMemory(t *testing.T) {
|
||||
idToken := jwtlib.New()
|
||||
idToken.Set("jti", "id-token-jti")
|
||||
|
||||
accessToken := jwtlib.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
|
||||
tokens := &jwt.Tokens{
|
||||
IDToken: jwt.NewIDToken("id_token", idToken),
|
||||
AccessToken: jwt.NewAccessToken("access_token", accessToken),
|
||||
}
|
||||
accessToken := "some-access-token"
|
||||
refreshToken := "some-refresh-token"
|
||||
|
||||
tokens := &openid.Tokens{
|
||||
AccessToken: accessToken,
|
||||
IDToken: openid.NewIDToken("id_token", idToken),
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
metadata := session.NewMetadata(time.Now().Add(time.Hour))
|
||||
data := session.NewData("myid", tokens, refreshToken, metadata)
|
||||
data := session.NewData("myid", tokens, metadata)
|
||||
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -24,16 +24,16 @@ func TestRedis(t *testing.T) {
|
||||
idToken := jwtlib.New()
|
||||
idToken.Set("jti", "id-token-jti")
|
||||
|
||||
accessToken := jwtlib.New()
|
||||
accessToken.Set("jti", "access-token-jti")
|
||||
|
||||
tokens := &jwt.Tokens{
|
||||
IDToken: jwt.NewIDToken("id_token", idToken),
|
||||
AccessToken: jwt.NewAccessToken("access_token", accessToken),
|
||||
}
|
||||
accessToken := "some-access-token"
|
||||
refreshToken := "some-refresh-token"
|
||||
|
||||
tokens := &openid.Tokens{
|
||||
AccessToken: accessToken,
|
||||
IDToken: openid.NewIDToken("id_token", idToken),
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
metadata := session.NewMetadata(time.Now().Add(time.Hour))
|
||||
data := session.NewData("myid", tokens, refreshToken, metadata)
|
||||
data := session.NewData("myid", tokens, metadata)
|
||||
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
|
||||
type EncryptedData struct {
|
||||
@@ -46,21 +46,21 @@ func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) {
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
ExternalSessionID string `json:"external_session_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Claims jwt.Claims `json:"claims"`
|
||||
Metadata Metadata `json:"metadata"`
|
||||
ExternalSessionID string `json:"external_session_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDTokenJwtID string `json:"id_token_jwt_id"`
|
||||
Metadata Metadata `json:"metadata"`
|
||||
}
|
||||
|
||||
func NewData(externalSessionID string, tokens *jwt.Tokens, refreshToken string, metadata *Metadata) *Data {
|
||||
func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data {
|
||||
data := &Data{
|
||||
ExternalSessionID: externalSessionID,
|
||||
AccessToken: tokens.AccessToken.GetSerialized(),
|
||||
AccessToken: tokens.AccessToken,
|
||||
IDToken: tokens.IDToken.GetSerialized(),
|
||||
RefreshToken: refreshToken,
|
||||
Claims: tokens.Claims(),
|
||||
IDTokenJwtID: tokens.IDToken.GetJwtID(),
|
||||
RefreshToken: tokens.RefreshToken,
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"io"
|
||||
"net/url"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/config"
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
SessionStateParamKey = "session_state"
|
||||
)
|
||||
|
||||
func NewSessionID(cfg *config.Provider, idToken *jwt.IDToken, params url.Values) (string, error) {
|
||||
func NewSessionID(cfg *config.Provider, idToken *openid.IDToken, params url.Values) (string, error) {
|
||||
// 1. check for 'sid' claim in id_token
|
||||
sessionID, err := idToken.GetSidClaim()
|
||||
if err == nil {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwt"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/config"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
@@ -17,7 +17,7 @@ func TestSessionID(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
config *config.Provider
|
||||
idToken *jwt.IDToken
|
||||
idToken *openid.IDToken
|
||||
params url.Values
|
||||
want string
|
||||
exactMatch bool
|
||||
@@ -136,7 +136,7 @@ func params(key, value string) url.Values {
|
||||
return values
|
||||
}
|
||||
|
||||
func newIDToken(extraClaims map[string]string) *jwt.IDToken {
|
||||
func newIDToken(extraClaims map[string]string) *openid.IDToken {
|
||||
idToken := jwtlib.New()
|
||||
idToken.Set("sub", "test")
|
||||
idToken.Set("iss", "test")
|
||||
@@ -155,15 +155,15 @@ func newIDToken(extraClaims map[string]string) *jwt.IDToken {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return jwt.NewIDToken(string(serialized), idToken)
|
||||
return openid.NewIDToken(string(serialized), idToken)
|
||||
}
|
||||
|
||||
func idTokenWithSid(sid string) *jwt.IDToken {
|
||||
func idTokenWithSid(sid string) *openid.IDToken {
|
||||
return newIDToken(map[string]string{
|
||||
"sid": sid,
|
||||
})
|
||||
}
|
||||
|
||||
func idToken() *jwt.IDToken {
|
||||
func idToken() *openid.IDToken {
|
||||
return newIDToken(nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user