refactor(jwt): clean up and deduplicate

This commit is contained in:
Trong Huu Nguyen
2022-02-03 09:33:18 +01:00
parent 59532eab0f
commit 3828437dc5
9 changed files with 101 additions and 96 deletions

View File

@@ -1,37 +0,0 @@
package jwt
import (
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)
type AccessToken struct {
Raw string
Token jwt.Token
Type Type
}
func (in *AccessToken) GetJtiClaim() string {
return GetStringClaimOrEmpty(in.Token, JtiClaim)
}
func (in *AccessToken) GetStringClaim(claim string) (string, error) {
return GetStringClaim(in.Token, claim)
}
func NewAccessToken(raw string, token jwt.Token) *AccessToken {
return &AccessToken{
Raw: raw,
Token: token,
Type: TypeAccessToken,
}
}
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := Parse(raw, jwks)
if err != nil {
return nil, err
}
return NewAccessToken(raw, accessToken), nil
}

View File

@@ -2,11 +2,75 @@ package jwt
import (
"fmt"
"time"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)
type Token interface {
GetExpiration() time.Time
GetJtiClaim() string
GetSerialized() string
GetStringClaim(claim string) (string, error)
GetToken() jwt.Token
}
type token struct {
serialized string
token jwt.Token
}
func (in *token) GetExpiration() time.Time {
return in.token.Expiration()
}
func (in *token) GetJtiClaim() string {
return in.GetStringClaimOrEmpty(JtiClaim)
}
func (in *token) GetSerialized() string {
return in.serialized
}
func (in *token) GetStringClaim(claim string) (string, error) {
if in.token == nil {
return "", fmt.Errorf("token is nil")
}
gotClaim, ok := in.token.Get(claim)
if !ok {
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
}
claimString, ok := gotClaim.(string)
if !ok {
return "", fmt.Errorf("'%s' claim is not a string", claim)
}
return claimString, nil
}
func (in *token) GetStringClaimOrEmpty(claim string) string {
str, err := in.GetStringClaim(claim)
if err != nil {
return ""
}
return str
}
func (in *token) GetToken() jwt.Token {
return in.token
}
func NewToken(raw string, jwtToken jwt.Token) Token {
return &token{
serialized: raw,
token: jwtToken,
}
}
func Parse(raw string, jwks jwk.Set) (jwt.Token, error) {
parseOpts := []jwt.ParseOption{
jwt.WithKeySet(jwks),
@@ -19,30 +83,3 @@ func Parse(raw string, jwks jwk.Set) (jwt.Token, error) {
return token, nil
}
func GetStringClaim(token jwt.Token, claim string) (string, error) {
if token == nil {
return "", fmt.Errorf("token is nil")
}
gotClaim, ok := token.Get(claim)
if !ok {
return "", fmt.Errorf("missing required '%s' claim in id_token", claim)
}
claimString, ok := gotClaim.(string)
if !ok {
return "", fmt.Errorf("'%s' claim is not a string", claim)
}
return claimString, nil
}
func GetStringClaimOrEmpty(token jwt.Token, claim string) string {
str, err := GetStringClaim(token, claim)
if err != nil {
return ""
}
return str
}

View File

@@ -0,0 +1,25 @@
package jwt
import (
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)
type AccessToken struct {
Token
}
func NewAccessToken(raw string, jwtToken jwt.Token) *AccessToken {
return &AccessToken{
NewToken(raw, jwtToken),
}
}
func ParseAccessToken(raw string, jwks jwk.Set) (*AccessToken, error) {
accessToken, err := Parse(raw, jwks)
if err != nil {
return nil, err
}
return NewAccessToken(raw, accessToken), nil
}

View File

@@ -10,23 +10,13 @@ import (
)
type IDToken struct {
Raw string
Token jwt.Token
Type Type
}
func (in *IDToken) GetJtiClaim() string {
return GetStringClaimOrEmpty(in.Token, JtiClaim)
Token
}
func (in *IDToken) GetSidClaim() (string, error) {
return in.GetStringClaim(SidClaim)
}
func (in *IDToken) GetStringClaim(claim string) (string, error) {
return GetStringClaim(in.Token, claim)
}
func (in *IDToken) Validate(provider openid.Provider, nonce string) error {
openIDconfig := provider.GetOpenIDConfiguration()
clientConfig := provider.GetClientConfiguration()
@@ -46,14 +36,12 @@ func (in *IDToken) Validate(provider openid.Provider, nonce string) error {
opts = append(opts, jwt.WithRequiredClaim("acr"))
}
return jwt.Validate(in.Token, opts...)
return jwt.Validate(in.GetToken(), opts...)
}
func NewIDToken(raw string, token jwt.Token) *IDToken {
func NewIDToken(raw string, jwtToken jwt.Token) *IDToken {
return &IDToken{
Raw: raw,
Token: token,
Type: TypeIDToken,
NewToken(raw, jwtToken),
}
}

View File

@@ -1,8 +0,0 @@
package jwt
type Type int
const (
TypeIDToken Type = iota
TypeAccessToken
)

View File

@@ -105,7 +105,7 @@ func request(ctx context.Context, url string, token *jwt.AccessToken) (*http.Req
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.GetSerialized()))
req.Header.Set("Accept", "application/json")
return req, nil

View File

@@ -67,7 +67,7 @@ func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Da
func (h *Handler) getSessionLifetime(accessToken *jwt.AccessToken) time.Duration {
defaultSessionLifetime := h.Config.SessionMaxLifetime
tokenDuration := accessToken.Token.Expiration().Sub(time.Now())
tokenDuration := accessToken.GetExpiration().Sub(time.Now())
if tokenDuration <= defaultSessionLifetime {
return tokenDuration

View File

@@ -34,8 +34,8 @@ func TestHandler_GetSessionFallback(t *testing.T) {
sessionData, err := h.GetSessionFallback(r)
assert.NoError(t, err)
assert.Equal(t, "sid", sessionData.ExternalSessionID)
assert.Equal(t, tokens.AccessToken.Raw, sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.Raw, sessionData.IDToken)
assert.Equal(t, tokens.AccessToken.GetSerialized(), sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
assert.Equal(t, "id-token-jti", sessionData.JwtIDs.IDToken)
assert.Equal(t, "access-token-jti", sessionData.JwtIDs.AccessToken)
})
@@ -65,11 +65,11 @@ func TestHandler_SetSessionFallback(t *testing.T) {
},
{
cookieName: h.SessionFallbackIDTokenCookieName(),
want: tokens.IDToken.Raw,
want: tokens.IDToken.GetSerialized(),
},
{
cookieName: h.SessionFallbackAccessTokenCookieName(),
want: tokens.AccessToken.Raw,
want: tokens.AccessToken.GetSerialized(),
},
} {
assertCookieExists(t, h, test.cookieName, test.want, cookies)

View File

@@ -89,8 +89,8 @@ type Data struct {
func NewData(externalSessionID string, tokens *jwt.Tokens) *Data {
return &Data{
ExternalSessionID: externalSessionID,
AccessToken: tokens.AccessToken.Raw,
IDToken: tokens.IDToken.Raw,
AccessToken: tokens.AccessToken.GetSerialized(),
IDToken: tokens.IDToken.GetSerialized(),
JwtIDs: tokens.JwtIDs(),
}
}