mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 15:22:58 +00:00
refactor: get or generate session ID with fallbacks
Turns out that Azure AD doesn't support the `check_session_iframe` property. However it still returns the session ID in the `session_state` parameter during callbacks, and optionally can be configured to return the `sid` claim in id_tokens. This commit changes the behaviour of the SessionID method to get the session ID if found, with the order of preference being: 1. from the `sid` claim in the id_token, 2. from the `session_state` parameter provided by the OP during callbacks If neither are found, and the OP's configuration does not indicate that either should be (e.g. no support for front-channel logout and/or session management), we fall back to generating our own session ID.
This commit is contained in:
@@ -15,22 +15,31 @@ const (
|
||||
)
|
||||
|
||||
func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) {
|
||||
var sessionID string
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case cfg.SidClaimRequired():
|
||||
sessionID, err = idToken.GetStringClaim("sid")
|
||||
case cfg.SessionStateRequired():
|
||||
sessionID, err = getSessionStateFrom(params)
|
||||
default:
|
||||
sessionID, err = generateSessionID()
|
||||
// 1. check for 'sid' claim in id_token
|
||||
sessionID, err := idToken.GetStringClaim("sid")
|
||||
if err == nil {
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// 1a. error if sid claim is required according to openid config
|
||||
if err != nil && cfg.SidClaimRequired() {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 2. check for session_state in callback params
|
||||
sessionID, err = getSessionStateFrom(params)
|
||||
if err == nil {
|
||||
return sessionID, nil
|
||||
}
|
||||
// 2a. error if session_state is required according to openid config
|
||||
if err != nil && cfg.SessionStateRequired() {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 3. generate ID if all else fails
|
||||
sessionID, err = generateSessionID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package router_test
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -28,29 +29,72 @@ func TestSessionID(t *testing.T) {
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for front channel session with required sid claim and session_state in param",
|
||||
config: sidRequired(),
|
||||
params: params("session_state", "some-session-state"),
|
||||
idToken: idTokenWithSid("some-sid"),
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for front channel session without required sid claim",
|
||||
config: sidRequired(),
|
||||
idToken: idTokenWithSid(""),
|
||||
idToken: idToken(),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Support for session management with required param",
|
||||
config: sessionStateRequired(),
|
||||
idToken: idToken(),
|
||||
params: params("session_state", "some-session"),
|
||||
want: "some-session",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for session management with required param and sid in id_token",
|
||||
config: sessionStateRequired(),
|
||||
idToken: idTokenWithSid("some-sid"),
|
||||
params: params("session_state", "some-session"),
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for session management with missing required param",
|
||||
config: sessionStateRequired(),
|
||||
idToken: idToken(),
|
||||
params: params("not_session_state", "some-session"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "No support for front-channel logout nor session management",
|
||||
config: &openid.Configuration{},
|
||||
want: "some-session",
|
||||
name: "No support for front-channel logout nor session management should generate session ID",
|
||||
config: &openid.Configuration{},
|
||||
idToken: idToken(),
|
||||
want: "some-generated-id",
|
||||
exactMatch: false,
|
||||
},
|
||||
{
|
||||
name: "No support for front-channel logout nor session management, sid in id_token",
|
||||
config: &openid.Configuration{},
|
||||
idToken: idTokenWithSid("some-sid"),
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "No support for front-channel logout nor session management, session_state in param",
|
||||
config: &openid.Configuration{},
|
||||
idToken: idToken(),
|
||||
params: params("session_state", "some-session-state"),
|
||||
want: "some-session-state",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "No support for front-channel logout nor session management, sid in id_token and session_state in param, sid should take precedence",
|
||||
config: &openid.Configuration{},
|
||||
idToken: idTokenWithSid("some-sid"),
|
||||
params: params("session_state", "some-session-state"),
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
} {
|
||||
actual, err := router.SessionID(test.config, test.idToken, test.params)
|
||||
@@ -91,11 +135,20 @@ func params(key, value string) url.Values {
|
||||
return values
|
||||
}
|
||||
|
||||
func idTokenWithSid(sid string) *openid.IDToken {
|
||||
func newIDToken(extraClaims map[string]string) *openid.IDToken {
|
||||
idToken := jwt.New()
|
||||
if len(sid) > 0 {
|
||||
idToken.Set("sid", sid)
|
||||
idToken.Set("sub", "test")
|
||||
idToken.Set("iss", "test")
|
||||
idToken.Set("aud", "test")
|
||||
idToken.Set("iat", time.Now().Unix())
|
||||
idToken.Set("exp", time.Now().Add(time.Hour).Unix())
|
||||
|
||||
for claim, value := range extraClaims {
|
||||
if len(claim) > 0 {
|
||||
idToken.Set(claim, value)
|
||||
}
|
||||
}
|
||||
|
||||
serialized, err := jwt.NewSerializer().Serialize(idToken)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -106,3 +159,13 @@ func idTokenWithSid(sid string) *openid.IDToken {
|
||||
Token: idToken,
|
||||
}
|
||||
}
|
||||
|
||||
func idTokenWithSid(sid string) *openid.IDToken {
|
||||
return newIDToken(map[string]string{
|
||||
"sid": sid,
|
||||
})
|
||||
}
|
||||
|
||||
func idToken() *openid.IDToken {
|
||||
return newIDToken(nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user