mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-21 15:52:54 +00:00
refactor: split out session ID generation to own file, add tests
Co-Authored-By: Youssef Bel Mekki <youssef.bel.mekki@nav.no>
This commit is contained in:
@@ -94,7 +94,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
v := url.Values{}
|
||||
v.Set("code", code)
|
||||
v.Set("state", state)
|
||||
if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() {
|
||||
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
|
||||
v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host)))
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
v := url.Values{}
|
||||
v.Set("error", "Unauthenticated")
|
||||
v.Set("error_description", "invalid client assertion")
|
||||
if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() {
|
||||
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
|
||||
v.Set("session_state", ip.SessionStates[clientID])
|
||||
}
|
||||
v.Encode()
|
||||
@@ -206,7 +206,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
idToken.Set("exp", time.Now().Unix()+expires)
|
||||
|
||||
// If the sid claim should be in token and in active session
|
||||
if !ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
|
||||
if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
|
||||
idToken.Set("sid", sid)
|
||||
ip.Sessions[sid] = clientID
|
||||
}
|
||||
@@ -225,7 +225,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
ExpiresIn: expires,
|
||||
}
|
||||
|
||||
if ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() {
|
||||
if ip.Provider.OpenIDConfiguration.SessionStateRequired() {
|
||||
sessionState := ip.SessionStates[clientID]
|
||||
token.SessionState = sessionState
|
||||
ip.Sessions[sessionState] = clientID
|
||||
|
||||
@@ -73,7 +73,7 @@ func (c *Configuration) FetchJwkSet(ctx context.Context) (*jwk.Set, error) {
|
||||
return &jwkSet, nil
|
||||
}
|
||||
|
||||
func (c *Configuration) GetCheckSessionIframe() bool {
|
||||
func (c *Configuration) SessionStateRequired() bool {
|
||||
return len(c.CheckSessionIframe) > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@ package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -49,12 +46,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.validateIDToken(idToken, loginCookie, params)
|
||||
err = h.validateIDToken(idToken, loginCookie, params)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := SessionID(h.Provider.GetOpenIDConfiguration(), idToken, params)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: generating session ID: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = h.createSession(w, r, sessionID, tokens, idToken)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
@@ -86,7 +89,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid.
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) (string, error) {
|
||||
func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) error {
|
||||
openIDconfig := h.Provider.GetOpenIDConfiguration()
|
||||
clientConfig := h.Provider.GetClientConfiguration()
|
||||
|
||||
@@ -105,56 +108,5 @@ func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.L
|
||||
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
|
||||
}
|
||||
|
||||
err := idToken.Validate(validateOpts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionID, err := h.SessionId(idToken, params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting external session ID from id_token: %w", err)
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
func (h *Handler) SessionId(idToken *openid.IDToken, params url.Values) (string, error) {
|
||||
var openIDconfig = h.Provider.GetOpenIDConfiguration()
|
||||
var sessionID string
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case openIDconfig.SidClaimRequired():
|
||||
sessionID, err = idToken.GetStringClaim("sid")
|
||||
case openIDconfig.GetCheckSessionIframe():
|
||||
sessionID, err = getSessionStateFrom(params)
|
||||
default:
|
||||
sessionID, err = h.GenerateSessionID()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
func getSessionStateFrom(params url.Values) (string, error) {
|
||||
var sessionStateKey = "session_state"
|
||||
sessionState := params.Get(sessionStateKey)
|
||||
if sessionState == "" {
|
||||
return "", fmt.Errorf("missing required '%s' in params", sessionStateKey)
|
||||
}
|
||||
return sessionState, nil
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateSessionID() (string, error) {
|
||||
rawID := make([]byte, 64)
|
||||
|
||||
_, err := io.ReadFull(rand.Reader, rawID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generating session ID: %w", err)
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(rawID), nil
|
||||
return idToken.Validate(validateOpts...)
|
||||
}
|
||||
|
||||
54
pkg/router/session_id.go
Normal file
54
pkg/router/session_id.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionStateParamKey = "session_state"
|
||||
)
|
||||
|
||||
func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) {
|
||||
var sessionID string
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case cfg.SidClaimRequired():
|
||||
sessionID, err = idToken.GetStringClaim("sid")
|
||||
case cfg.SessionStateRequired():
|
||||
sessionID, err = getSessionStateFrom(params)
|
||||
default:
|
||||
sessionID, err = generateSessionID()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
func getSessionStateFrom(params url.Values) (string, error) {
|
||||
sessionState := params.Get(SessionStateParamKey)
|
||||
if len(sessionState) == 0 {
|
||||
return "", fmt.Errorf("missing required '%s' in params", SessionStateParamKey)
|
||||
}
|
||||
return sessionState, nil
|
||||
}
|
||||
|
||||
func generateSessionID() (string, error) {
|
||||
rawID := make([]byte, 64)
|
||||
|
||||
_, err := io.ReadFull(rand.Reader, rawID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generating session ID: %w", err)
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(rawID), nil
|
||||
}
|
||||
108
pkg/router/session_id_test.go
Normal file
108
pkg/router/session_id_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
)
|
||||
|
||||
func TestSessionID(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
config *openid.Configuration
|
||||
idToken *openid.IDToken
|
||||
params url.Values
|
||||
want string
|
||||
exactMatch bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Support for front channel session with required sid claim",
|
||||
config: sidRequired(),
|
||||
idToken: idTokenWithSid("some-sid"),
|
||||
want: "some-sid",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for front channel session without required sid claim",
|
||||
config: sidRequired(),
|
||||
idToken: idTokenWithSid(""),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Support for session management with required param",
|
||||
config: sessionStateRequired(),
|
||||
params: params("session_state", "some-session"),
|
||||
want: "some-session",
|
||||
exactMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Support for session management with missing required param",
|
||||
config: sessionStateRequired(),
|
||||
params: params("not_session_state", "some-session"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "No support for front-channel logout nor session management",
|
||||
config: &openid.Configuration{},
|
||||
want: "some-session",
|
||||
},
|
||||
} {
|
||||
actual, err := router.SessionID(test.config, test.idToken, test.params)
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if test.exactMatch {
|
||||
assert.Equal(t, test.want, actual)
|
||||
}
|
||||
|
||||
assert.NotEmpty(t, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func sidRequired() *openid.Configuration {
|
||||
return &openid.Configuration{
|
||||
FrontchannelLogoutSessionSupported: true,
|
||||
FrontchannelLogoutSupported: true,
|
||||
}
|
||||
}
|
||||
|
||||
func sessionStateRequired() *openid.Configuration {
|
||||
return &openid.Configuration{
|
||||
CheckSessionIframe: "https://some-provider/some-endpoint",
|
||||
}
|
||||
}
|
||||
|
||||
func params(key, value string) url.Values {
|
||||
values := url.Values{}
|
||||
if len(key) > 0 && len(value) > 0 {
|
||||
values.Add(key, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func idTokenWithSid(sid string) *openid.IDToken {
|
||||
idToken := jwt.New()
|
||||
if len(sid) > 0 {
|
||||
idToken.Set("sid", sid)
|
||||
}
|
||||
serialized, err := jwt.NewSerializer().Serialize(idToken)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &openid.IDToken{
|
||||
Raw: string(serialized),
|
||||
Token: idToken,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user