diff --git a/pkg/session/id.go b/pkg/session/id.go new file mode 100644 index 0000000..b283941 --- /dev/null +++ b/pkg/session/id.go @@ -0,0 +1,51 @@ +package session + +import ( + "fmt" + "net/http" + + "github.com/nais/wonderwall/pkg/openid" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" + "github.com/nais/wonderwall/pkg/strings" +) + +// ExternalID returns the external session ID, derived from the given request or id_token; e.g. `sid` or `session_state`. +// If none are present, a generated ID is returned. +func ExternalID(r *http.Request, cfg openidconfig.Provider, idToken *openid.IDToken) (string, error) { + // 1. check for 'sid' claim in id_token + sessionID, err := idToken.GetSidClaim() + if err == nil { + return sessionID, nil + } + // 1a. error if sid claim is required according to openid config + if err != nil && cfg.SidClaimRequired() { + return "", err + } + + // 2. check for session_state in callback params + sessionID, err = getSessionStateFrom(r) + if err == nil { + return sessionID, nil + } + // 2a. error if session_state is required according to openid config + if err != nil && cfg.SessionStateRequired() { + return "", err + } + + // 3. generate ID if all else fails + sessionID, err = strings.GenerateBase64(64) + if err != nil { + return "", fmt.Errorf("generating session ID: %w", err) + } + return sessionID, nil +} + +func getSessionStateFrom(r *http.Request) (string, error) { + params := r.URL.Query() + + sessionState := params.Get(openid.SessionState) + if len(sessionState) == 0 { + return "", fmt.Errorf("missing required '%s' in params", openid.SessionState) + } + return sessionState, nil +} diff --git a/pkg/session/handler_test.go b/pkg/session/id_test.go similarity index 89% rename from pkg/session/handler_test.go rename to pkg/session/id_test.go index 39899d2..8e08ddf 100644 --- a/pkg/session/handler_test.go +++ b/pkg/session/id_test.go @@ -1,6 +1,8 @@ package session_test import ( + "net/http" + "net/http/httptest" "net/url" "testing" "time" @@ -10,14 +12,14 @@ import ( "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/openid/config" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/session" ) -func TestNewSessionID(t *testing.T) { +func TestExternalID(t *testing.T) { for _, test := range []struct { name string - config config.Provider + config openidconfig.Provider idToken *openid.IDToken params url.Values want string @@ -99,7 +101,13 @@ func TestNewSessionID(t *testing.T) { exactMatch: true, }, } { - actual, err := session.NewSessionID(test.config, test.idToken, test.params) + req := httptest.NewRequest(http.MethodGet, "https://wonderwall/callback", nil) + + if test.params != nil { + req.URL.RawQuery = test.params.Encode() + } + + actual, err := session.ExternalID(req, test.config, test.idToken) t.Run(test.name, func(t *testing.T) { if test.expectErr { @@ -123,18 +131,18 @@ func testConfiguration() *mock.TestConfiguration { return idp.OpenIDConfig } -func standardConfig() config.Provider { +func standardConfig() openidconfig.Provider { return testConfiguration().Provider() } -func sidRequired() config.Provider { +func sidRequired() openidconfig.Provider { cfg := testConfiguration() cfg.TestProvider.WithFrontChannelLogoutSupport() return cfg.Provider() } -func sessionStateRequired() config.Provider { +func sessionStateRequired() openidconfig.Provider { endpoint := "https://some-provider/some-endpoint" cfg := testConfiguration()