refactor: separate parsing and validation of id_token

This commit is contained in:
Trong Huu Nguyen
2021-09-07 21:30:38 +02:00
parent 09bbc35df7
commit 55002e3cfe
2 changed files with 24 additions and 26 deletions

View File

@@ -268,15 +268,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
parseOpts := []jwt.ParseOption{
jwt.WithRequiredClaim("sid"),
}
if h.Config.SecurityLevel.Enabled {
parseOpts = append(parseOpts, jwt.WithRequiredClaim("acr"))
}
idToken, err := token.ParseIDToken(r.Context(), h.jwkSet, tokens, parseOpts...)
idToken, err := token.ParseIDToken(r.Context(), h.jwkSet, tokens)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusUnauthorized)
@@ -288,6 +280,11 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
jwt.WithClaimValue("nonce", cookies.Nonce),
jwt.WithIssuer(h.Config.WellKnown.Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
jwt.WithRequiredClaim("sid"),
}
if h.Config.SecurityLevel.Enabled {
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
}
err = idToken.Validate(validateOpts...)
@@ -297,7 +294,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
sessionID := h.localSessionID(idToken.ExternalSessionID)
externalSessionID, ok := idToken.GetSID()
if !ok {
log.Error("missing required 'sid' claim")
w.WriteHeader(http.StatusUnauthorized)
return
}
sessionID := h.localSessionID(externalSessionID)
err = h.setEncryptedCookie(w, SessionCookieName, sessionID, h.Config.SessionMaxLifetime)
if err != nil {
@@ -307,7 +310,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
}
err = h.Sessions.Write(r.Context(), sessionID, &session.Data{
ExternalSessionID: idToken.ExternalSessionID,
ExternalSessionID: externalSessionID,
OAuth2Token: tokens,
IDTokenSerialized: idToken.Raw,
}, h.Config.SessionMaxLifetime)

View File

@@ -23,9 +23,8 @@ type JWTTokenRequest struct {
}
type IDToken struct {
Raw string
ExternalSessionID string
Token jwt.Token
Raw string
Token jwt.Token
}
func (in *IDToken) Validate(opts ...jwt.ValidateOption) error {
@@ -37,7 +36,12 @@ func (in *IDToken) Validate(opts ...jwt.ValidateOption) error {
return nil
}
func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token, opts ...jwt.ParseOption) (*IDToken, error) {
func (in *IDToken) GetSID() (string, bool) {
sid, ok := in.Token.Get("sid")
return sid.(string), ok
}
func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token) (*IDToken, error) {
raw, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token in token response")
@@ -50,24 +54,15 @@ func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token, opts .
parseOpts := []jwt.ParseOption{
jwt.WithKeySet(jwks),
jwt.WithValidate(true),
}
parseOpts = append(parseOpts, opts...)
idToken, err := jwt.Parse([]byte(raw), parseOpts...)
if err != nil {
return nil, fmt.Errorf("parsing jwt: %w", err)
}
sid, ok := idToken.Get("sid")
if !ok {
return nil, fmt.Errorf("missing 'sid' claim in id_token")
}
result := &IDToken{
Raw: raw,
ExternalSessionID: sid.(string),
Token: idToken,
Raw: raw,
Token: idToken,
}
return result, nil