refactor(session): extract external ID function to separate file

This commit is contained in:
Trong Huu Nguyen
2023-02-20 12:39:46 +01:00
parent c6d3d11072
commit 94d4b1a524
2 changed files with 66 additions and 7 deletions

51
pkg/session/id.go Normal file
View File

@@ -0,0 +1,51 @@
package session
import (
"fmt"
"net/http"
"github.com/nais/wonderwall/pkg/openid"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/strings"
)
// ExternalID returns the external session ID, derived from the given request or id_token; e.g. `sid` or `session_state`.
// If none are present, a generated ID is returned.
func ExternalID(r *http.Request, cfg openidconfig.Provider, idToken *openid.IDToken) (string, error) {
// 1. check for 'sid' claim in id_token
sessionID, err := idToken.GetSidClaim()
if err == nil {
return sessionID, nil
}
// 1a. error if sid claim is required according to openid config
if err != nil && cfg.SidClaimRequired() {
return "", err
}
// 2. check for session_state in callback params
sessionID, err = getSessionStateFrom(r)
if err == nil {
return sessionID, nil
}
// 2a. error if session_state is required according to openid config
if err != nil && cfg.SessionStateRequired() {
return "", err
}
// 3. generate ID if all else fails
sessionID, err = strings.GenerateBase64(64)
if err != nil {
return "", fmt.Errorf("generating session ID: %w", err)
}
return sessionID, nil
}
func getSessionStateFrom(r *http.Request) (string, error) {
params := r.URL.Query()
sessionState := params.Get(openid.SessionState)
if len(sessionState) == 0 {
return "", fmt.Errorf("missing required '%s' in params", openid.SessionState)
}
return sessionState, nil
}

View File

@@ -1,6 +1,8 @@
package session_test
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
@@ -10,14 +12,14 @@ import (
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/session"
)
func TestNewSessionID(t *testing.T) {
func TestExternalID(t *testing.T) {
for _, test := range []struct {
name string
config config.Provider
config openidconfig.Provider
idToken *openid.IDToken
params url.Values
want string
@@ -99,7 +101,13 @@ func TestNewSessionID(t *testing.T) {
exactMatch: true,
},
} {
actual, err := session.NewSessionID(test.config, test.idToken, test.params)
req := httptest.NewRequest(http.MethodGet, "https://wonderwall/callback", nil)
if test.params != nil {
req.URL.RawQuery = test.params.Encode()
}
actual, err := session.ExternalID(req, test.config, test.idToken)
t.Run(test.name, func(t *testing.T) {
if test.expectErr {
@@ -123,18 +131,18 @@ func testConfiguration() *mock.TestConfiguration {
return idp.OpenIDConfig
}
func standardConfig() config.Provider {
func standardConfig() openidconfig.Provider {
return testConfiguration().Provider()
}
func sidRequired() config.Provider {
func sidRequired() openidconfig.Provider {
cfg := testConfiguration()
cfg.TestProvider.WithFrontChannelLogoutSupport()
return cfg.Provider()
}
func sessionStateRequired() config.Provider {
func sessionStateRequired() openidconfig.Provider {
endpoint := "https://some-provider/some-endpoint"
cfg := testConfiguration()