diff --git a/pkg/mock/handler.go b/pkg/mock/handler.go index 7f5baab..3e2d8fe 100644 --- a/pkg/mock/handler.go +++ b/pkg/mock/handler.go @@ -94,7 +94,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ v := url.Values{} v.Set("code", code) v.Set("state", state) - if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() { + if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() { v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host))) } @@ -172,7 +172,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) v := url.Values{} v.Set("error", "Unauthenticated") v.Set("error_description", "invalid client assertion") - if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() { + if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() { v.Set("session_state", ip.SessionStates[clientID]) } v.Encode() @@ -206,7 +206,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) idToken.Set("exp", time.Now().Unix()+expires) // If the sid claim should be in token and in active session - if !ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() { + if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() { idToken.Set("sid", sid) ip.Sessions[sid] = clientID } @@ -225,7 +225,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) ExpiresIn: expires, } - if ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() { + if ip.Provider.OpenIDConfiguration.SessionStateRequired() { sessionState := ip.SessionStates[clientID] token.SessionState = sessionState ip.Sessions[sessionState] = clientID diff --git a/pkg/openid/configuration.go b/pkg/openid/configuration.go index 9a4c822..c86c16c 100644 --- a/pkg/openid/configuration.go +++ b/pkg/openid/configuration.go @@ -73,7 +73,7 @@ func (c *Configuration) FetchJwkSet(ctx context.Context) (*jwk.Set, error) { return &jwkSet, nil } -func (c *Configuration) GetCheckSessionIframe() bool { +func (c *Configuration) SessionStateRequired() bool { return len(c.CheckSessionIframe) > 0 } diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index ac7072a..58e7c08 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -2,10 +2,7 @@ package router import ( "context" - "crypto/rand" - "encoding/base64" "fmt" - "io" "net/http" "net/url" "time" @@ -49,12 +46,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - sessionID, err := h.validateIDToken(idToken, loginCookie, params) + err = h.validateIDToken(idToken, loginCookie, params) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err)) return } + sessionID, err := SessionID(h.Provider.GetOpenIDConfiguration(), idToken, params) + if err != nil { + h.InternalError(w, r, fmt.Errorf("callback: generating session ID: %w", err)) + return + } + err = h.createSession(w, r, sessionID, tokens, idToken) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) @@ -86,7 +89,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid. return tokens, nil } -func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) (string, error) { +func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) error { openIDconfig := h.Provider.GetOpenIDConfiguration() clientConfig := h.Provider.GetClientConfiguration() @@ -105,56 +108,5 @@ func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.L validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr")) } - err := idToken.Validate(validateOpts...) - if err != nil { - return "", err - } - - sessionID, err := h.SessionId(idToken, params) - if err != nil { - return "", fmt.Errorf("getting external session ID from id_token: %w", err) - } - - return sessionID, nil -} - -func (h *Handler) SessionId(idToken *openid.IDToken, params url.Values) (string, error) { - var openIDconfig = h.Provider.GetOpenIDConfiguration() - var sessionID string - var err error - - switch { - case openIDconfig.SidClaimRequired(): - sessionID, err = idToken.GetStringClaim("sid") - case openIDconfig.GetCheckSessionIframe(): - sessionID, err = getSessionStateFrom(params) - default: - sessionID, err = h.GenerateSessionID() - } - - if err != nil { - return "", err - } - - return sessionID, nil -} - -func getSessionStateFrom(params url.Values) (string, error) { - var sessionStateKey = "session_state" - sessionState := params.Get(sessionStateKey) - if sessionState == "" { - return "", fmt.Errorf("missing required '%s' in params", sessionStateKey) - } - return sessionState, nil -} - -func (h *Handler) GenerateSessionID() (string, error) { - rawID := make([]byte, 64) - - _, err := io.ReadFull(rand.Reader, rawID) - if err != nil { - return "", fmt.Errorf("generating session ID: %w", err) - } - - return base64.RawURLEncoding.EncodeToString(rawID), nil + return idToken.Validate(validateOpts...) } diff --git a/pkg/router/session_id.go b/pkg/router/session_id.go new file mode 100644 index 0000000..cfa3a81 --- /dev/null +++ b/pkg/router/session_id.go @@ -0,0 +1,54 @@ +package router + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "net/url" + + "github.com/nais/wonderwall/pkg/openid" +) + +const ( + SessionStateParamKey = "session_state" +) + +func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) { + var sessionID string + var err error + + switch { + case cfg.SidClaimRequired(): + sessionID, err = idToken.GetStringClaim("sid") + case cfg.SessionStateRequired(): + sessionID, err = getSessionStateFrom(params) + default: + sessionID, err = generateSessionID() + } + + if err != nil { + return "", err + } + + return sessionID, nil +} + +func getSessionStateFrom(params url.Values) (string, error) { + sessionState := params.Get(SessionStateParamKey) + if len(sessionState) == 0 { + return "", fmt.Errorf("missing required '%s' in params", SessionStateParamKey) + } + return sessionState, nil +} + +func generateSessionID() (string, error) { + rawID := make([]byte, 64) + + _, err := io.ReadFull(rand.Reader, rawID) + if err != nil { + return "", fmt.Errorf("generating session ID: %w", err) + } + + return base64.RawURLEncoding.EncodeToString(rawID), nil +} diff --git a/pkg/router/session_id_test.go b/pkg/router/session_id_test.go new file mode 100644 index 0000000..9125e10 --- /dev/null +++ b/pkg/router/session_id_test.go @@ -0,0 +1,108 @@ +package router_test + +import ( + "net/url" + "testing" + + "github.com/lestrrat-go/jwx/jwt" + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router" +) + +func TestSessionID(t *testing.T) { + for _, test := range []struct { + name string + config *openid.Configuration + idToken *openid.IDToken + params url.Values + want string + exactMatch bool + expectErr bool + }{ + { + name: "Support for front channel session with required sid claim", + config: sidRequired(), + idToken: idTokenWithSid("some-sid"), + want: "some-sid", + exactMatch: true, + }, + { + name: "Support for front channel session without required sid claim", + config: sidRequired(), + idToken: idTokenWithSid(""), + expectErr: true, + }, + { + name: "Support for session management with required param", + config: sessionStateRequired(), + params: params("session_state", "some-session"), + want: "some-session", + exactMatch: true, + }, + { + name: "Support for session management with missing required param", + config: sessionStateRequired(), + params: params("not_session_state", "some-session"), + expectErr: true, + }, + { + name: "No support for front-channel logout nor session management", + config: &openid.Configuration{}, + want: "some-session", + }, + } { + actual, err := router.SessionID(test.config, test.idToken, test.params) + + t.Run(test.name, func(t *testing.T) { + if test.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if test.exactMatch { + assert.Equal(t, test.want, actual) + } + + assert.NotEmpty(t, actual) + } + }) + } +} + +func sidRequired() *openid.Configuration { + return &openid.Configuration{ + FrontchannelLogoutSessionSupported: true, + FrontchannelLogoutSupported: true, + } +} + +func sessionStateRequired() *openid.Configuration { + return &openid.Configuration{ + CheckSessionIframe: "https://some-provider/some-endpoint", + } +} + +func params(key, value string) url.Values { + values := url.Values{} + if len(key) > 0 && len(value) > 0 { + values.Add(key, value) + } + return values +} + +func idTokenWithSid(sid string) *openid.IDToken { + idToken := jwt.New() + if len(sid) > 0 { + idToken.Set("sid", sid) + } + serialized, err := jwt.NewSerializer().Serialize(idToken) + if err != nil { + panic(err) + } + + return &openid.IDToken{ + Raw: string(serialized), + Token: idToken, + } +}