diff --git a/pkg/mock/handler.go b/pkg/mock/handler.go index 3e2d8fe..6b1cf24 100644 --- a/pkg/mock/handler.go +++ b/pkg/mock/handler.go @@ -1,8 +1,6 @@ package mock import ( - "crypto/sha256" - "encoding/base64" "encoding/json" "fmt" "net/http" @@ -19,7 +17,6 @@ type identityProviderHandler struct { Codes map[string]authorizeRequest Provider TestProvider Sessions map[string]string - SessionStates map[string]string } func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler { @@ -27,7 +24,6 @@ func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler Codes: make(map[string]authorizeRequest), Provider: provider, Sessions: make(map[string]string), - SessionStates: make(map[string]string), } } @@ -44,7 +40,6 @@ type tokenResponse struct { RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` - SessionState string `json:"session_state"` } func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) { @@ -95,7 +90,8 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ v.Set("code", code) v.Set("state", state) if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() { - v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host))) + sessionID := uuid.New().String() + v.Set("session_state", sessionID) } u.RawQuery = v.Encode() @@ -172,9 +168,6 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) v := url.Values{} v.Set("error", "Unauthenticated") v.Set("error_description", "invalid client assertion") - if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() { - v.Set("session_state", ip.SessionStates[clientID]) - } v.Encode() w.Write([]byte(fmt.Sprintf(v.Encode()+"%+v", err))) return @@ -193,7 +186,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - sid := uuid.New().String() + sessionID := uuid.New().String() + ip.Sessions[sessionID] = clientID idToken := jwt.New() idToken.Set("sub", sub) @@ -206,9 +200,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) idToken.Set("exp", time.Now().Unix()+expires) // If the sid claim should be in token and in active session - if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() { - idToken.Set("sid", sid) - ip.Sessions[sid] = clientID + if ip.Provider.OpenIDConfiguration.SidClaimRequired() { + idToken.Set("sid", sessionID) } signedIdToken, err := ip.signToken(idToken) @@ -225,36 +218,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) ExpiresIn: expires, } - if ip.Provider.OpenIDConfiguration.SessionStateRequired() { - sessionState := ip.SessionStates[clientID] - token.SessionState = sessionState - ip.Sessions[sessionState] = clientID - } - w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(token) } -func (ip *identityProviderHandler) generateSessionState(state, originUrl string) string { - // Here, the session_state is calculated in this particular way, - // but it is entirely up to the OP how to do it under the - // requirements defined in this specification. - clientId := ip.Provider.ClientConfiguration.GetClientID() - salt := "some-salt" - saltedString := fmt.Sprintf("%s %s %s %s", clientId, state, originUrl, salt) - session := NewSHA256([]byte(saltedString)) - sessionState := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s.%s", session, NewSHA256([]byte(salt))))) - ip.SessionStates[clientId] = sessionState - return sessionState - -} - -func NewSHA256(data []byte) []byte { - hash := sha256.Sum256(data) - return hash[:] -} - -func (ip *identityProviderHandler) GetClientID(sessionID string) string { - return ip.Sessions[sessionID] -} diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 73d4ac9..c8e6409 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -6,7 +6,7 @@ import ( "github.com/go-chi/chi/v5" ) -func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *identityProviderHandler) { +func IdentityProviderServer() (*httptest.Server, TestProvider) { provider := NewTestProvider() handler := newIdentityProviderHandler(provider) router := identityProviderRouter(handler) @@ -18,14 +18,7 @@ func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *ident provider.OpenIDConfiguration.TokenEndpoint = server.URL + "/token" provider.OpenIDConfiguration.EndSessionEndpoint = server.URL + "/endsession" - if iframe { - provider.OpenIDConfiguration.CheckSessionIframe = server.URL + "/checksession" - } else { - provider.OpenIDConfiguration.FrontchannelLogoutSupported = true - provider.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true - } - - return server, provider, handler + return server, provider } func identityProviderRouter(ip *identityProviderHandler) chi.Router { diff --git a/pkg/mock/provider.go b/pkg/mock/provider.go index 0971a19..3e82b85 100644 --- a/pkg/mock/provider.go +++ b/pkg/mock/provider.go @@ -31,6 +31,17 @@ func (p TestProvider) PrivateJwkSet() *jwk.Set { return &p.JwksPair.Private } +func (p TestProvider) WithFrontChannelLogoutSupport() TestProvider { + p.OpenIDConfiguration.FrontchannelLogoutSupported = true + p.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true + return p +} + +func (p TestProvider) WithCheckSessionIFrameSupport(url string) TestProvider { + p.OpenIDConfiguration.CheckSessionIframe = url + return p +} + func NewTestProvider() TestProvider { jwksPair, err := crypto.NewJwkSet() if err != nil { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 210c95d..520ce08 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -7,7 +7,6 @@ import ( "net/http/cookiejar" "net/http/httptest" "net/url" - "strings" "testing" "time" @@ -45,7 +44,7 @@ func newHandler(provider openid.Provider) *router.Handler { } func TestHandler_Login(t *testing.T) { - idpserver, idp, _ := mock.IdentityProviderServer(false) + idpserver, idp := mock.IdentityProviderServer() h := newHandler(idp) r := router.New(h) @@ -101,7 +100,7 @@ func TestHandler_Login(t *testing.T) { } func TestHandler_Callback_and_Logout(t *testing.T) { - idpserver, idp, _ := mock.IdentityProviderServer(false) + idpserver, idp := mock.IdentityProviderServer() h := newHandler(idp) r := router.New(h) @@ -196,7 +195,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) { } func TestHandler_FrontChannelLogout(t *testing.T) { - _, idp, idpHandler := mock.IdentityProviderServer(false) + _, idp := mock.IdentityProviderServer() + idp.WithFrontChannelLogoutSupport() + h := newHandler(idp) r := router.New(h) server := httptest.NewServer(r) @@ -252,9 +253,6 @@ func TestHandler_FrontChannelLogout(t *testing.T) { sid, err := h.Crypter.Decrypt(ciphertext) assert.NoError(t, err) - clientID := idpHandler.GetClientID(parseSessionID(sid)) - assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID) - frontchannelLogoutURL, err := url.Parse(server.URL) assert.NoError(t, err) @@ -271,8 +269,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) { defer resp.Body.Close() } -func TestHandler_CheckSessionIframe(t *testing.T) { - _, idp, idpHandler := mock.IdentityProviderServer(true) +func TestHandler_SessionStateRequired(t *testing.T) { + idpServer, idp := mock.IdentityProviderServer() + idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession") h := newHandler(idp) r := router.New(h) server := httptest.NewServer(r) @@ -311,24 +310,9 @@ func TestHandler_CheckSessionIframe(t *testing.T) { callbackURL, err := url.Parse(location) assert.NoError(t, err) - // Follow redirect to callback - resp, err = client.Get(callbackURL.String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - cookies := client.Jar.Cookies(callbackURL) - sessionCookie := getCookieFromJar(router.SessionCookieName, cookies) - - assert.NotNil(t, sessionCookie) - - ciphertext, err := base64.StdEncoding.DecodeString(sessionCookie.Value) - assert.NoError(t, err) - - sessionState, err := h.Crypter.Decrypt(ciphertext) - assert.NoError(t, err) - - clientID := idpHandler.GetClientID(parseSessionID(sessionState)) - assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID) + params := callbackURL.Query() + sessionState := params.Get("session_state") + assert.NotEmpty(t, sessionState) } func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { @@ -340,7 +324,3 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { return nil } - -func parseSessionID(sessionID []byte) string { - return strings.Split(string(sessionID), ":")[2] -}