mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-16 05:16:37 +00:00
refactor(jwt): clean up and deduplicate
This commit is contained in:
@@ -1,37 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Raw string
|
||||
Token jwt.Token
|
||||
Type Type
|
||||
}
|
||||
|
||||
func (in *AccessToken) GetJtiClaim() string {
|
||||
return GetStringClaimOrEmpty(in.Token, JtiClaim)
|
||||
}
|
||||
|
||||
func (in *AccessToken) GetStringClaim(claim string) (string, error) {
|
||||
return GetStringClaim(in.Token, claim)
|
||||
}
|
||||
|
||||
func NewAccessToken(raw string, token jwt.Token) *AccessToken {
|
||||
return &AccessToken{
|
||||
Raw: raw,
|
||||
Token: token,
|
||||
Type: TypeAccessToken,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
|
||||
accessToken, err := Parse(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAccessToken(raw, accessToken), nil
|
||||
}
|
||||
@@ -2,11 +2,75 @@ package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
GetExpiration() time.Time
|
||||
GetJtiClaim() string
|
||||
GetSerialized() string
|
||||
GetStringClaim(claim string) (string, error)
|
||||
GetToken() jwt.Token
|
||||
}
|
||||
|
||||
type token struct {
|
||||
serialized string
|
||||
token jwt.Token
|
||||
}
|
||||
|
||||
func (in *token) GetExpiration() time.Time {
|
||||
return in.token.Expiration()
|
||||
}
|
||||
|
||||
func (in *token) GetJtiClaim() string {
|
||||
return in.GetStringClaimOrEmpty(JtiClaim)
|
||||
}
|
||||
|
||||
func (in *token) GetSerialized() string {
|
||||
return in.serialized
|
||||
}
|
||||
|
||||
func (in *token) GetStringClaim(claim string) (string, error) {
|
||||
if in.token == nil {
|
||||
return "", fmt.Errorf("token is nil")
|
||||
}
|
||||
|
||||
gotClaim, ok := in.token.Get(claim)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
|
||||
}
|
||||
|
||||
claimString, ok := gotClaim.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("'%s' claim is not a string", claim)
|
||||
}
|
||||
|
||||
return claimString, nil
|
||||
}
|
||||
|
||||
func (in *token) GetStringClaimOrEmpty(claim string) string {
|
||||
str, err := in.GetStringClaim(claim)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
func (in *token) GetToken() jwt.Token {
|
||||
return in.token
|
||||
}
|
||||
|
||||
func NewToken(raw string, jwtToken jwt.Token) Token {
|
||||
return &token{
|
||||
serialized: raw,
|
||||
token: jwtToken,
|
||||
}
|
||||
}
|
||||
|
||||
func Parse(raw string, jwks jwk.Set) (jwt.Token, error) {
|
||||
parseOpts := []jwt.ParseOption{
|
||||
jwt.WithKeySet(jwks),
|
||||
@@ -19,30 +83,3 @@ func Parse(raw string, jwks jwk.Set) (jwt.Token, error) {
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func GetStringClaim(token jwt.Token, claim string) (string, error) {
|
||||
if token == nil {
|
||||
return "", fmt.Errorf("token is nil")
|
||||
}
|
||||
|
||||
gotClaim, ok := token.Get(claim)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
|
||||
}
|
||||
|
||||
claimString, ok := gotClaim.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("'%s' claim is not a string", claim)
|
||||
}
|
||||
|
||||
return claimString, nil
|
||||
}
|
||||
|
||||
func GetStringClaimOrEmpty(token jwt.Token, claim string) string {
|
||||
str, err := GetStringClaim(token, claim)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
25
pkg/jwt/jwt_access_token.go
Normal file
25
pkg/jwt/jwt_access_token.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Token
|
||||
}
|
||||
|
||||
func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken {
|
||||
return &AccessToken{
|
||||
NewToken(raw, jwtToken),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
|
||||
accessToken, err := Parse(raw, jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAccessToken(raw, accessToken), nil
|
||||
}
|
||||
@@ -10,23 +10,13 @@ import (
|
||||
)
|
||||
|
||||
type IDToken struct {
|
||||
Raw string
|
||||
Token jwt.Token
|
||||
Type Type
|
||||
}
|
||||
|
||||
func (in *IDToken) GetJtiClaim() string {
|
||||
return GetStringClaimOrEmpty(in.Token, JtiClaim)
|
||||
Token
|
||||
}
|
||||
|
||||
func (in *IDToken) GetSidClaim() (string, error) {
|
||||
return in.GetStringClaim(SidClaim)
|
||||
}
|
||||
|
||||
func (in *IDToken) GetStringClaim(claim string) (string, error) {
|
||||
return GetStringClaim(in.Token, claim)
|
||||
}
|
||||
|
||||
func (in *IDToken) Validate(provider openid.Provider, nonce string) error {
|
||||
openIDconfig := provider.GetOpenIDConfiguration()
|
||||
clientConfig := provider.GetClientConfiguration()
|
||||
@@ -46,14 +36,12 @@ func (in *IDToken) Validate(provider openid.Provider, nonce string) error {
|
||||
opts = append(opts, jwt.WithRequiredClaim("acr"))
|
||||
}
|
||||
|
||||
return jwt.Validate(in.Token, opts...)
|
||||
return jwt.Validate(in.GetToken(), opts...)
|
||||
}
|
||||
|
||||
func NewIDToken(raw string, token jwt.Token) *IDToken {
|
||||
func NewIDToken(raw string, jwtToken jwt.Token) *IDToken {
|
||||
return &IDToken{
|
||||
Raw: raw,
|
||||
Token: token,
|
||||
Type: TypeIDToken,
|
||||
NewToken(raw, jwtToken),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
package jwt
|
||||
|
||||
type Type int
|
||||
|
||||
const (
|
||||
TypeIDToken Type = iota
|
||||
TypeAccessToken
|
||||
)
|
||||
@@ -105,7 +105,7 @@ func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Req
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized()))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -67,7 +67,7 @@ func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Da
|
||||
func (h *Handler) getSessionLifetime(accessToken *jwt.AccessToken) time.Duration {
|
||||
defaultSessionLifetime := h.Config.SessionMaxLifetime
|
||||
|
||||
tokenDuration := accessToken.Token.Expiration().Sub(time.Now())
|
||||
tokenDuration := accessToken.GetExpiration().Sub(time.Now())
|
||||
|
||||
if tokenDuration <= defaultSessionLifetime {
|
||||
return tokenDuration
|
||||
|
||||
@@ -34,8 +34,8 @@ func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
sessionData, err := h.GetSessionFallback(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sid", sessionData.ExternalSessionID)
|
||||
assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken)
|
||||
assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken)
|
||||
assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken)
|
||||
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
|
||||
assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken)
|
||||
assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken)
|
||||
})
|
||||
@@ -65,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
},
|
||||
{
|
||||
cookieName: h.SessionFallbackIDTokenCookieName(),
|
||||
want: tokens.IDToken.Raw,
|
||||
want: tokens.IDToken.GetSerialized(),
|
||||
},
|
||||
{
|
||||
cookieName: h.SessionFallbackAccessTokenCookieName(),
|
||||
want: tokens.AccessToken.Raw,
|
||||
want: tokens.AccessToken.GetSerialized(),
|
||||
},
|
||||
} {
|
||||
assertCookieExists(t, h, test.cookieName, test.want, cookies)
|
||||
|
||||
@@ -89,8 +89,8 @@ type Data struct {
|
||||
func NewData(externalSessionID string, tokens *jwt.Tokens) *Data {
|
||||
return &Data{
|
||||
ExternalSessionID: externalSessionID,
|
||||
AccessToken: tokens.AccessToken.Raw,
|
||||
IDToken: tokens.IDToken.Raw,
|
||||
AccessToken: tokens.AccessToken.GetSerialized(),
|
||||
IDToken: tokens.IDToken.GetSerialized(),
|
||||
JwtIDs: tokens.JwtIDs(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user