refactor: split out session ID generation to own file, add tests

Co-Authored-By: Youssef Bel Mekki <youssef.bel.mekki@nav.no>
This commit is contained in:
Trong Huu Nguyen
2022-01-25 15:33:42 +01:00
parent abc8bd1835
commit 24cae11ba2
5 changed files with 176 additions and 62 deletions

View File

@@ -94,7 +94,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
v := url.Values{}
v.Set("code", code)
v.Set("state", state)
if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() {
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host)))
}
@@ -172,7 +172,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
v := url.Values{}
v.Set("error", "Unauthenticated")
v.Set("error_description", "invalid client assertion")
if ip.Provider.GetOpenIDConfiguration().GetCheckSessionIframe() {
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
v.Set("session_state", ip.SessionStates[clientID])
}
v.Encode()
@@ -206,7 +206,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
idToken.Set("exp", time.Now().Unix()+expires)
// If the sid claim should be in token and in active session
if !ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
idToken.Set("sid", sid)
ip.Sessions[sid] = clientID
}
@@ -225,7 +225,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
ExpiresIn: expires,
}
if ip.Provider.OpenIDConfiguration.GetCheckSessionIframe() {
if ip.Provider.OpenIDConfiguration.SessionStateRequired() {
sessionState := ip.SessionStates[clientID]
token.SessionState = sessionState
ip.Sessions[sessionState] = clientID

View File

@@ -73,7 +73,7 @@ func (c *Configuration) FetchJwkSet(ctx context.Context) (*jwk.Set, error) {
return &jwkSet, nil
}
func (c *Configuration) GetCheckSessionIframe() bool {
func (c *Configuration) SessionStateRequired() bool {
return len(c.CheckSessionIframe) > 0
}

View File

@@ -2,10 +2,7 @@ package router
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"time"
@@ -49,12 +46,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
sessionID, err := h.validateIDToken(idToken, loginCookie, params)
err = h.validateIDToken(idToken, loginCookie, params)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err))
return
}
sessionID, err := SessionID(h.Provider.GetOpenIDConfiguration(), idToken, params)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: generating session ID: %w", err))
return
}
err = h.createSession(w, r, sessionID, tokens, idToken)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
@@ -86,7 +89,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid.
return tokens, nil
}
func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) (string, error) {
func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.LoginCookie, params url.Values) error {
openIDconfig := h.Provider.GetOpenIDConfiguration()
clientConfig := h.Provider.GetClientConfiguration()
@@ -105,56 +108,5 @@ func (h *Handler) validateIDToken(idToken *openid.IDToken, loginCookie *openid.L
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
}
err := idToken.Validate(validateOpts...)
if err != nil {
return "", err
}
sessionID, err := h.SessionId(idToken, params)
if err != nil {
return "", fmt.Errorf("getting external session ID from id_token: %w", err)
}
return sessionID, nil
}
func (h *Handler) SessionId(idToken *openid.IDToken, params url.Values) (string, error) {
var openIDconfig = h.Provider.GetOpenIDConfiguration()
var sessionID string
var err error
switch {
case openIDconfig.SidClaimRequired():
sessionID, err = idToken.GetStringClaim("sid")
case openIDconfig.GetCheckSessionIframe():
sessionID, err = getSessionStateFrom(params)
default:
sessionID, err = h.GenerateSessionID()
}
if err != nil {
return "", err
}
return sessionID, nil
}
func getSessionStateFrom(params url.Values) (string, error) {
var sessionStateKey = "session_state"
sessionState := params.Get(sessionStateKey)
if sessionState == "" {
return "", fmt.Errorf("missing required '%s' in params", sessionStateKey)
}
return sessionState, nil
}
func (h *Handler) GenerateSessionID() (string, error) {
rawID := make([]byte, 64)
_, err := io.ReadFull(rand.Reader, rawID)
if err != nil {
return "", fmt.Errorf("generating session ID: %w", err)
}
return base64.RawURLEncoding.EncodeToString(rawID), nil
return idToken.Validate(validateOpts...)
}

54
pkg/router/session_id.go Normal file
View File

@@ -0,0 +1,54 @@
package router
import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/url"
"github.com/nais/wonderwall/pkg/openid"
)
const (
SessionStateParamKey = "session_state"
)
func SessionID(cfg *openid.Configuration, idToken *openid.IDToken, params url.Values) (string, error) {
var sessionID string
var err error
switch {
case cfg.SidClaimRequired():
sessionID, err = idToken.GetStringClaim("sid")
case cfg.SessionStateRequired():
sessionID, err = getSessionStateFrom(params)
default:
sessionID, err = generateSessionID()
}
if err != nil {
return "", err
}
return sessionID, nil
}
func getSessionStateFrom(params url.Values) (string, error) {
sessionState := params.Get(SessionStateParamKey)
if len(sessionState) == 0 {
return "", fmt.Errorf("missing required '%s' in params", SessionStateParamKey)
}
return sessionState, nil
}
func generateSessionID() (string, error) {
rawID := make([]byte, 64)
_, err := io.ReadFull(rand.Reader, rawID)
if err != nil {
return "", fmt.Errorf("generating session ID: %w", err)
}
return base64.RawURLEncoding.EncodeToString(rawID), nil
}

View File

@@ -0,0 +1,108 @@
package router_test
import (
"net/url"
"testing"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
)
func TestSessionID(t *testing.T) {
for _, test := range []struct {
name string
config *openid.Configuration
idToken *openid.IDToken
params url.Values
want string
exactMatch bool
expectErr bool
}{
{
name: "Support for front channel session with required sid claim",
config: sidRequired(),
idToken: idTokenWithSid("some-sid"),
want: "some-sid",
exactMatch: true,
},
{
name: "Support for front channel session without required sid claim",
config: sidRequired(),
idToken: idTokenWithSid(""),
expectErr: true,
},
{
name: "Support for session management with required param",
config: sessionStateRequired(),
params: params("session_state", "some-session"),
want: "some-session",
exactMatch: true,
},
{
name: "Support for session management with missing required param",
config: sessionStateRequired(),
params: params("not_session_state", "some-session"),
expectErr: true,
},
{
name: "No support for front-channel logout nor session management",
config: &openid.Configuration{},
want: "some-session",
},
} {
actual, err := router.SessionID(test.config, test.idToken, test.params)
t.Run(test.name, func(t *testing.T) {
if test.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if test.exactMatch {
assert.Equal(t, test.want, actual)
}
assert.NotEmpty(t, actual)
}
})
}
}
func sidRequired() *openid.Configuration {
return &openid.Configuration{
FrontchannelLogoutSessionSupported: true,
FrontchannelLogoutSupported: true,
}
}
func sessionStateRequired() *openid.Configuration {
return &openid.Configuration{
CheckSessionIframe: "https://some-provider/some-endpoint",
}
}
func params(key, value string) url.Values {
values := url.Values{}
if len(key) > 0 && len(value) > 0 {
values.Add(key, value)
}
return values
}
func idTokenWithSid(sid string) *openid.IDToken {
idToken := jwt.New()
if len(sid) > 0 {
idToken.Set("sid", sid)
}
serialized, err := jwt.NewSerializer().Serialize(idToken)
if err != nil {
panic(err)
}
return &openid.IDToken{
Raw: string(serialized),
Token: idToken,
}
}