From 10dddd00bc969b3d3abdd1c6a964a9cfe10a3fdf Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 21 Jun 2022 10:48:54 +0200 Subject: [PATCH] refactor(router): begin extraction of openid client --- pkg/mock/client.go | 54 ++++++ pkg/mock/{handler.go => idp_handler.go} | 0 pkg/openid/client.go | 95 ++++++++++ pkg/openid/login.go | 185 +++++++++++++++++- pkg/openid/login_callback.go | 21 +++ pkg/openid/login_test.go | 205 ++++++++++++++++++++ pkg/openid/logout.go | 13 ++ pkg/openid/logout_callback.go | 9 + pkg/router/handler.go | 18 +- pkg/router/handler_callback.go | 4 +- pkg/router/handler_error.go | 30 ++- pkg/router/handler_error_test.go | 172 +++++++++++++++++ pkg/router/handler_login.go | 31 +-- pkg/router/handler_logout.go | 16 +- pkg/router/handler_logout_callback.go | 2 +- pkg/router/login_url.go | 88 --------- pkg/router/login_url_test.go | 117 ------------ pkg/router/middleware/logentry.go | 4 + pkg/router/request/parameters.go | 7 - pkg/router/request/request.go | 48 +---- pkg/router/request/request_test.go | 238 ------------------------ pkg/router/router_test.go | 39 +--- pkg/router/session_fallback_test.go | 6 +- 23 files changed, 816 insertions(+), 586 deletions(-) create mode 100644 pkg/mock/client.go rename pkg/mock/{handler.go => idp_handler.go} (100%) create mode 100644 pkg/openid/client.go create mode 100644 pkg/openid/login_callback.go create mode 100644 pkg/openid/login_test.go create mode 100644 pkg/openid/logout.go create mode 100644 pkg/openid/logout_callback.go create mode 100644 pkg/router/handler_error_test.go delete mode 100644 pkg/router/login_url.go delete mode 100644 pkg/router/login_url_test.go delete mode 100644 pkg/router/request/parameters.go diff --git a/pkg/mock/client.go b/pkg/mock/client.go new file mode 100644 index 0000000..1d43fda --- /dev/null +++ b/pkg/mock/client.go @@ -0,0 +1,54 @@ +package mock + +import ( + "time" + + "github.com/rs/zerolog" + + "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router" + "github.com/nais/wonderwall/pkg/session" +) + +func Config() *config.Config { + return &config.Config{ + EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES + Ingress: "/", + OpenID: config.OpenID{ + Provider: "test", + }, + SessionMaxLifetime: time.Hour, + } +} + +func NewClient(provider openid.Provider) openid.Client { + return openid.NewClient(*Config(), provider) +} + +func NewClientWithCfg(cfg *config.Config, provider openid.Provider) openid.Client { + return openid.NewClient(*cfg, provider) +} + +func NewHandler(provider openid.Provider) *router.Handler { + cfg := Config() + return NewHandlerWithCfg(cfg, provider) +} + +func NewHandlerWithCfg(cfg *config.Config, provider openid.Provider) *router.Handler { + if cfg == nil { + cfg = Config() + } + + crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey)) + sessionStore := session.NewMemory() + + h, err := router.NewHandler(*cfg, crypter, zerolog.Logger{}, provider, sessionStore) + if err != nil { + panic(err) + } + + h.CookieOptions = h.CookieOptions.WithSecure(false) + return h +} diff --git a/pkg/mock/handler.go b/pkg/mock/idp_handler.go similarity index 100% rename from pkg/mock/handler.go rename to pkg/mock/idp_handler.go diff --git a/pkg/openid/client.go b/pkg/openid/client.go new file mode 100644 index 0000000..5be6aca --- /dev/null +++ b/pkg/openid/client.go @@ -0,0 +1,95 @@ +package openid + +import ( + "context" + "fmt" + "net/http" + + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/config" +) + +type Client interface { + Config() config.Config + Provider() Provider + OAuth2Config() *oauth2.Config + + Login(r *http.Request) (Login, error) + LoginCallback(r *http.Request) error + Logout(r *http.Request) error + LogoutCallback(r *http.Request) error + + AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) + RefreshGrant(r *http.Request) error +} + +type client struct { + cfg config.Config + provider Provider + oauth2Config *oauth2.Config +} + +func NewClient(cfg config.Config, provider Provider) Client { + oauth2Config := &oauth2.Config{ + ClientID: provider.GetClientConfiguration().GetClientID(), + Endpoint: oauth2.Endpoint{ + AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint, + TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: provider.GetClientConfiguration().GetCallbackURI(), + Scopes: provider.GetClientConfiguration().GetScopes(), + } + + return &client{ + cfg: cfg, + provider: provider, + oauth2Config: oauth2Config, + } +} + +func (c client) Config() config.Config { + return c.cfg +} + +func (c client) Provider() Provider { + return c.provider +} + +func (c client) OAuth2Config() *oauth2.Config { + return c.oauth2Config +} + +func (c client) Login(r *http.Request) (Login, error) { + login, err := NewLogin(c, r) + if err != nil { + return nil, fmt.Errorf("login: %w", err) + } + + return login, nil +} + +func (c client) LoginCallback(r *http.Request) error { + //TODO implement me + panic("implement me") +} + +func (c client) Logout(r *http.Request) error { + //TODO implement me + panic("implement me") +} + +func (c client) LogoutCallback(r *http.Request) error { + //TODO implement me + panic("implement me") +} + +func (c client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) { + return c.oauth2Config.Exchange(ctx, code, opts...) +} + +func (c client) RefreshGrant(r *http.Request) error { + //TODO implement me + panic("implement me") +} diff --git a/pkg/openid/login.go b/pkg/openid/login.go index 7c7e07a..ddd53b6 100644 --- a/pkg/openid/login.go +++ b/pkg/openid/login.go @@ -3,19 +3,109 @@ package openid import ( "crypto/sha256" "encoding/base64" + "errors" "fmt" + "net/http" + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/router/request" "github.com/nais/wonderwall/pkg/strings" ) -type LoginParameters struct { +const ( + LocaleURLParameter = "locale" + SecurityLevelURLParameter = "level" +) + +var ( + InvalidSecurityLevelError = errors.New("InvalidSecurityLevel") + InvalidLocaleError = errors.New("InvalidLocale") + InvalidLoginParameterError = errors.New("InvalidLoginParameter") + + // LoginParameterMapping maps incoming login parameters to OpenID Connect parameters + LoginParameterMapping = map[string]string{ + LocaleURLParameter: "ui_locales", + SecurityLevelURLParameter: "acr_values", + } +) + +type Login interface { + AuthCodeURL() string + CanonicalRedirect() string + CodeChallenge() string + CodeVerifier() string + Cookie() *LoginCookie + Nonce() string + State() string +} + +func NewLogin(c Client, r *http.Request) (Login, error) { + params, err := newLoginParameters(c) + if err != nil { + return nil, fmt.Errorf("generating login parameters: %w", err) + } + + url, err := params.authCodeURL(r) + if err != nil { + return nil, fmt.Errorf("generating login url: %w", err) + } + + redirect := request.CanonicalRedirectURL(r, c.Config().Ingress) + cookie := params.cookie(redirect) + + return login{ + authCodeURL: url, + canonicalRedirect: redirect, + cookie: cookie, + params: params, + }, nil +} + +type login struct { + authCodeURL string + canonicalRedirect string + cookie *LoginCookie + params *loginParameters +} + +func (l login) CodeChallenge() string { + return l.params.CodeChallenge +} + +func (l login) CodeVerifier() string { + return l.params.CodeVerifier +} + +func (l login) Nonce() string { + return l.params.Nonce +} + +func (l login) State() string { + return l.params.State +} + +func (l login) AuthCodeURL() string { + return l.authCodeURL +} + +func (l login) CanonicalRedirect() string { + return l.canonicalRedirect +} + +func (l login) Cookie() *LoginCookie { + return l.cookie +} + +type loginParameters struct { + Client CodeVerifier string CodeChallenge string Nonce string State string } -func GenerateLoginParameters() (*LoginParameters, error) { +func newLoginParameters(c Client) (*loginParameters, error) { codeVerifier, err := strings.GenerateBase64(64) if err != nil { return nil, fmt.Errorf("creating code verifier: %w", err) @@ -31,15 +121,100 @@ func GenerateLoginParameters() (*LoginParameters, error) { return nil, fmt.Errorf("creating state: %w", err) } - return &LoginParameters{ + return &loginParameters{ + Client: c, CodeVerifier: codeVerifier, - CodeChallenge: CodeChallenge(codeVerifier), + CodeChallenge: codeChallenge(codeVerifier), Nonce: nonce, State: state, }, nil } -func CodeChallenge(codeVerifier string) string { +func (in *loginParameters) authCodeURL(r *http.Request) (string, error) { + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("scope", in.Provider().GetClientConfiguration().GetScopes().String()), + oauth2.SetAuthURLParam("nonce", in.Nonce), + oauth2.SetAuthURLParam("response_mode", "query"), + oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + } + + if in.Config().Loginstatus.NeedsResourceIndicator() { + opts = append(opts, oauth2.SetAuthURLParam("resource", in.Config().Loginstatus.ResourceIndicator)) + } + + opts, err := in.withSecurityLevel(r, opts) + if err != nil { + return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err) + } + + opts, err = in.withLocale(r, opts) + if err != nil { + return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err) + } + + authCodeUrl := in.OAuth2Config().AuthCodeURL(in.State, opts...) + return authCodeUrl, nil +} + +func (in *loginParameters) cookie(redirect string) *LoginCookie { + return &LoginCookie{ + State: in.State, + Nonce: in.Nonce, + CodeVerifier: in.CodeVerifier, + Referer: redirect, + } +} + +func (in *loginParameters) withLocale(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) { + return withParamMapping(r, + opts, + LocaleURLParameter, + in.Provider().GetClientConfiguration().GetUILocales(), + in.Provider().GetOpenIDConfiguration().UILocalesSupported, + ) +} + +func (in *loginParameters) withSecurityLevel(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) { + return withParamMapping(r, + opts, + SecurityLevelURLParameter, + in.Provider().GetClientConfiguration().GetACRValues(), + in.Provider().GetOpenIDConfiguration().ACRValuesSupported, + ) +} + +func withParamMapping(r *http.Request, opts []oauth2.AuthCodeOption, param, fallback string, supported Supported) ([]oauth2.AuthCodeOption, error) { + if len(fallback) == 0 { + return opts, nil + } + + value, err := LoginURLParameter(r, param, fallback, supported) + if err != nil { + return nil, err + } + + opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[param], value)) + return opts, nil +} + +// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found. +// The value must exist in the supplied list of supported values. +func LoginURLParameter(r *http.Request, parameter, fallback string, supported Supported) (string, error) { + value := r.URL.Query().Get(parameter) + + if len(value) == 0 { + value = fallback + } + + if supported.Contains(value) { + return value, nil + } + + return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value) +} + +func codeChallenge(codeVerifier string) string { hasher := sha256.New() hasher.Write([]byte(codeVerifier)) codeVerifierHash := hasher.Sum(nil) diff --git a/pkg/openid/login_callback.go b/pkg/openid/login_callback.go new file mode 100644 index 0000000..bbb6094 --- /dev/null +++ b/pkg/openid/login_callback.go @@ -0,0 +1,21 @@ +package openid + +type LoginCallback struct { + Client +} + +func (in LoginCallback) IdentityProviderError() (bool, error) { + panic("not implemented") +} + +func (in LoginCallback) ValidateRequest() error { + panic("not implemented") +} + +func (in LoginCallback) RedeemCode() error { + panic("not implemented") +} + +func (in LoginCallback) ParseAndValidateTokens() error { + panic("not implemented") +} diff --git a/pkg/openid/login_test.go b/pkg/openid/login_test.go new file mode 100644 index 0000000..4bdf0cc --- /dev/null +++ b/pkg/openid/login_test.go @@ -0,0 +1,205 @@ +package openid_test + +import ( + "errors" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" +) + +func TestLogin_URL(t *testing.T) { + type loginURLTest struct { + url string + extraParams map[string]string + error error + } + + tests := []loginURLTest{ + { + url: "http://localhost:1234/oauth2/login?level=Level4", + extraParams: map[string]string{ + "acr_values": "Level4", + }, + error: nil, + }, + { + url: "http://localhost:1234/oauth2/login", + error: nil, + }, + { + url: "http://localhost:1234/oauth2/login?level=NoLevel", + error: openid.InvalidSecurityLevelError, + }, + { + url: "http://localhost:1234/oauth2/login?locale=nb", + extraParams: map[string]string{ + "ui_locales": "nb", + }, + error: nil, + }, + { + url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb", + extraParams: map[string]string{ + "acr_values": "Level4", + "ui_locales": "nb", + }, + error: nil, + }, + { + url: "http://localhost:1234/oauth2/login?locale=es", + error: openid.InvalidLocaleError, + }, + } + + for _, test := range tests { + t.Run(test.url, func(t *testing.T) { + req, err := http.NewRequest("GET", test.url, nil) + assert.NoError(t, err) + + provider := mock.NewTestProvider() + provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize" + client := mock.NewClient(provider) + result, err := client.Login(req) + + if test.error != nil { + assert.True(t, errors.Is(err, test.error)) + } else { + assert.NoError(t, err) + + parsed, err := url.Parse(result.AuthCodeURL()) + assert.NoError(t, err) + + query := parsed.Query() + assert.Contains(t, query, "response_type") + assert.Contains(t, query, "client_id") + assert.Contains(t, query, "redirect_uri") + assert.Contains(t, query, "scope") + assert.Contains(t, query, "state") + assert.Contains(t, query, "nonce") + assert.Contains(t, query, "response_mode") + assert.Contains(t, query, "code_challenge") + assert.Contains(t, query, "code_challenge_method") + assert.NotContains(t, query, "resource") + + assert.ElementsMatch(t, query["response_type"], []string{"code"}) + assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID}) + assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI}) + assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()}) + assert.ElementsMatch(t, query["state"], []string{result.State()}) + assert.ElementsMatch(t, query["nonce"], []string{result.Nonce()}) + assert.ElementsMatch(t, query["response_mode"], []string{"query"}) + assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()}) + assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"}) + + if test.extraParams != nil { + for key, value := range test.extraParams { + assert.Contains(t, query, key) + assert.ElementsMatch(t, query[key], []string{value}) + } + } + } + }) + } +} + +func TestLoginURL_WithResourceIndicator(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil) + assert.NoError(t, err) + + provider := mock.NewTestProvider() + provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize" + cfg := mock.Config() + cfg.Loginstatus.Enabled = true + cfg.Loginstatus.ResourceIndicator = "https://some-resource" + client := mock.NewClientWithCfg(cfg, provider) + result, err := client.Login(req) + + assert.NotEmpty(t, result) + parsed, err := url.Parse(result.AuthCodeURL()) + assert.NoError(t, err) + + query := parsed.Query() + assert.Contains(t, query, "resource") + assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"}) +} + +func TestLoginURLParameter(t *testing.T) { + for _, test := range []struct { + name string + parameter string + fallback string + supported openid.Supported + url string + expectErr error + expected string + }{ + { + name: "no URL parameter should use fallback value", + url: "http://localhost:8080/oauth2/login", + expected: "valid", + }, + { + name: "non-matching URL parameter should be ignored", + url: "http://localhost:8080/oauth2/login?other_param=value2", + expected: "valid", + }, + { + name: "matching URL parameter should take precedence", + url: "http://localhost:8080/oauth2/login?param=valid2", + expected: "valid2", + }, + { + name: "invalid URL parameter value should return error", + url: "http://localhost:8080/oauth2/login?param=invalid", + expectErr: openid.InvalidLoginParameterError, + }, + { + name: "invalid fallback value should return error", + fallback: "invalid", + url: "http://localhost:8080/oauth2/login", + expectErr: openid.InvalidLoginParameterError, + }, + { + name: "no supported values should return error", + url: "http://localhost:8080/oauth2/login", + supported: openid.Supported{""}, + expectErr: openid.InvalidLoginParameterError, + }, + } { + t.Run(test.name, func(t *testing.T) { + r, err := http.NewRequest("GET", test.url, nil) + assert.NoError(t, err) + + // default test values + parameter := "param" + fallback := "valid" + supported := openid.Supported{"valid", "valid2"} + + if len(test.parameter) > 0 { + parameter = test.parameter + } + + if len(test.fallback) > 0 { + fallback = test.fallback + } + + if len(test.supported) > 0 { + supported = test.supported + } + + val, err := openid.LoginURLParameter(r, parameter, fallback, supported) + + if test.expectErr == nil { + assert.NoError(t, err) + assert.Equal(t, test.expected, val) + } else { + assert.Error(t, err) + } + }) + } +} diff --git a/pkg/openid/logout.go b/pkg/openid/logout.go new file mode 100644 index 0000000..ad6d52f --- /dev/null +++ b/pkg/openid/logout.go @@ -0,0 +1,13 @@ +package openid + +type Logout struct { + Client +} + +func (in Logout) URL() string { + panic("not implemented") +} + +func (in Logout) Cookie() LogoutCookie { + panic("not implemented") +} diff --git a/pkg/openid/logout_callback.go b/pkg/openid/logout_callback.go new file mode 100644 index 0000000..66499f8 --- /dev/null +++ b/pkg/openid/logout_callback.go @@ -0,0 +1,9 @@ +package openid + +type LogoutCallback struct { + Client +} + +func (in LogoutCallback) ValidateRequest() (bool, error) { + panic("not implemented") +} diff --git a/pkg/router/handler.go b/pkg/router/handler.go index 695c06c..d1d21c7 100644 --- a/pkg/router/handler.go +++ b/pkg/router/handler.go @@ -2,10 +2,8 @@ package router import ( "net/http" - "sync" "github.com/rs/zerolog" - "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" @@ -16,14 +14,13 @@ import ( ) type Handler struct { + Client openid.Client Config config.Config CookieOptions cookie.Options Crypter crypto.Crypter - OauthConfig oauth2.Config Loginstatus loginstatus.Client Provider openid.Provider Sessions session.Store - lock sync.Mutex Httplogger zerolog.Logger } @@ -34,28 +31,19 @@ func NewHandler( provider openid.Provider, sessionStore session.Store, ) (*Handler, error) { - oauthConfig := oauth2.Config{ - ClientID: provider.GetClientConfiguration().GetClientID(), - Endpoint: oauth2.Endpoint{ - AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint, - TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint, - }, - RedirectURL: provider.GetClientConfiguration().GetCallbackURI(), - Scopes: provider.GetClientConfiguration().GetScopes(), - } + client := openid.NewClient(cfg, provider) loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient) cookiePath := config.ParseIngress(cfg.Ingress) cookieOpts := cookie.DefaultOptions().WithPath(cookiePath) return &Handler{ + Client: client, Config: cfg, CookieOptions: cookieOpts, Crypter: crypter, Httplogger: httplogger, - lock: sync.Mutex{}, Loginstatus: loginstatusClient, - OauthConfig: oauthConfig, Provider: provider, Sessions: sessionStore, }, nil diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 05b37fc..cb782bf 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -112,7 +112,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid. oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), } - tokens, err = h.OauthConfig.Exchange(ctx, code, opts...) + tokens, err = h.Client.AuthCodeGrant(ctx, code, opts) if err != nil { log.Warnf("callback: exchanging authorization code for token; retrying: %+v", err) return retry.RetryableError(err) @@ -154,7 +154,7 @@ func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) { "claims": tokens.Claims(), } - logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("callback: successful login") } diff --git a/pkg/router/handler_error.go b/pkg/router/handler_error.go index ad4ce2c..8a122f2 100644 --- a/pkg/router/handler_error.go +++ b/pkg/router/handler_error.go @@ -2,16 +2,21 @@ package router import ( _ "embed" + "fmt" "html/template" "net/http" "net/url" "strconv" + "strings" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/openid" logentry "github.com/nais/wonderwall/pkg/router/middleware" + "github.com/nais/wonderwall/pkg/router/paths" "github.com/nais/wonderwall/pkg/router/request" ) @@ -58,7 +63,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s errorPage := ErrorPage{ CorrelationID: middleware.GetReqID(r.Context()), - RetryURI: request.RetryURI(r, h.Config.Ingress, loginCookie), + RetryURI: RetryURI(r, h.Config.Ingress, loginCookie), } err = errorTemplate.Execute(w, errorPage) if err != nil { @@ -97,3 +102,26 @@ func (h *Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error func (h *Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) { h.respondError(w, r, http.StatusUnauthorized, cause, zerolog.WarnLevel) } + +// RetryURI returns a URI that should retry the desired route that failed. +// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes +// are related to the authentication flow, we default to redirecting back to the handled +// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow. +func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string { + retryURI := r.URL.Path + prefix := config.ParseIngress(ingress) + + if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) { + return prefix + retryURI + } + + redirect := request.CanonicalRedirectURL(r, ingress) + + if loginCookie != nil && len(loginCookie.Referer) > 0 { + redirect = loginCookie.Referer + } + + retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login) + retryURI = retryURI + fmt.Sprintf("?%s=%s", request.RedirectURLParameter, redirect) + return retryURI +} diff --git a/pkg/router/handler_error_test.go b/pkg/router/handler_error_test.go new file mode 100644 index 0000000..64f2eb3 --- /dev/null +++ b/pkg/router/handler_error_test.go @@ -0,0 +1,172 @@ +package router_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router" +) + +func TestRetryURI(t *testing.T) { + httpRequest := func(url string, referer ...string) *http.Request { + req, _ := http.NewRequest(http.MethodGet, url, nil) + if len(referer) > 0 { + req.Header.Add("Referer", referer[0]) + } + return req + } + + for _, test := range []struct { + name string + request *http.Request + ingress string + loginCookie *openid.LoginCookie + want string + }{ + { + name: "login path", + request: httpRequest("/oauth2/login"), + want: "/oauth2/login?redirect=/", + }, + { + name: "callback path", + request: httpRequest("/oauth2/callback"), + want: "/oauth2/login?redirect=/", + }, + { + name: "logout path", + request: httpRequest("/oauth2/logout"), + want: "/oauth2/logout", + }, + { + name: "front-channel logout path", + request: httpRequest("/oauth2/logout/frontchannel"), + want: "/oauth2/logout/frontchannel", + }, + { + name: "login with non-default ingress", + request: httpRequest("/oauth2/login"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/domene", + }, + { + name: "logout with non-default ingress", + request: httpRequest("/oauth2/logout"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/logout", + }, + { + name: "login with referer", + request: httpRequest("/oauth2/login", "/api/me"), + want: "/oauth2/login?redirect=/api/me", + }, + { + name: "login with referer on non-default ingress", + request: httpRequest("/oauth2/login", "/api/me"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/api/me", + }, + { + name: "login with root referer", + request: httpRequest("/oauth2/login", "/"), + want: "/oauth2/login?redirect=/", + }, + { + name: "login with root referer on non-default ingress", + request: httpRequest("/oauth2/login", "/"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/", + }, + { + name: "login with cookie referer", + request: httpRequest("/oauth2/login"), + loginCookie: &openid.LoginCookie{Referer: "/"}, + want: "/oauth2/login?redirect=/", + }, + { + name: "login with empty cookie referer", + request: httpRequest("/oauth2/login"), + loginCookie: &openid.LoginCookie{Referer: ""}, + want: "/oauth2/login?redirect=/", + }, + { + name: "login with cookie referer takes precedence over referer header", + request: httpRequest("/oauth2/login", "/api/me"), + loginCookie: &openid.LoginCookie{Referer: "/api/headers"}, + want: "/oauth2/login?redirect=/api/headers", + }, + { + name: "login with cookie referer on non-default ingress", + request: httpRequest("/oauth2/login"), + loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/domene/api/me", + }, + { + name: "login with redirect parameter set", + request: httpRequest("/oauth2/login?redirect=/api/me"), + want: "/oauth2/login?redirect=/api/me", + }, + { + name: "login with redirect parameter set and query parameters", + request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"), + want: "/oauth2/login?redirect=/api/me?a=b&c=d", + }, + { + name: "login with redirect parameter set on non-default ingress", + request: httpRequest("/oauth2/login?redirect=/api/me"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/api/me", + }, + { + name: "login with redirect parameter set takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=/other", "/api/me"), + want: "/oauth2/login?redirect=/other", + }, + { + name: "login with redirect parameter set to relative root takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=/", "/api/me"), + want: "/oauth2/login?redirect=/", + }, + { + name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=/", "/api/me"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/", + }, + { + name: "login with redirect parameter set to absolute url takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"), + want: "/oauth2/login?redirect=/", + }, + { + name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"), + want: "/oauth2/login?redirect=/", + }, + { + name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header", + request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"), + ingress: "https://test.nav.no/domene", + want: "/domene/oauth2/login?redirect=/", + }, + { + name: "login with cookie referer takes precedence over redirect parameter", + request: httpRequest("/oauth2/login?redirect=/other"), + loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, + want: "/oauth2/login?redirect=/domene/api/me", + }, + } { + t.Run(test.name, func(t *testing.T) { + if len(test.ingress) == 0 { + test.ingress = "/" + } + + retryURI := router.RetryURI(test.request, test.ingress, test.loginCookie) + assert.Equal(t, test.want, retryURI) + }) + } +} diff --git a/pkg/router/handler_login.go b/pkg/router/handler_login.go index e2e3296..34ccc45 100644 --- a/pkg/router/handler_login.go +++ b/pkg/router/handler_login.go @@ -10,7 +10,6 @@ import ( "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/openid" logentry "github.com/nais/wonderwall/pkg/router/middleware" - "github.com/nais/wonderwall/pkg/router/request" ) const ( @@ -18,44 +17,30 @@ const ( ) func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { - params, err := openid.GenerateLoginParameters() + login, err := h.Client.Login(r) if err != nil { - h.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err)) - return - } - - loginURL, err := h.LoginURL(r, params) - if err != nil { - cause := fmt.Errorf("login: creating login URL: %w", err) - - if errors.Is(err, InvalidSecurityLevelError) || errors.Is(err, InvalidLocaleError) { - h.BadRequest(w, r, cause) + if errors.Is(err, openid.InvalidSecurityLevelError) || errors.Is(err, openid.InvalidLocaleError) { + h.BadRequest(w, r, err) } else { - h.InternalError(w, r, cause) + h.InternalError(w, r, err) } return } - redirect := request.CanonicalRedirectURL(r, h.Config.Ingress) - err = h.setLoginCookies(w, &openid.LoginCookie{ - State: params.State, - Nonce: params.Nonce, - CodeVerifier: params.CodeVerifier, - Referer: redirect, - }) + err = h.setLoginCookies(w, login.Cookie()) if err != nil { h.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err)) return } fields := map[string]interface{}{ - "redirect_to": redirect, + "redirect_after_login": login.CanonicalRedirect(), } - logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("login: redirecting to identity provider") - http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect) + http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect) } func (h *Handler) getLoginCookie(r *http.Request) (*openid.LoginCookie, error) { diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 56d84bd..8c84cf0 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -17,12 +17,6 @@ import ( // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint) - if err != nil { - h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err)) - return - } - var idToken string sessionData, err := h.getSessionFromCookie(w, r) @@ -37,7 +31,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { fields := map[string]interface{}{ "claims": sessionData.Claims, } - logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("logout: successful local logout") } @@ -47,6 +41,12 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { h.Loginstatus.ClearCookie(w, h.CookieOptions) } + u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint) + if err != nil { + h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err)) + return + } + logoutCookie, err := h.logoutCookie() if err != nil { h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err)) @@ -72,7 +72,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { fields := map[string]interface{}{ "redirect_to": logoutCookie.RedirectTo, } - logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger := logentry.LogEntryWithFields(r.Context(), fields) logger.Info().Msg("logout: redirecting to identity provider") http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go index b583dd2..bd3bb40 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/router/handler_logout_callback.go @@ -33,7 +33,7 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { actualState := params.Get("state") if expectedState != actualState { - logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState) + logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState) http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) return } diff --git a/pkg/router/login_url.go b/pkg/router/login_url.go deleted file mode 100644 index 1990b0e..0000000 --- a/pkg/router/login_url.go +++ /dev/null @@ -1,88 +0,0 @@ -package router - -import ( - "errors" - "fmt" - "net/http" - "net/url" - - "github.com/nais/wonderwall/pkg/openid" - request2 "github.com/nais/wonderwall/pkg/router/request" -) - -var ( - InvalidSecurityLevelError = errors.New("InvalidSecurityLevel") - InvalidLocaleError = errors.New("InvalidLocale") -) - -func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (string, error) { - u, err := url.Parse(h.Provider.GetOpenIDConfiguration().AuthorizationEndpoint) - if err != nil { - return "", err - } - - v := u.Query() - v.Add("response_type", "code") - v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID()) - v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI()) - v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String()) - v.Add("state", params.State) - v.Add("nonce", params.Nonce) - v.Add("response_mode", "query") - v.Add("code_challenge", params.CodeChallenge) - v.Add("code_challenge_method", "S256") - - if h.Config.Loginstatus.NeedsResourceIndicator() { - v.Add("resource", h.Config.Loginstatus.ResourceIndicator) - } - - err = h.withSecurityLevel(r, v) - if err != nil { - return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err) - } - - err = h.withLocale(r, v) - if err != nil { - return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err) - } - - u.RawQuery = v.Encode() - - return u.String(), nil -} - -func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error { - acrValues := h.Provider.GetClientConfiguration().GetACRValues() - if len(acrValues) == 0 { - return nil - } - - fallback := acrValues - supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported - - securityLevel, err := request2.LoginURLParameter(r, request2.SecurityLevelURLParameter, fallback, supported) - if err != nil { - return err - } - - v.Add("acr_values", securityLevel) - return nil -} - -func (h *Handler) withLocale(r *http.Request, v url.Values) error { - uiLocales := h.Provider.GetClientConfiguration().GetUILocales() - if len(uiLocales) == 0 { - return nil - } - - fallback := uiLocales - supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported - - locale, err := request2.LoginURLParameter(r, request2.LocaleURLParameter, fallback, supported) - if err != nil { - return err - } - - v.Add("ui_locales", locale) - return nil -} diff --git a/pkg/router/login_url_test.go b/pkg/router/login_url_test.go deleted file mode 100644 index 01b4c61..0000000 --- a/pkg/router/login_url_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package router_test - -import ( - "errors" - "net/http" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/router" -) - -func TestLoginURL(t *testing.T) { - type loginURLTest struct { - url string - error error - } - - tests := []loginURLTest{ - { - url: "http://localhost:1234/oauth2/login?level=Level4", - error: nil, - }, - { - url: "http://localhost:1234/oauth2/login", - error: nil, - }, - { - url: "http://localhost:1234/oauth2/login?level=NoLevel", - error: router.InvalidSecurityLevelError, - }, - { - url: "http://localhost:1234/oauth2/login?locale=nb", - error: nil, - }, - { - url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb", - error: nil, - }, - { - url: "http://localhost:1234/oauth2/login?locale=es", - error: router.InvalidLocaleError, - }, - } - - for _, test := range tests { - t.Run(test.url, func(t *testing.T) { - req, err := http.NewRequest("GET", test.url, nil) - assert.NoError(t, err) - - params, err := openid.GenerateLoginParameters() - assert.NoError(t, err) - - provider := mock.NewTestProvider() - provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize" - handler := newHandler(provider) - result, err := handler.LoginURL(req, params) - - if test.error != nil { - assert.True(t, errors.Is(err, test.error)) - } else { - assert.NoError(t, err) - - parsed, err := url.Parse(result) - assert.NoError(t, err) - - query := parsed.Query() - assert.Contains(t, query, "response_type") - assert.Contains(t, query, "client_id") - assert.Contains(t, query, "redirect_uri") - assert.Contains(t, query, "scope") - assert.Contains(t, query, "state") - assert.Contains(t, query, "nonce") - assert.Contains(t, query, "response_mode") - assert.Contains(t, query, "code_challenge") - assert.Contains(t, query, "code_challenge_method") - assert.NotContains(t, query, "resource") - - assert.ElementsMatch(t, query["response_type"], []string{"code"}) - assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID}) - assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI}) - assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()}) - assert.ElementsMatch(t, query["state"], []string{params.State}) - assert.ElementsMatch(t, query["nonce"], []string{params.Nonce}) - assert.ElementsMatch(t, query["response_mode"], []string{"query"}) - assert.ElementsMatch(t, query["code_challenge"], []string{params.CodeChallenge}) - assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"}) - } - }) - } -} - -func TestLoginURL_WithResourceIndicator(t *testing.T) { - req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil) - assert.NoError(t, err) - - params, err := openid.GenerateLoginParameters() - assert.NoError(t, err) - - provider := mock.NewTestProvider() - provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize" - handler := newHandler(provider) - handler.Config.Loginstatus.Enabled = true - handler.Config.Loginstatus.ResourceIndicator = "https://some-resource" - result, err := handler.LoginURL(req, params) - - assert.NotEmpty(t, result) - parsed, err := url.Parse(result) - assert.NoError(t, err) - - query := parsed.Query() - assert.Contains(t, query, "resource") - assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"}) -} diff --git a/pkg/router/middleware/logentry.go b/pkg/router/middleware/logentry.go index 686dd02..cd8ced3 100644 --- a/pkg/router/middleware/logentry.go +++ b/pkg/router/middleware/logentry.go @@ -52,6 +52,10 @@ func LogEntry(ctx context.Context) zerolog.Logger { return httplog.NewLogger("wonderwall") } +func LogEntryWithFields(ctx context.Context, fields any) zerolog.Logger { + return LogEntry(ctx).With().Fields(fields).Logger() +} + type requestLogger struct { Logger zerolog.Logger } diff --git a/pkg/router/request/parameters.go b/pkg/router/request/parameters.go deleted file mode 100644 index d6532f4..0000000 --- a/pkg/router/request/parameters.go +++ /dev/null @@ -1,7 +0,0 @@ -package request - -const ( - LocaleURLParameter = "locale" - RedirectURLParameter = "redirect" - SecurityLevelURLParameter = "level" -) diff --git a/pkg/router/request/request.go b/pkg/router/request/request.go index aa6f5f4..d0eb31b 100644 --- a/pkg/router/request/request.go +++ b/pkg/router/request/request.go @@ -1,19 +1,14 @@ package request import ( - "errors" - "fmt" "net/http" "net/url" - "strings" "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/router/paths" ) -var ( - InvalidLoginParameterError = errors.New("InvalidLoginParameter") +const ( + RedirectURLParameter = "redirect" ) // CanonicalRedirectURL constructs a redirect URL that points back to the application. @@ -78,22 +73,6 @@ func parseRedirectParam(r *http.Request) (string, bool) { return redirectParamURLString, true } -// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found. -// The value must exist in the supplied list of supported values. -func LoginURLParameter(r *http.Request, parameter, fallback string, supported openid.Supported) (string, error) { - value := r.URL.Query().Get(parameter) - - if len(value) == 0 { - value = fallback - } - - if supported.Contains(value) { - return value, nil - } - - return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value) -} - func refererPath(r *http.Request) string { if len(r.Referer()) == 0 { return "" @@ -109,26 +88,3 @@ func refererPath(r *http.Request) string { referer.Host = "" return referer.String() } - -// RetryURI returns a URI that should retry the desired route that failed. -// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes -// are related to the authentication flow, we default to redirecting back to the handled -// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow. -func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string { - retryURI := r.URL.Path - prefix := config.ParseIngress(ingress) - - if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) { - return prefix + retryURI - } - - redirect := CanonicalRedirectURL(r, ingress) - - if loginCookie != nil && len(loginCookie.Referer) > 0 { - redirect = loginCookie.Referer - } - - retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login) - retryURI = retryURI + fmt.Sprintf("?%s=%s", RedirectURLParameter, redirect) - return retryURI -} diff --git a/pkg/router/request/request_test.go b/pkg/router/request/request_test.go index d759bc0..5596de2 100644 --- a/pkg/router/request/request_test.go +++ b/pkg/router/request/request_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router/request" ) @@ -138,240 +137,3 @@ func TestCanonicalRedirectURL(t *testing.T) { } }) } - -func TestLoginURLParameter(t *testing.T) { - for _, test := range []struct { - name string - parameter string - fallback string - supported openid.Supported - url string - expectErr error - expected string - }{ - { - name: "no URL parameter should use fallback value", - url: "http://localhost:8080/oauth2/login", - expected: "valid", - }, - { - name: "non-matching URL parameter should be ignored", - url: "http://localhost:8080/oauth2/login?other_param=value2", - expected: "valid", - }, - { - name: "matching URL parameter should take precedence", - url: "http://localhost:8080/oauth2/login?param=valid2", - expected: "valid2", - }, - { - name: "invalid URL parameter value should return error", - url: "http://localhost:8080/oauth2/login?param=invalid", - expectErr: request.InvalidLoginParameterError, - }, - { - name: "invalid fallback value should return error", - fallback: "invalid", - url: "http://localhost:8080/oauth2/login", - expectErr: request.InvalidLoginParameterError, - }, - { - name: "no supported values should return error", - url: "http://localhost:8080/oauth2/login", - supported: openid.Supported{""}, - expectErr: request.InvalidLoginParameterError, - }, - } { - t.Run(test.name, func(t *testing.T) { - r, err := http.NewRequest("GET", test.url, nil) - assert.NoError(t, err) - - // default test values - parameter := "param" - fallback := "valid" - supported := openid.Supported{"valid", "valid2"} - - if len(test.parameter) > 0 { - parameter = test.parameter - } - - if len(test.fallback) > 0 { - fallback = test.fallback - } - - if len(test.supported) > 0 { - supported = test.supported - } - - val, err := request.LoginURLParameter(r, parameter, fallback, supported) - - if test.expectErr == nil { - assert.NoError(t, err) - assert.Equal(t, test.expected, val) - } else { - assert.Error(t, err) - } - }) - } -} - -func TestRetryURI(t *testing.T) { - httpRequest := func(url string, referer ...string) *http.Request { - req, _ := http.NewRequest(http.MethodGet, url, nil) - if len(referer) > 0 { - req.Header.Add("Referer", referer[0]) - } - return req - } - - for _, test := range []struct { - name string - request *http.Request - ingress string - loginCookie *openid.LoginCookie - want string - }{ - { - name: "login path", - request: httpRequest("/oauth2/login"), - want: "/oauth2/login?redirect=/", - }, - { - name: "callback path", - request: httpRequest("/oauth2/callback"), - want: "/oauth2/login?redirect=/", - }, - { - name: "logout path", - request: httpRequest("/oauth2/logout"), - want: "/oauth2/logout", - }, - { - name: "front-channel logout path", - request: httpRequest("/oauth2/logout/frontchannel"), - want: "/oauth2/logout/frontchannel", - }, - { - name: "login with non-default ingress", - request: httpRequest("/oauth2/login"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/domene", - }, - { - name: "logout with non-default ingress", - request: httpRequest("/oauth2/logout"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/logout", - }, - { - name: "login with referer", - request: httpRequest("/oauth2/login", "/api/me"), - want: "/oauth2/login?redirect=/api/me", - }, - { - name: "login with referer on non-default ingress", - request: httpRequest("/oauth2/login", "/api/me"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/api/me", - }, - { - name: "login with root referer", - request: httpRequest("/oauth2/login", "/"), - want: "/oauth2/login?redirect=/", - }, - { - name: "login with root referer on non-default ingress", - request: httpRequest("/oauth2/login", "/"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/", - }, - { - name: "login with cookie referer", - request: httpRequest("/oauth2/login"), - loginCookie: &openid.LoginCookie{Referer: "/"}, - want: "/oauth2/login?redirect=/", - }, - { - name: "login with empty cookie referer", - request: httpRequest("/oauth2/login"), - loginCookie: &openid.LoginCookie{Referer: ""}, - want: "/oauth2/login?redirect=/", - }, - { - name: "login with cookie referer takes precedence over referer header", - request: httpRequest("/oauth2/login", "/api/me"), - loginCookie: &openid.LoginCookie{Referer: "/api/headers"}, - want: "/oauth2/login?redirect=/api/headers", - }, - { - name: "login with cookie referer on non-default ingress", - request: httpRequest("/oauth2/login"), - loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/domene/api/me", - }, - { - name: "login with redirect parameter set", - request: httpRequest("/oauth2/login?redirect=/api/me"), - want: "/oauth2/login?redirect=/api/me", - }, - { - name: "login with redirect parameter set and query parameters", - request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"), - want: "/oauth2/login?redirect=/api/me?a=b&c=d", - }, - { - name: "login with redirect parameter set on non-default ingress", - request: httpRequest("/oauth2/login?redirect=/api/me"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/api/me", - }, - { - name: "login with redirect parameter set takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=/other", "/api/me"), - want: "/oauth2/login?redirect=/other", - }, - { - name: "login with redirect parameter set to relative root takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=/", "/api/me"), - want: "/oauth2/login?redirect=/", - }, - { - name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=/", "/api/me"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/", - }, - { - name: "login with redirect parameter set to absolute url takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"), - want: "/oauth2/login?redirect=/", - }, - { - name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"), - want: "/oauth2/login?redirect=/", - }, - { - name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header", - request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"), - ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=/", - }, - { - name: "login with cookie referer takes precedence over redirect parameter", - request: httpRequest("/oauth2/login?redirect=/other"), - loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, - want: "/oauth2/login?redirect=/domene/api/me", - }, - } { - t.Run(test.name, func(t *testing.T) { - if len(test.ingress) == 0 { - test.ingress = "/" - } - - retryURI := request.RetryURI(test.request, test.ingress, test.loginCookie) - assert.Equal(t, test.want, retryURI) - }) - } -} diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 93f11d3..4822f1c 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -8,45 +8,17 @@ import ( "net/http/httptest" "net/url" "testing" - "time" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router" - "github.com/nais/wonderwall/pkg/session" ) -var cfg = config.Config{ - EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES - Ingress: "/", - OpenID: config.OpenID{ - Provider: "test", - }, - SessionMaxLifetime: time.Hour, -} - -func newHandler(provider openid.Provider) *router.Handler { - crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey)) - sessionStore := session.NewMemory() - - h, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, provider, sessionStore) - if err != nil { - panic(err) - } - - h.CookieOptions = h.CookieOptions.WithSecure(false) - return h -} - func TestHandler_Login(t *testing.T) { idpserver, idp := mock.IdentityProviderServer() - h := newHandler(idp) + h := mock.NewHandler(idp) r := router.New(h) jar, err := cookiejar.New(nil) @@ -103,13 +75,14 @@ func TestHandler_Login(t *testing.T) { func TestHandler_Callback_and_Logout(t *testing.T) { idpserver, idp := mock.IdentityProviderServer() - h := newHandler(idp) + h := mock.NewHandler(idp) r := router.New(h) server := httptest.NewServer(r) idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback" + h.Client = mock.NewClient(idp) jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -239,12 +212,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) { _, idp := mock.IdentityProviderServer() idp.WithFrontChannelLogoutSupport() - h := newHandler(idp) + h := mock.NewHandler(idp) r := router.New(h) server := httptest.NewServer(r) idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL + h.Client = mock.NewClient(idp) jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -313,12 +287,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) { func TestHandler_SessionStateRequired(t *testing.T) { idpServer, idp := mock.IdentityProviderServer() idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession") - h := newHandler(idp) + h := mock.NewHandler(idp) r := router.New(h) server := httptest.NewServer(r) idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL + h.Client = mock.NewClient(idp) jar, err := cookiejar.New(nil) assert.NoError(t, err) diff --git a/pkg/router/session_fallback_test.go b/pkg/router/session_fallback_test.go index 648ccc0..54dc9a3 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/router/session_fallback_test.go @@ -21,7 +21,7 @@ import ( func TestHandler_GetSessionFallback(t *testing.T) { p := mock.NewTestProvider() - h := newHandler(p) + h := mock.NewHandler(p) tokens := makeTokens(p) t.Run("request without fallback session cookies", func(t *testing.T) { @@ -47,7 +47,7 @@ func TestHandler_GetSessionFallback(t *testing.T) { func TestHandler_SetSessionFallback(t *testing.T) { provider := mock.NewTestProvider() - h := newHandler(provider) + h := mock.NewHandler(provider) // request should set session cookies in response writer := httptest.NewRecorder() @@ -82,7 +82,7 @@ func TestHandler_SetSessionFallback(t *testing.T) { func TestHandler_DeleteSessionFallback(t *testing.T) { p := mock.NewTestProvider() - h := newHandler(p) + h := mock.NewHandler(p) tokens := makeTokens(p) t.Run("expire cookies if they are set", func(t *testing.T) {