diff --git a/pkg/mock/client_configuration.go b/pkg/mock/client_configuration.go new file mode 100644 index 0000000..72ebf68 --- /dev/null +++ b/pkg/mock/client_configuration.go @@ -0,0 +1,88 @@ +package mock + +import ( + "crypto/rand" + "crypto/rsa" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwk" + + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/scopes" +) + +type TestClientConfiguration struct { + ClientID string + ClientJWK jwk.Key + RedirectURI string + PostLogoutRedirectURI string + Scopes scopes.Scopes + ACRValues openid.OptionalConfiguration + UILocales openid.OptionalConfiguration + WellKnownURL string +} + +func (c TestClientConfiguration) GetRedirectURI() string { + return c.RedirectURI +} + +func (c TestClientConfiguration) GetClientID() string { + return c.ClientID +} + +func (c TestClientConfiguration) GetClientJWK() jwk.Key { + return c.ClientJWK +} + +func (c TestClientConfiguration) GetPostLogoutRedirectURI() string { + return c.PostLogoutRedirectURI +} + +func (c TestClientConfiguration) GetScopes() scopes.Scopes { + return c.Scopes +} + +func (c TestClientConfiguration) GetACRValues() openid.OptionalConfiguration { + return c.ACRValues +} + +func (c TestClientConfiguration) GetUILocales() openid.OptionalConfiguration { + return c.UILocales +} + +func (c TestClientConfiguration) GetWellKnownURL() string { + return c.WellKnownURL +} + +func clientConfiguration() TestClientConfiguration { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + key, err := jwk.New(privateKey) + if err != nil { + panic(err) + } + key.Set(jwk.AlgorithmKey, jwa.RS256) + key.Set(jwk.KeyTypeKey, jwa.RSA) + key.Set(jwk.KeyIDKey, uuid.New().String()) + + return TestClientConfiguration{ + ClientID: "client_id", + ClientJWK: key, + RedirectURI: "http://localhost/callback", + WellKnownURL: "", + UILocales: openid.OptionalConfiguration{ + Enabled: true, + Value: "nb", + }, + ACRValues: openid.OptionalConfiguration{ + Enabled: true, + Value: "Level4", + }, + PostLogoutRedirectURI: "", + Scopes: scopes.Defaults(), + } +} diff --git a/pkg/mock/idporten.go b/pkg/mock/handler.go similarity index 60% rename from pkg/mock/idporten.go rename to pkg/mock/handler.go index 1b704dc..feb8a26 100644 --- a/pkg/mock/idporten.go +++ b/pkg/mock/handler.go @@ -1,72 +1,40 @@ package mock import ( - "crypto/rand" - "crypto/rsa" "encoding/json" "fmt" "net/http" "net/url" "time" - "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" - - "github.com/nais/wonderwall/pkg/config" ) -type IDPorten struct { - Clients map[string]string - Config config.IDPorten - Codes map[string]AuthRequest - Keys jwk.Set +type identityProviderHandler struct { + Codes map[string]authorizeRequest + Provider TestProvider Sessions map[string]string } -type AuthRequest struct { +func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler { + return &identityProviderHandler{ + Codes: make(map[string]authorizeRequest), + Provider: provider, + Sessions: make(map[string]string), + } +} + +type authorizeRequest struct { AcrLevel string CodeChallenge string Locale string Nonce string } -func NewIDPorten(clients map[string]string, config config.IDPorten) *IDPorten { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - - key, err := jwk.New(privateKey) - if err != nil { - panic(err) - } - - err = jwk.AssignKeyID(key) - if err != nil { - panic(err) - } - - err = key.Set(jwk.AlgorithmKey, jwa.RS256) - if err != nil { - panic(err) - } - - keys := jwk.NewSet() - keys.Add(key) - - return &IDPorten{ - Clients: clients, - Codes: make(map[string]AuthRequest), - Config: config, - Keys: keys, - Sessions: make(map[string]string), - } -} - -type TokenJSON struct { +type tokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token"` @@ -74,8 +42,9 @@ type TokenJSON struct { IDToken string `json:"id_token"` } -func (ip *IDPorten) signToken(token jwt.Token) (string, error) { - signer, ok := ip.Keys.Get(0) +func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) { + privateJwkSet := *ip.Provider.PrivateJwkSet() + signer, ok := privateJwkSet.Get(0) if !ok { return "", fmt.Errorf("could not get signer") } @@ -88,7 +57,49 @@ func (ip *IDPorten) signToken(token jwt.Token) (string, error) { return string(signedToken), nil } -func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { +func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + state := query.Get("state") + redirect := query.Get("redirect_uri") + acrLevel := query.Get("acr_values") + codeChallenge := query.Get("code_challenge") + locale := query.Get("ui_locales") + nonce := query.Get("nonce") + + if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing required fields")) + return + } + + code := uuid.New().String() + ip.Codes[code] = authorizeRequest{ + AcrLevel: acrLevel, + CodeChallenge: codeChallenge, + Locale: locale, + Nonce: nonce, + } + + u, err := url.Parse(redirect) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("couldn't parse redirect uri")) + return + } + v := url.Values{} + v.Set("code", code) + v.Set("state", state) + + u.RawQuery = v.Encode() + + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) +} + +func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, _ *http.Request) { + json.NewEncoder(w).Encode(ip.Provider.GetPublicJwkSet()) +} + +func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -116,9 +127,48 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { sub := uuid.New().String() sid := uuid.New().String() + clientID := r.PostForm.Get("client_id") + if len(clientID) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing client_id")) + return + } + + clientAssertion := r.PostForm.Get("client_assertion") + if len(clientID) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing client_assertion")) + return + } + + clientJwk := ip.Provider.GetClientConfiguration().GetClientJWK() + clientJwkSet := jwk.NewSet() + clientJwkSet.Add(clientJwk) + publicClientJwkSet, err := jwk.PublicSetOf(clientJwkSet) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("failed to create public client jwk set")) + return + } + + opts := []jwt.ParseOption{ + jwt.WithValidate(true), + jwt.WithKeySet(publicClientJwkSet), + jwt.WithIssuer(ip.Provider.GetClientConfiguration().GetClientID()), + jwt.WithSubject(ip.Provider.GetClientConfiguration().GetClientID()), + jwt.WithClaimValue("scope", ip.Provider.GetClientConfiguration().GetScopes().String()), + jwt.WithAudience(ip.Provider.GetOpenIDConfiguration().Issuer), + } + _, err = jwt.Parse([]byte(clientAssertion), opts...) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(fmt.Sprintf("invalid client assertion: %+v", err))) + return + } + accessToken := jwt.New() accessToken.Set("sub", sub) - accessToken.Set("iss", ip.Config.WellKnown.Issuer) + accessToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer) accessToken.Set("acr", auth.AcrLevel) accessToken.Set("iat", time.Now().Unix()) accessToken.Set("exp", time.Now().Unix()+expires) @@ -131,8 +181,8 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { idToken := jwt.New() idToken.Set("sub", sub) - idToken.Set("iss", ip.Config.WellKnown.Issuer) - idToken.Set("aud", ip.Config.ClientID) + idToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer) + idToken.Set("aud", clientID) idToken.Set("locale", auth.Locale) idToken.Set("nonce", auth.Nonce) idToken.Set("acr", auth.AcrLevel) @@ -147,8 +197,8 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { return } - ip.Sessions[sid] = ip.Config.ClientID - token := &TokenJSON{ + ip.Sessions[sid] = clientID + token := &tokenResponse{ AccessToken: signedAccessToken, TokenType: "Bearer", IDToken: signedIdToken, @@ -159,59 +209,3 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(token) } - -func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - state := query.Get("state") - redirect := query.Get("redirect_uri") - acrLevel := query.Get("acr_values") - codeChallenge := query.Get("code_challenge") - locale := query.Get("ui_locales") - nonce := query.Get("nonce") - - if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing required fields")) - return - } - - code := uuid.New().String() - ip.Codes[code] = AuthRequest{ - AcrLevel: acrLevel, - CodeChallenge: codeChallenge, - Locale: locale, - Nonce: nonce, - } - - u, err := url.Parse(redirect) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("couldn't parse redirect uri")) - return - } - v := url.Values{} - v.Set("code", code) - v.Set("state", state) - - u.RawQuery = v.Encode() - - http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) -} - -func (ip *IDPorten) Jwks(w http.ResponseWriter, r *http.Request) { - publicSet, err := jwk.PublicSetOf(ip.Keys) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("could not create public set: " + err.Error())) - return - } - json.NewEncoder(w).Encode(publicSet) -} - -func IDPortenRouter(ip *IDPorten) chi.Router { - r := chi.NewRouter() - r.Get("/authorize", ip.Authorize) - r.Post("/token", ip.Token) - r.Get("/jwks", ip.Jwks) - return r -} diff --git a/pkg/mock/provider.go b/pkg/mock/provider.go new file mode 100644 index 0000000..a60997f --- /dev/null +++ b/pkg/mock/provider.go @@ -0,0 +1,50 @@ +package mock + +import ( + "github.com/lestrrat-go/jwx/jwk" + log "github.com/sirupsen/logrus" + + "github.com/nais/wonderwall/pkg/jwks" + "github.com/nais/wonderwall/pkg/openid" +) + +type TestProvider struct { + ClientConfiguration *TestClientConfiguration + OpenIDConfiguration *openid.Configuration + JwksPair *jwks.Pair +} + +func (p TestProvider) GetClientConfiguration() openid.ClientConfiguration { + return p.ClientConfiguration +} + +func (p TestProvider) GetOpenIDConfiguration() *openid.Configuration { + return p.OpenIDConfiguration +} + +func (p TestProvider) GetPublicJwkSet() *jwk.Set { + return &p.JwksPair.Public +} + +func (p TestProvider) PrivateJwkSet() *jwk.Set { + return &p.JwksPair.Private +} + +func NewTestProvider() TestProvider { + jwksPair, err := jwks.NewJwksPair() + if err != nil { + log.Fatal(err) + } + + clientCfg := clientConfiguration() + provider := TestProvider{ + ClientConfiguration: &clientCfg, + OpenIDConfiguration: &openid.Configuration{ + ACRValuesSupported: openid.Supported{"Level3", "Level4"}, + UILocalesSupported: openid.Supported{"nb", "nb", "en", "se"}, + }, + JwksPair: jwksPair, + } + + return provider +} diff --git a/pkg/mock/router.go b/pkg/mock/router.go new file mode 100644 index 0000000..ec48324 --- /dev/null +++ b/pkg/mock/router.go @@ -0,0 +1,11 @@ +package mock + +import "github.com/go-chi/chi/v5" + +func identityProviderRouter(ip *identityProviderHandler) chi.Router { + r := chi.NewRouter() + r.Get("/authorize", ip.Authorize) + r.Post("/token", ip.Token) + r.Get("/jwks", ip.Jwks) + return r +} diff --git a/pkg/mock/server.go b/pkg/mock/server.go new file mode 100644 index 0000000..8525f49 --- /dev/null +++ b/pkg/mock/server.go @@ -0,0 +1,20 @@ +package mock + +import ( + "net/http/httptest" +) + +func IdentityProviderServer() (*httptest.Server, TestProvider) { + provider := NewTestProvider() + handler := newIdentityProviderHandler(provider) + router := identityProviderRouter(handler) + server := httptest.NewServer(router) + + provider.OpenIDConfiguration.Issuer = server.URL + provider.OpenIDConfiguration.JwksURI = server.URL + "/jwks" + provider.OpenIDConfiguration.AuthorizationEndpoint = server.URL + "/authorize" + provider.OpenIDConfiguration.TokenEndpoint = server.URL + "/token" + provider.OpenIDConfiguration.EndSessionEndpoint = server.URL + "/endsession" + + return server, provider +} diff --git a/pkg/router/login_url_test.go b/pkg/router/login_url_test.go index ea3690a..111907f 100644 --- a/pkg/router/login_url_test.go +++ b/pkg/router/login_url_test.go @@ -47,14 +47,14 @@ func TestLoginURL(t *testing.T) { for _, test := range tests { t.Run(test.url, func(t *testing.T) { - cfg := defaultConfig() req, err := http.NewRequest("GET", test.url, nil) assert.NoError(t, err) params, err := openid.GenerateLoginParameters() assert.NoError(t, err) - handler := handler(cfg) + provider := mock.NewTestProvider() + handler := handler(provider) _, err = handler.LoginURL(req, params) if test.error != nil { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 167d9e7..1abfbb1 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -1,14 +1,8 @@ package router_test import ( - "context" - "crypto/rand" - "crypto/rsa" "encoding/base64" - "encoding/json" "fmt" - "github.com/google/uuid" - "github.com/lestrrat-go/jwx/jwa" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -16,85 +10,31 @@ import ( "testing" "time" - "github.com/lestrrat-go/jwx/jwk" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cryptutil" "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/provider" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" ) -const clientID = "clientid" - -var encryptionKey = []byte(`G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`) // 256 bits AES - -var clients = map[string]string{ - clientID: "http://localhost/oauth2/logout/frontchannel", +var cfg = config.Config{ + EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES + Ingress: "/", + OpenID: config.OpenID{ + Provider: "test", + }, + SessionMaxLifetime: time.Hour, } -func defaultConfig() config.Config { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - - key, err := jwk.New(privateKey) - if err != nil { - panic(err) - } - key.Set(jwk.AlgorithmKey, jwa.RS256) - key.Set(jwk.KeyTypeKey, jwa.RSA) - key.Set(jwk.KeyIDKey, uuid.New().String()) - - clientJwk, err := json.Marshal(key) - if err != nil { - panic(err) - } - - return config.Config{IDPorten: config.IDPorten{ - ClientID: clientID, - ClientJWK: string(clientJwk), - RedirectURI: "http://localhost/callback", - WellKnownURL: "", - WellKnown: config.IDPortenWellKnown{ - Issuer: "issuer", - AuthorizationEndpoint: "http://localhost:1234/authorize", - ACRValuesSupported: config.Supported{"Level3", "Level4"}, - UILocalesSupported: config.Supported{"nb", "nb", "en", "se"}, - }, - Locale: config.IDPortenLocale{ - Enabled: true, - Value: "nb", - }, - SecurityLevel: config.IDPortenSecurityLevel{ - Enabled: true, - Value: "Level4", - }, - PostLogoutRedirectURI: "", - SessionMaxLifetime: time.Hour, - }} -} - -func handler(cfg config.Config) *router.Handler { - var jwkSet jwk.Set - var err error - - if len(cfg.IDPorten.WellKnown.JwksURI) == 0 { - jwk.NewSet() - } else { - jwkSet, err = jwk.Fetch(context.Background(), cfg.IDPorten.WellKnown.JwksURI) - } - if err != nil { - panic(err) - } - - crypter := cryptutil.New(encryptionKey) +func handler(provider provider.Provider) *router.Handler { + crypter := cryptutil.New([]byte(cfg.EncryptionKey)) sessionStore := session.NewMemory() - handler, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, jwkSet, sessionStore, "") + handler, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, provider, sessionStore) if err != nil { panic(err) } @@ -102,9 +42,8 @@ func handler(cfg config.Config) *router.Handler { } func TestHandler_Login(t *testing.T) { - cfg := defaultConfig() - - h := handler(cfg) + idpserver, idp := mock.IdentityProviderServer() + h := handler(idp) r := router.New(h) jar, err := cookiejar.New(nil) @@ -117,41 +56,38 @@ func TestHandler_Login(t *testing.T) { return http.ErrUseLastResponse } - idprouter := mock.IDPortenRouter(mock.NewIDPorten(clients, cfg.IDPorten)) - idpserver := httptest.NewServer(idprouter) - - h.Config.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" - loginURL, err := url.Parse(server.URL + "/oauth2/login") assert.NoError(t, err) - req, err := client.Get(loginURL.String()) + resp, err := client.Get(loginURL.String()) assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() cookies := client.Jar.Cookies(loginURL) loginCookie := getCookieFromJar(h.GetLoginCookieName(), cookies) assert.NotNil(t, loginCookie) - location := req.Header.Get("location") + location := resp.Header.Get("location") u, err := url.Parse(location) assert.NoError(t, err) assert.Equal(t, idpserver.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host)) assert.Equal(t, "/authorize", u.Path) - assert.Equal(t, cfg.IDPorten.SecurityLevel.Value, u.Query().Get("acr_values")) - assert.Equal(t, cfg.IDPorten.Locale.Value, u.Query().Get("ui_locales")) - assert.Equal(t, cfg.IDPorten.ClientID, u.Query().Get("client_id")) - assert.Equal(t, cfg.IDPorten.RedirectURI, u.Query().Get("redirect_uri")) + assert.Equal(t, idp.GetClientConfiguration().GetACRValues().Value, u.Query().Get("acr_values")) + assert.Equal(t, idp.GetClientConfiguration().GetUILocales().Value, u.Query().Get("ui_locales")) + assert.Equal(t, idp.GetClientConfiguration().GetClientID(), u.Query().Get("client_id")) + assert.Equal(t, idp.GetClientConfiguration().GetRedirectURI(), u.Query().Get("redirect_uri")) assert.NotEmpty(t, u.Query().Get("state")) assert.NotEmpty(t, u.Query().Get("nonce")) assert.NotEmpty(t, u.Query().Get("code_challenge")) - req, err = client.Get(u.String()) + resp, err = client.Get(u.String()) assert.NoError(t, err) - defer req.Body.Close() + defer resp.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - location = req.Header.Get("location") + location = resp.Header.Get("location") callbackURL, err := url.Parse(location) assert.NoError(t, err) @@ -160,21 +96,14 @@ func TestHandler_Login(t *testing.T) { } func TestHandler_Callback_and_Logout(t *testing.T) { - cfg := defaultConfig() + idpserver, idp := mock.IdentityProviderServer() - idprouter := mock.IDPortenRouter(mock.NewIDPorten(clients, cfg.IDPorten)) - idpserver := httptest.NewServer(idprouter) - cfg.IDPorten.WellKnown.JwksURI = idpserver.URL + "/jwks" - cfg.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" - cfg.IDPorten.WellKnown.TokenEndpoint = idpserver.URL + "/token" - cfg.IDPorten.WellKnown.EndSessionEndpoint = idpserver.URL + "/endsession" - - h := handler(cfg) + h := handler(idp) r := router.New(h) server := httptest.NewServer(r) - h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback" - h.Config.IDPorten.PostLogoutRedirectURI = server.URL + idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback" + idp.ClientConfiguration.PostLogoutRedirectURI = server.URL jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -187,9 +116,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) { // First, run /oauth2/login to set cookies loginURL, err := url.Parse(server.URL + "/oauth2/login") - req, err := client.Get(loginURL.String()) + resp, err := client.Get(loginURL.String()) assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() cookies := client.Jar.Cookies(loginURL) sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies) @@ -199,23 +129,25 @@ func TestHandler_Callback_and_Logout(t *testing.T) { assert.NotNil(t, loginCookie) // Get authorization URL - location := req.Header.Get("location") + location := resp.Header.Get("location") u, err := url.Parse(location) assert.NoError(t, err) - // Follow redirect to authorize with idporten - req, err = client.Get(u.String()) + // Follow redirect to authorize with identity provider + resp, err = client.Get(u.String()) assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() // Get callback URL after successful auth - location = req.Header.Get("location") + location = resp.Header.Get("location") callbackURL, err := url.Parse(location) assert.NoError(t, err) // Follow redirect to callback - req, err = client.Get(callbackURL.String()) + resp, err = client.Get(callbackURL.String()) assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies = client.Jar.Cookies(callbackURL) sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies) @@ -228,9 +160,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) { logoutURL, err := url.Parse(server.URL + "/oauth2/logout") assert.NoError(t, err) - req, err = client.Get(logoutURL.String()) + resp, err = client.Get(logoutURL.String()) assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() cookies = client.Jar.Cookies(logoutURL) sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies) @@ -238,7 +171,7 @@ func TestHandler_Callback_and_Logout(t *testing.T) { assert.Nil(t, sessionCookie) // Get endsession endpoint after local logout - location = req.Header.Get("location") + location = resp.Header.Get("location") endsessionURL, err := url.Parse(location) assert.NoError(t, err) @@ -249,27 +182,18 @@ func TestHandler_Callback_and_Logout(t *testing.T) { assert.Equal(t, idpserverURL.Host, endsessionURL.Host) assert.Equal(t, "/endsession", endsessionURL.Path) - assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{h.Config.IDPorten.PostLogoutRedirectURI}) + assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetPostLogoutRedirectURI()}) assert.NotEmpty(t, endsessionParams["id_token_hint"]) } func TestHandler_FrontChannelLogout(t *testing.T) { - cfg := defaultConfig() - - idp := mock.NewIDPorten(clients, cfg.IDPorten) - idprouter := mock.IDPortenRouter(idp) - idpserver := httptest.NewServer(idprouter) - - cfg.IDPorten.WellKnown.JwksURI = idpserver.URL + "/jwks" - cfg.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" - cfg.IDPorten.WellKnown.TokenEndpoint = idpserver.URL + "/token" - - h := handler(cfg) + _, idp := mock.IdentityProviderServer() + h := handler(idp) r := router.New(h) server := httptest.NewServer(r) - h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback" - h.Config.IDPorten.PostLogoutRedirectURI = server.URL + idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback" + idp.ClientConfiguration.PostLogoutRedirectURI = server.URL jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -281,28 +205,31 @@ func TestHandler_FrontChannelLogout(t *testing.T) { } // First, run /oauth2/login to set cookies - req, err := client.Get(server.URL + "/oauth2/login") + resp, err := client.Get(server.URL + "/oauth2/login") assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() // Get authorization URL - location := req.Header.Get("location") + location := resp.Header.Get("location") u, err := url.Parse(location) assert.NoError(t, err) // Follow redirect to authorize with idporten - req, err = client.Get(u.String()) + resp, err = client.Get(u.String()) assert.NoError(t, err) - defer req.Body.Close() + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() // Get callback URL after successful auth - location = req.Header.Get("location") + location = resp.Header.Get("location") callbackURL, err := url.Parse(location) assert.NoError(t, err) // Follow redirect to callback - req, err = client.Get(callbackURL.String()) + resp, err = client.Get(callbackURL.String()) assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := client.Jar.Cookies(callbackURL) sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies) @@ -323,14 +250,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) { values := url.Values{} values.Add("sid", string(sid)) - values.Add("iss", h.Config.IDPorten.WellKnown.Issuer) + values.Add("iss", idp.GetOpenIDConfiguration().Issuer) frontchannelLogoutURL.RawQuery = values.Encode() - req, err = client.Get(frontchannelLogoutURL.String()) + resp, err = client.Get(frontchannelLogoutURL.String()) assert.NoError(t, err) - defer req.Body.Close() - - assert.Equal(t, http.StatusOK, req.StatusCode) + assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() } func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index 755dc3e..beba328 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -2,18 +2,20 @@ package router_test import ( "encoding/base64" - "github.com/nais/wonderwall/pkg/router" - "github.com/nais/wonderwall/pkg/session" - "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "testing" "time" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/router" + "github.com/nais/wonderwall/pkg/session" ) func TestHandler_GetSessionFallback(t *testing.T) { - cfg := defaultConfig() - h := handler(cfg) + h := handler(mock.NewTestProvider()) t.Run("request without fallback session cookies", func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -53,8 +55,7 @@ func TestHandler_GetSessionFallback(t *testing.T) { } func TestHandler_SetSessionFallback(t *testing.T) { - cfg := defaultConfig() - h := handler(cfg) + h := handler(mock.NewTestProvider()) // request should set session cookies in response writer := httptest.NewRecorder() @@ -87,8 +88,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { } func TestHandler_DeleteSessionFallback(t *testing.T) { - cfg := defaultConfig() - h := handler(cfg) + h := handler(mock.NewTestProvider()) writer := httptest.NewRecorder() h.DeleteSessionFallback(writer)