diff --git a/pkg/router/session_id.go b/pkg/router/session_id.go index cfa3a81..b0a7d1f 100644 --- a/pkg/router/session_id.go +++ b/pkg/router/session_id.go @@ -15,22 +15,31 @@ const ( ) func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) { - var sessionID string - var err error - - switch { - case cfg.SidClaimRequired(): - sessionID, err = idToken.GetStringClaim("sid") - case cfg.SessionStateRequired(): - sessionID, err = getSessionStateFrom(params) - default: - sessionID, err = generateSessionID() + // 1. check for 'sid' claim in id_token + sessionID, err := idToken.GetStringClaim("sid") + if err == nil { + return sessionID, nil } - - if err != nil { + // 1a. error if sid claim is required according to openid config + if err != nil && cfg.SidClaimRequired() { return "", err } + // 2. check for session_state in callback params + sessionID, err = getSessionStateFrom(params) + if err == nil { + return sessionID, nil + } + // 2a. error if session_state is required according to openid config + if err != nil && cfg.SessionStateRequired() { + return "", err + } + + // 3. generate ID if all else fails + sessionID, err = generateSessionID() + if err != nil { + return "", err + } return sessionID, nil } diff --git a/pkg/router/session_id_test.go b/pkg/router/session_id_test.go index 9125e10..956bcee 100644 --- a/pkg/router/session_id_test.go +++ b/pkg/router/session_id_test.go @@ -3,6 +3,7 @@ package router_test import ( "net/url" "testing" + "time" "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/assert" @@ -28,29 +29,72 @@ func TestSessionID(t *testing.T) { want: "some-sid", exactMatch: true, }, + { + name: "Support for front channel session with required sid claim and session_state in param", + config: sidRequired(), + params: params("session_state", "some-session-state"), + idToken: idTokenWithSid("some-sid"), + want: "some-sid", + exactMatch: true, + }, { name: "Support for front channel session without required sid claim", config: sidRequired(), - idToken: idTokenWithSid(""), + idToken: idToken(), expectErr: true, }, { name: "Support for session management with required param", config: sessionStateRequired(), + idToken: idToken(), params: params("session_state", "some-session"), want: "some-session", exactMatch: true, }, + { + name: "Support for session management with required param and sid in id_token", + config: sessionStateRequired(), + idToken: idTokenWithSid("some-sid"), + params: params("session_state", "some-session"), + want: "some-sid", + exactMatch: true, + }, { name: "Support for session management with missing required param", config: sessionStateRequired(), + idToken: idToken(), params: params("not_session_state", "some-session"), expectErr: true, }, { - name: "No support for front-channel logout nor session management", - config: &openid.Configuration{}, - want: "some-session", + name: "No support for front-channel logout nor session management should generate session ID", + config: &openid.Configuration{}, + idToken: idToken(), + want: "some-generated-id", + exactMatch: false, + }, + { + name: "No support for front-channel logout nor session management, sid in id_token", + config: &openid.Configuration{}, + idToken: idTokenWithSid("some-sid"), + want: "some-sid", + exactMatch: true, + }, + { + name: "No support for front-channel logout nor session management, session_state in param", + config: &openid.Configuration{}, + idToken: idToken(), + params: params("session_state", "some-session-state"), + want: "some-session-state", + exactMatch: true, + }, + { + name: "No support for front-channel logout nor session management, sid in id_token and session_state in param, sid should take precedence", + config: &openid.Configuration{}, + idToken: idTokenWithSid("some-sid"), + params: params("session_state", "some-session-state"), + want: "some-sid", + exactMatch: true, }, } { actual, err := router.SessionID(test.config, test.idToken, test.params) @@ -91,11 +135,20 @@ func params(key, value string) url.Values { return values } -func idTokenWithSid(sid string) *openid.IDToken { +func newIDToken(extraClaims map[string]string) *openid.IDToken { idToken := jwt.New() - if len(sid) > 0 { - idToken.Set("sid", sid) + idToken.Set("sub", "test") + idToken.Set("iss", "test") + idToken.Set("aud", "test") + idToken.Set("iat", time.Now().Unix()) + idToken.Set("exp", time.Now().Add(time.Hour).Unix()) + + for claim, value := range extraClaims { + if len(claim) > 0 { + idToken.Set(claim, value) + } } + serialized, err := jwt.NewSerializer().Serialize(idToken) if err != nil { panic(err) @@ -106,3 +159,13 @@ func idTokenWithSid(sid string) *openid.IDToken { Token: idToken, } } + +func idTokenWithSid(sid string) *openid.IDToken { + return newIDToken(map[string]string{ + "sid": sid, + }) +} + +func idToken() *openid.IDToken { + return newIDToken(nil) +}