From 55002e3cfed725322c5874d043a3934b91803e99 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 7 Sep 2021 21:30:38 +0200 Subject: [PATCH] refactor: separate parsing and validation of id_token --- pkg/router/router.go | 25 ++++++++++++++----------- pkg/token/token.go | 25 ++++++++++--------------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index f79097b..bf1b4ca 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -268,15 +268,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - parseOpts := []jwt.ParseOption{ - jwt.WithRequiredClaim("sid"), - } - - if h.Config.SecurityLevel.Enabled { - parseOpts = append(parseOpts, jwt.WithRequiredClaim("acr")) - } - - idToken, err := token.ParseIDToken(r.Context(), h.jwkSet, tokens, parseOpts...) + idToken, err := token.ParseIDToken(r.Context(), h.jwkSet, tokens) if err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) @@ -288,6 +280,11 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { jwt.WithClaimValue("nonce", cookies.Nonce), jwt.WithIssuer(h.Config.WellKnown.Issuer), jwt.WithAcceptableSkew(5 * time.Second), + jwt.WithRequiredClaim("sid"), + } + + if h.Config.SecurityLevel.Enabled { + validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr")) } err = idToken.Validate(validateOpts...) @@ -297,7 +294,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - sessionID := h.localSessionID(idToken.ExternalSessionID) + externalSessionID, ok := idToken.GetSID() + if !ok { + log.Error("missing required 'sid' claim") + w.WriteHeader(http.StatusUnauthorized) + return + } + sessionID := h.localSessionID(externalSessionID) err = h.setEncryptedCookie(w, SessionCookieName, sessionID, h.Config.SessionMaxLifetime) if err != nil { @@ -307,7 +310,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { } err = h.Sessions.Write(r.Context(), sessionID, &session.Data{ - ExternalSessionID: idToken.ExternalSessionID, + ExternalSessionID: externalSessionID, OAuth2Token: tokens, IDTokenSerialized: idToken.Raw, }, h.Config.SessionMaxLifetime) diff --git a/pkg/token/token.go b/pkg/token/token.go index 96967b5..d207f3f 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -23,9 +23,8 @@ type JWTTokenRequest struct { } type IDToken struct { - Raw string - ExternalSessionID string - Token jwt.Token + Raw string + Token jwt.Token } func (in *IDToken) Validate(opts ...jwt.ValidateOption) error { @@ -37,7 +36,12 @@ func (in *IDToken) Validate(opts ...jwt.ValidateOption) error { return nil } -func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token, opts ...jwt.ParseOption) (*IDToken, error) { +func (in *IDToken) GetSID() (string, bool) { + sid, ok := in.Token.Get("sid") + return sid.(string), ok +} + +func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token) (*IDToken, error) { raw, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("missing id_token in token response") @@ -50,24 +54,15 @@ func ParseIDToken(ctx context.Context, jwks jwk.Set, token *oauth2.Token, opts . parseOpts := []jwt.ParseOption{ jwt.WithKeySet(jwks), - jwt.WithValidate(true), } - parseOpts = append(parseOpts, opts...) - idToken, err := jwt.Parse([]byte(raw), parseOpts...) if err != nil { return nil, fmt.Errorf("parsing jwt: %w", err) } - sid, ok := idToken.Get("sid") - if !ok { - return nil, fmt.Errorf("missing 'sid' claim in id_token") - } - result := &IDToken{ - Raw: raw, - ExternalSessionID: sid.(string), - Token: idToken, + Raw: raw, + Token: idToken, } return result, nil