diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 9244291..557b88c 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -29,6 +29,7 @@ type IdentityProvider struct { Cfg *config.Config OpenIDConfig Configuration Provider TestProvider + ProviderHandler *IdentityProviderHandler ProviderServer *httptest.Server RelyingPartyHandler *router.Handler RelyingPartyServer *httptest.Server @@ -87,11 +88,12 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider { RelyingPartyServer: rpServer, OpenIDConfig: openidConfig, Provider: provider, + ProviderHandler: handler, ProviderServer: server, } } -func identityProviderRouter(ip *identityProviderHandler) chi.Router { +func identityProviderRouter(ip *IdentityProviderHandler) chi.Router { r := chi.NewRouter() r.Get("/authorize", ip.Authorize) r.Post("/token", ip.Token) @@ -100,23 +102,23 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router { return r } -type identityProviderHandler struct { - Codes map[string]authorizeRequest +type IdentityProviderHandler struct { + Codes map[string]AuthorizeRequest Config openidconfig.Config Provider TestProvider Sessions map[string]string } -func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *identityProviderHandler { - return &identityProviderHandler{ - Codes: make(map[string]authorizeRequest), +func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *IdentityProviderHandler { + return &IdentityProviderHandler{ + Codes: make(map[string]AuthorizeRequest), Config: cfg, Provider: provider, Sessions: make(map[string]string), } } -type authorizeRequest struct { +type AuthorizeRequest struct { AcrLevel string CodeChallenge string Locale string @@ -131,7 +133,7 @@ type tokenResponse struct { IDToken string `json:"id_token"` } -func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) { +func (ip *IdentityProviderHandler) signToken(token jwt.Token) (string, error) { privateJwkSet := *ip.Provider.PrivateJwkSet() signer, ok := privateJwkSet.Key(0) if !ok { @@ -146,7 +148,7 @@ func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) { return string(signedToken), nil } -func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) { +func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() state := query.Get("state") redirect := query.Get("redirect_uri") @@ -162,7 +164,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ } code := uuid.New().String() - ip.Codes[code] = authorizeRequest{ + ip.Codes[code] = AuthorizeRequest{ AcrLevel: acrLevel, CodeChallenge: codeChallenge, Locale: locale, @@ -188,12 +190,12 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } -func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) { +func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) { jwks, _ := ip.Provider.GetPublicJwkSet(r.Context()) json.NewEncoder(w).Encode(jwks) } -func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) { +func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -302,10 +304,11 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) } token := &tokenResponse{ - AccessToken: signedAccessToken, - TokenType: "Bearer", - IDToken: signedIdToken, - ExpiresIn: expires, + AccessToken: signedAccessToken, + TokenType: "Bearer", + IDToken: signedIdToken, + RefreshToken: code + "some-refresh-token", + ExpiresIn: expires, } w.Header().Set("content-type", "application/json") @@ -313,7 +316,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) json.NewEncoder(w).Encode(token) } -func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { +func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() state := query.Get("state") postLogoutRedirectURI := query.Get("post_logout_redirect_uri") diff --git a/pkg/openid/client/client_test.go b/pkg/openid/client/client_test.go index 5440711..465e9fd 100644 --- a/pkg/openid/client/client_test.go +++ b/pkg/openid/client/client_test.go @@ -48,3 +48,11 @@ func TestMakeAssertion(t *testing.T) { assert.True(t, assertion.Expiration().After(time.Now())) assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry))) } + +func newTestClientWithConfig(config mock.Configuration) client.Client { + return client.NewClient(config) +} + +func newTestClient() client.Client { + return newTestClientWithConfig(mock.NewTestConfiguration(mock.Config())) +} diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go new file mode 100644 index 0000000..e456ddb --- /dev/null +++ b/pkg/openid/client/login_callback_test.go @@ -0,0 +1,170 @@ +package client_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/openid/client" +) + +func TestLoginCallback_StateMismatchError(t *testing.T) { + cookie := &openid.LoginCookie{ + State: "some-state", + } + + t.Run("invalid state", func(t *testing.T) { + url := "http://wonderwall/oauth2/callback?state=some-other-state" + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + + err := lc.StateMismatchError() + assert.Error(t, err) + }) + + t.Run("missing state", func(t *testing.T) { + url := "http://wonderwall/oauth2/callback" + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + + err := lc.StateMismatchError() + assert.Error(t, err) + }) +} + +func TestLoginCallback_IdentityProviderError(t *testing.T) { + cookie := &openid.LoginCookie{ + State: "some-state", + } + + url := "http://wonderwall/oauth2/callback?error=invalid_client&error_description=client%20authenticaion%20failed" + + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + + err := lc.IdentityProviderError() + assert.Error(t, err) +} + +func TestLoginCallback_ExchangeAuthCode(t *testing.T) { + t.Run("valid code", func(t *testing.T) { + cookie := &openid.LoginCookie{} + url := "http://wonderwall/oauth2/callback?code=some-code" + + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{ + "some-code": {}, + } + + tokens, err := lc.ExchangeAuthCode(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, tokens) + + assert.NotEmpty(t, tokens.AccessToken) + assert.NotEmpty(t, tokens.RefreshToken) + assert.NotEmpty(t, tokens.Extra("id_token")) + assert.NotEmpty(t, tokens.TokenType) + assert.NotEmpty(t, tokens.Expiry) + + assert.Equal(t, "Bearer", tokens.TokenType) + + assert.True(t, time.Now().Before(tokens.Expiry)) + assert.True(t, tokens.Expiry.Before(time.Now().Add(time.Hour))) + }) + + t.Run("invalid code", func(t *testing.T) { + cookie := &openid.LoginCookie{} + url := "http://wonderwall/oauth2/callback?code=some-code" + + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{ + "some-other-code": {}, + "another-code": {}, + } + + tokens, err := lc.ExchangeAuthCode(context.Background()) + assert.Error(t, err) + assert.Nil(t, tokens) + }) +} + +func TestLoginCallback_ProcessTokens(t *testing.T) { + cookie := &openid.LoginCookie{ + State: "some-state", + Nonce: "some-nonce", + } + url := "http://wonderwall/oauth2/callback?code=some-code" + + t.Run("happy path", func(t *testing.T) { + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{ + Nonce: "some-nonce", + } + + rawTokens, err := lc.ExchangeAuthCode(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, rawTokens) + + tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + assert.NoError(t, err) + assert.NotNil(t, tokens) + }) + + t.Run("nonce mismatch", func(t *testing.T) { + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{ + Nonce: "some-other-nonce", + } + + rawTokens, err := lc.ExchangeAuthCode(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, rawTokens) + + tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + assert.Error(t, err) + assert.Nil(t, tokens) + }) + + t.Run("unexpected audience", func(t *testing.T) { + idp, lc := newLoginCallback(t, url, cookie) + defer idp.Close() + idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{ + Nonce: "some-nonce", + } + idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id" + + rawTokens, err := lc.ExchangeAuthCode(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, rawTokens) + + tokens, err := lc.ProcessTokens(context.Background(), rawTokens) + assert.Error(t, err) + assert.Nil(t, tokens) + }) +} + +func newLoginCallback(t *testing.T, url string, cookie *openid.LoginCookie) (mock.IdentityProvider, client.LoginCallback) { + req, err := http.NewRequest("GET", url, nil) + assert.NoError(t, err) + + idp := mock.NewIdentityProvider(mock.Config()) + + cfg := idp.OpenIDConfig + cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI + cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI + cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint + + loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie) + assert.NoError(t, err) + + return idp, loginCallback +} diff --git a/pkg/openid/client/logout_callback_test.go b/pkg/openid/client/logout_callback_test.go new file mode 100644 index 0000000..57cd692 --- /dev/null +++ b/pkg/openid/client/logout_callback_test.go @@ -0,0 +1,81 @@ +package client_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/openid/client" +) + +func TestLogoutCallback_ValidateRequest(t *testing.T) { + t.Run("nil cookie", func(t *testing.T) { + _, err := newLogoutCallback(t, "http://localhost/oauth2/logout/callback?state=some-state", nil) + assert.Error(t, err) + }) + + for _, test := range []struct { + name string + url string + cookie *openid.LogoutCookie + wantErr bool + }{ + { + name: "valid request", + url: "http://localhost/oauth2/logout/callback?state=some-state", + cookie: &openid.LogoutCookie{ + State: "some-state", + RedirectTo: "http://some-url", + }, + wantErr: false, + }, + { + name: "empty redirect", + url: "http://localhost/oauth2/logout/callback?state=some-state", + cookie: &openid.LogoutCookie{ + State: "some-state", + RedirectTo: "", + }, + wantErr: true, + }, + { + name: "empty state", + url: "http://localhost/oauth2/logout/callback", + cookie: &openid.LogoutCookie{ + State: "some-state", + RedirectTo: "http://some-url", + }, + wantErr: true, + }, + { + name: "state mismatch", + url: "http://localhost/oauth2/logout/callback?state=some-other-state", + cookie: &openid.LogoutCookie{ + State: "some-state", + RedirectTo: "http://some-url", + }, + wantErr: true, + }, + } { + t.Run(test.name, func(t *testing.T) { + lc, err := newLogoutCallback(t, test.url, test.cookie) + assert.NoError(t, err) + + err = lc.ValidateRequest() + if test.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func newLogoutCallback(t *testing.T, url string, cookie *openid.LogoutCookie) (client.LogoutCallback, error) { + req, err := http.NewRequest("GET", url, nil) + assert.NoError(t, err) + + return newTestClient().LogoutCallback(req, cookie) +} diff --git a/pkg/openid/client/logout_frontchannel_test.go b/pkg/openid/client/logout_frontchannel_test.go new file mode 100644 index 0000000..3110225 --- /dev/null +++ b/pkg/openid/client/logout_frontchannel_test.go @@ -0,0 +1,35 @@ +package client_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/openid/client" +) + +func TestLogoutFrontchannel_Sid(t *testing.T) { + t.Run("missing sid parameter in request", func(t *testing.T) { + url := "http://localhost/oauth2/logout/frontchannel" + lf := newLogoutFrontchannel(t, url) + + assert.Empty(t, lf.Sid()) + assert.True(t, lf.MissingSidParameter()) + }) + + t.Run("has sid parameter in request", func(t *testing.T) { + url := "http://localhost/oauth2/logout/frontchannel?sid=some-session-id" + lf := newLogoutFrontchannel(t, url) + + assert.Equal(t, "some-session-id", lf.Sid()) + assert.False(t, lf.MissingSidParameter()) + }) +} + +func newLogoutFrontchannel(t *testing.T, url string) client.LogoutFrontchannel { + req, err := http.NewRequest("GET", url, nil) + assert.NoError(t, err) + + return newTestClient().LogoutFrontchannel(req) +} diff --git a/pkg/openid/client/logout_test.go b/pkg/openid/client/logout_test.go new file mode 100644 index 0000000..ebef131 --- /dev/null +++ b/pkg/openid/client/logout_test.go @@ -0,0 +1,105 @@ +package client_test + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid/client" +) + +const ( + LogoutCallbackURI = "http://wonderwall/oauth2/logout/callback" + PostLogoutRedirectURI = "http://some-other-url" + EndSessionEndpoint = "http://provider/endsession" +) + +func TestLogout_CanonicalRedirect(t *testing.T) { + logout := newLogout(t) + canonicalRedirect := logout.CanonicalRedirect() + + assert.Equal(t, PostLogoutRedirectURI, canonicalRedirect) +} + +func TestLogout_Cookie(t *testing.T) { + logout := newLogout(t) + cookie := logout.Cookie() + + assert.NotNil(t, cookie) + assert.NotEmpty(t, cookie.State) + assert.NotEmpty(t, cookie.RedirectTo) +} + +func TestLogout_SingleLogoutURL(t *testing.T) { + t.Run("with id_token", func(t *testing.T) { + logout := newLogout(t) + cookie := logout.Cookie() + + state := cookie.State + idToken := "some-id-token" + + raw := logout.SingleLogoutURL(idToken) + assert.NotEmpty(t, raw) + + logoutUrl, err := url.Parse(raw) + assert.NoError(t, err) + + query := logoutUrl.Query() + assert.Len(t, query, 3) + + assert.Contains(t, query, "id_token_hint") + assert.Equal(t, idToken, query.Get("id_token_hint")) + + assert.Contains(t, query, "state") + assert.Equal(t, state, query.Get("state")) + + assert.Contains(t, query, "post_logout_redirect_uri") + assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) + + logoutUrl.RawQuery = "" + assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) + }) + + t.Run("without id_token", func(t *testing.T) { + logout := newLogout(t) + cookie := logout.Cookie() + + state := cookie.State + idToken := "" + + raw := logout.SingleLogoutURL(idToken) + assert.NotEmpty(t, raw) + + logoutUrl, err := url.Parse(raw) + assert.NoError(t, err) + + query := logoutUrl.Query() + assert.Len(t, query, 2) + + assert.NotContains(t, query, "id_token_hint") + assert.Equal(t, idToken, query.Get("id_token_hint")) + + assert.Contains(t, query, "state") + assert.Equal(t, state, query.Get("state")) + + assert.Contains(t, query, "post_logout_redirect_uri") + assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) + + logoutUrl.RawQuery = "" + assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) + }) +} + +func newLogout(t *testing.T) client.Logout { + cfg := mock.NewTestConfiguration(mock.Config()) + cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI + cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI + cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint + + logout, err := newTestClientWithConfig(cfg).Logout() + assert.NoError(t, err) + + return logout +}