diff --git a/pkg/jwt/access_token.go b/pkg/jwt/access_token.go deleted file mode 100644 index d3a44db..0000000 --- a/pkg/jwt/access_token.go +++ /dev/null @@ -1,37 +0,0 @@ -package jwt - -import ( - "github.com/lestrrat-go/jwx/jwk" - "github.com/lestrrat-go/jwx/jwt" -) - -type AccessToken struct { - Raw string - Token jwt.Token - Type Type -} - -func (in *AccessToken) GetJtiClaim() string { - return GetStringClaimOrEmpty(in.Token, JtiClaim) -} - -func (in *AccessToken) GetStringClaim(claim string) (string, error) { - return GetStringClaim(in.Token, claim) -} - -func NewAccessToken(raw string, token jwt.Token) *AccessToken { - return &AccessToken{ - Raw: raw, - Token: token, - Type: TypeAccessToken, - } -} - -func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) { - accessToken, err := Parse(raw, jwks) - if err != nil { - return nil, err - } - - return NewAccessToken(raw, accessToken), nil -} diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index 45c5161..3a00135 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -2,11 +2,75 @@ package jwt import ( "fmt" + "time" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" ) +type Token interface { + GetExpiration() time.Time + GetJtiClaim() string + GetSerialized() string + GetStringClaim(claim string) (string, error) + GetToken() jwt.Token +} + +type token struct { + serialized string + token jwt.Token +} + +func (in *token) GetExpiration() time.Time { + return in.token.Expiration() +} + +func (in *token) GetJtiClaim() string { + return in.GetStringClaimOrEmpty(JtiClaim) +} + +func (in *token) GetSerialized() string { + return in.serialized +} + +func (in *token) GetStringClaim(claim string) (string, error) { + if in.token == nil { + return "", fmt.Errorf("token is nil") + } + + gotClaim, ok := in.token.Get(claim) + if !ok { + return "", fmt.Errorf("missing required '%s' claim in id_token", claim) + } + + claimString, ok := gotClaim.(string) + if !ok { + return "", fmt.Errorf("'%s' claim is not a string", claim) + } + + return claimString, nil +} + +func (in *token) GetStringClaimOrEmpty(claim string) string { + str, err := in.GetStringClaim(claim) + if err != nil { + return "" + } + + return str +} + +func (in *token) GetToken() jwt.Token { + return in.token +} + +func NewToken(raw string, jwtToken jwt.Token) Token { + return &token{ + serialized: raw, + token: jwtToken, + } +} + func Parse(raw string, jwks jwk.Set) (jwt.Token, error) { parseOpts := []jwt.ParseOption{ jwt.WithKeySet(jwks), @@ -19,30 +83,3 @@ func Parse(raw string, jwks jwk.Set) (jwt.Token, error) { return token, nil } - -func GetStringClaim(token jwt.Token, claim string) (string, error) { - if token == nil { - return "", fmt.Errorf("token is nil") - } - - gotClaim, ok := token.Get(claim) - if !ok { - return "", fmt.Errorf("missing required '%s' claim in id_token", claim) - } - - claimString, ok := gotClaim.(string) - if !ok { - return "", fmt.Errorf("'%s' claim is not a string", claim) - } - - return claimString, nil -} - -func GetStringClaimOrEmpty(token jwt.Token, claim string) string { - str, err := GetStringClaim(token, claim) - if err != nil { - return "" - } - - return str -} diff --git a/pkg/jwt/jwt_access_token.go b/pkg/jwt/jwt_access_token.go new file mode 100644 index 0000000..748e298 --- /dev/null +++ b/pkg/jwt/jwt_access_token.go @@ -0,0 +1,25 @@ +package jwt + +import ( + "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" +) + +type AccessToken struct { + Token +} + +func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken { + return &AccessToken{ + NewToken(raw, jwtToken), + } +} + +func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) { + accessToken, err := Parse(raw, jwks) + if err != nil { + return nil, err + } + + return NewAccessToken(raw, accessToken), nil +} diff --git a/pkg/jwt/id_token.go b/pkg/jwt/jwt_id_token.go similarity index 72% rename from pkg/jwt/id_token.go rename to pkg/jwt/jwt_id_token.go index d45c516..770f219 100644 --- a/pkg/jwt/id_token.go +++ b/pkg/jwt/jwt_id_token.go @@ -10,23 +10,13 @@ import ( ) type IDToken struct { - Raw string - Token jwt.Token - Type Type -} - -func (in *IDToken) GetJtiClaim() string { - return GetStringClaimOrEmpty(in.Token, JtiClaim) + Token } func (in *IDToken) GetSidClaim() (string, error) { return in.GetStringClaim(SidClaim) } -func (in *IDToken) GetStringClaim(claim string) (string, error) { - return GetStringClaim(in.Token, claim) -} - func (in *IDToken) Validate(provider openid.Provider, nonce string) error { openIDconfig := provider.GetOpenIDConfiguration() clientConfig := provider.GetClientConfiguration() @@ -46,14 +36,12 @@ func (in *IDToken) Validate(provider openid.Provider, nonce string) error { opts = append(opts, jwt.WithRequiredClaim("acr")) } - return jwt.Validate(in.Token, opts...) + return jwt.Validate(in.GetToken(), opts...) } -func NewIDToken(raw string, token jwt.Token) *IDToken { +func NewIDToken(raw string, jwtToken jwt.Token) *IDToken { return &IDToken{ - Raw: raw, - Token: token, - Type: TypeIDToken, + NewToken(raw, jwtToken), } } diff --git a/pkg/jwt/type.go b/pkg/jwt/type.go deleted file mode 100644 index f313bfe..0000000 --- a/pkg/jwt/type.go +++ /dev/null @@ -1,8 +0,0 @@ -package jwt - -type Type int - -const ( - TypeIDToken Type = iota - TypeAccessToken -) diff --git a/pkg/loginstatus/loginstatus.go b/pkg/loginstatus/loginstatus.go index 484767d..77c5db1 100644 --- a/pkg/loginstatus/loginstatus.go +++ b/pkg/loginstatus/loginstatus.go @@ -105,7 +105,7 @@ func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Req return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized())) req.Header.Set("Accept", "application/json") return req, nil diff --git a/pkg/router/session.go b/pkg/router/session.go index 10014f7..bc192f6 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -67,7 +67,7 @@ func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Da func (h *Handler) getSessionLifetime(accessToken *jwt.AccessToken) time.Duration { defaultSessionLifetime := h.Config.SessionMaxLifetime - tokenDuration := accessToken.Token.Expiration().Sub(time.Now()) + tokenDuration := accessToken.GetExpiration().Sub(time.Now()) if tokenDuration <= defaultSessionLifetime { return tokenDuration diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index 4835623..533d88f 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -34,8 +34,8 @@ func TestHandler_GetSessionFallback(t *testing.T) { sessionData, err := h.GetSessionFallback(r) assert.NoError(t, err) assert.Equal(t, "sid", sessionData.ExternalSessionID) - assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken) - assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken) + assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken) + assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken) assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken) assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken) }) @@ -65,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) { }, { cookieName: h.SessionFallbackIDTokenCookieName(), - want: tokens.IDToken.Raw, + want: tokens.IDToken.GetSerialized(), }, { cookieName: h.SessionFallbackAccessTokenCookieName(), - want: tokens.AccessToken.Raw, + want: tokens.AccessToken.GetSerialized(), }, } { assertCookieExists(t, h, test.cookieName, test.want, cookies) diff --git a/pkg/session/session.go b/pkg/session/session.go index ddea52b..8c51fd6 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -89,8 +89,8 @@ type Data struct { func NewData(externalSessionID string, tokens *jwt.Tokens) *Data { return &Data{ ExternalSessionID: externalSessionID, - AccessToken: tokens.AccessToken.Raw, - IDToken: tokens.IDToken.Raw, + AccessToken: tokens.AccessToken.GetSerialized(), + IDToken: tokens.IDToken.GetSerialized(), JwtIDs: tokens.JwtIDs(), } }