From e3b9d33296c7035bf20058d5987f2096150b91fa Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 15 Jul 2022 07:44:54 +0200 Subject: [PATCH] refactor: split out packages from router --- cmd/wonderwall/main.go | 3 +- pkg/{router => handler}/handler.go | 2 +- pkg/{router => handler}/handler_callback.go | 4 +- pkg/{router => handler}/handler_default.go | 2 +- pkg/{router => handler}/handler_error.go | 36 +---- pkg/handler/handler_error_test.go | 1 + .../handler_frontchannellogout.go | 2 +- pkg/{router => handler}/handler_login.go | 4 +- pkg/{router => handler}/handler_logout.go | 4 +- .../handler_logout_callback.go | 4 +- .../handler_test.go} | 2 +- pkg/{router => handler}/session.go | 2 +- pkg/{router => handler}/session_fallback.go | 2 +- .../session_fallback_test.go | 8 +- .../templates/error.gohtml | 0 pkg/{router => }/middleware/correlationid.go | 0 pkg/{router => }/middleware/logentry.go | 0 pkg/{router => }/middleware/prometheus.go | 0 pkg/mock/openid.go | 11 +- pkg/openid/client/login.go | 4 +- pkg/router/request/request_test.go | 139 ------------------ pkg/router/router.go | 5 +- pkg/{router/request/request.go => url/url.go} | 33 ++++- .../handler_error_test.go => url/url_test.go} | 137 ++++++++++++++++- 24 files changed, 199 insertions(+), 206 deletions(-) rename pkg/{router => handler}/handler.go (98%) rename pkg/{router => handler}/handler_callback.go (97%) rename pkg/{router => handler}/handler_default.go (99%) rename pkg/{router => handler}/handler_error.go (65%) create mode 100644 pkg/handler/handler_error_test.go rename pkg/{router => handler}/handler_frontchannellogout.go (98%) rename pkg/{router => handler}/handler_login.go (97%) rename pkg/{router => handler}/handler_logout.go (94%) rename pkg/{router => handler}/handler_logout_callback.go (84%) rename pkg/{router/router_test.go => handler/handler_test.go} (99%) rename pkg/{router => handler}/session.go (99%) rename pkg/{router => handler}/session_fallback.go (97%) rename pkg/{router => handler}/session_fallback_test.go (94%) rename pkg/{router => handler}/templates/error.gohtml (100%) rename pkg/{router => }/middleware/correlationid.go (100%) rename pkg/{router => }/middleware/logentry.go (100%) rename pkg/{router => }/middleware/prometheus.go (100%) delete mode 100644 pkg/router/request/request_test.go rename pkg/{router/request/request.go => url/url.go} (58%) rename pkg/{router/handler_error_test.go => url/url_test.go} (62%) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 4c7e49d..a6e525e 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -9,6 +9,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/handler" "github.com/nais/wonderwall/pkg/logging" "github.com/nais/wonderwall/pkg/metrics" openidconfig "github.com/nais/wonderwall/pkg/openid/config" @@ -58,7 +59,7 @@ func run() error { crypt := crypto.NewCrypter(key) sessionStore := session.NewStore(cfg) httplogger := logging.NewHttpLogger(cfg) - h, err := router.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore) + h, err := handler.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } diff --git a/pkg/router/handler.go b/pkg/handler/handler.go similarity index 98% rename from pkg/router/handler.go rename to pkg/handler/handler.go index 6ea16ac..eb90ef0 100644 --- a/pkg/router/handler.go +++ b/pkg/handler/handler.go @@ -1,4 +1,4 @@ -package router +package handler import ( "context" diff --git a/pkg/router/handler_callback.go b/pkg/handler/handler_callback.go similarity index 97% rename from pkg/router/handler_callback.go rename to pkg/handler/handler_callback.go index 3d4fa45..1a1f1d7 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/handler/handler_callback.go @@ -1,4 +1,4 @@ -package router +package handler import ( "context" @@ -11,9 +11,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/nais/wonderwall/pkg/loginstatus" + logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" - logentry "github.com/nais/wonderwall/pkg/router/middleware" ) const ( diff --git a/pkg/router/handler_default.go b/pkg/handler/handler_default.go similarity index 99% rename from pkg/router/handler_default.go rename to pkg/handler/handler_default.go index bf47a7b..6111780 100644 --- a/pkg/router/handler_default.go +++ b/pkg/handler/handler_default.go @@ -1,4 +1,4 @@ -package router +package handler import ( "net/http" diff --git a/pkg/router/handler_error.go b/pkg/handler/handler_error.go similarity index 65% rename from pkg/router/handler_error.go rename to pkg/handler/handler_error.go index 1425b51..88f4b4f 100644 --- a/pkg/router/handler_error.go +++ b/pkg/handler/handler_error.go @@ -1,23 +1,18 @@ -package router +package handler import ( _ "embed" - "fmt" "html/template" "net/http" "net/url" "strconv" - "strings" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" log "github.com/sirupsen/logrus" - "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/openid" - logentry "github.com/nais/wonderwall/pkg/router/middleware" - "github.com/nais/wonderwall/pkg/router/paths" - "github.com/nais/wonderwall/pkg/router/request" + logentry "github.com/nais/wonderwall/pkg/middleware" + urlpkg "github.com/nais/wonderwall/pkg/url" ) type ErrorPage struct { @@ -63,7 +58,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s errorPage := ErrorPage{ CorrelationID: middleware.GetReqID(r.Context()), - RetryURI: RetryURI(r, h.Cfg.Wonderwall().Ingress, loginCookie), + RetryURI: urlpkg.Retry(r, h.Cfg.Wonderwall().Ingress, loginCookie), } err = errorTemplate.Execute(w, errorPage) if err != nil { @@ -102,26 +97,3 @@ func (h *Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error func (h *Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) { h.respondError(w, r, http.StatusUnauthorized, cause, zerolog.WarnLevel) } - -// RetryURI returns a URI that should retry the desired route that failed. -// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes -// are related to the authentication flow, we default to redirecting back to the handled -// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow. -func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string { - retryURI := r.URL.Path - prefix := config.ParseIngress(ingress) - - if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) { - return prefix + retryURI - } - - redirect := request.CanonicalRedirectURL(r, ingress) - - if loginCookie != nil && len(loginCookie.Referer) > 0 { - redirect = loginCookie.Referer - } - - retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login) - retryURI = retryURI + fmt.Sprintf("?%s=%s", request.RedirectURLParameter, redirect) - return retryURI -} diff --git a/pkg/handler/handler_error_test.go b/pkg/handler/handler_error_test.go new file mode 100644 index 0000000..13e2707 --- /dev/null +++ b/pkg/handler/handler_error_test.go @@ -0,0 +1 @@ +package handler_test diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/handler/handler_frontchannellogout.go similarity index 98% rename from pkg/router/handler_frontchannellogout.go rename to pkg/handler/handler_frontchannellogout.go index 9f64e14..0a55bab 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/handler/handler_frontchannellogout.go @@ -1,4 +1,4 @@ -package router +package handler import ( "net/http" diff --git a/pkg/router/handler_login.go b/pkg/handler/handler_login.go similarity index 97% rename from pkg/router/handler_login.go rename to pkg/handler/handler_login.go index 04b4755..28595c1 100644 --- a/pkg/router/handler_login.go +++ b/pkg/handler/handler_login.go @@ -1,4 +1,4 @@ -package router +package handler import ( "encoding/json" @@ -8,9 +8,9 @@ import ( "time" "github.com/nais/wonderwall/pkg/cookie" + logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" - logentry "github.com/nais/wonderwall/pkg/router/middleware" ) const ( diff --git a/pkg/router/handler_logout.go b/pkg/handler/handler_logout.go similarity index 94% rename from pkg/router/handler_logout.go rename to pkg/handler/handler_logout.go index 0d3fc26..80a3f89 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/handler/handler_logout.go @@ -1,4 +1,4 @@ -package router +package handler import ( "errors" @@ -8,7 +8,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/nais/wonderwall/pkg/cookie" - logentry "github.com/nais/wonderwall/pkg/router/middleware" + logentry "github.com/nais/wonderwall/pkg/middleware" ) // Logout triggers self-initiated for the current user diff --git a/pkg/router/handler_logout_callback.go b/pkg/handler/handler_logout_callback.go similarity index 84% rename from pkg/router/handler_logout_callback.go rename to pkg/handler/handler_logout_callback.go index 1aa7d02..d98b923 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/handler/handler_logout_callback.go @@ -1,9 +1,9 @@ -package router +package handler import ( "net/http" - logentry "github.com/nais/wonderwall/pkg/router/middleware" + logentry "github.com/nais/wonderwall/pkg/middleware" ) // LogoutCallback handles the callback from the self-initiated logout for the current user diff --git a/pkg/router/router_test.go b/pkg/handler/handler_test.go similarity index 99% rename from pkg/router/router_test.go rename to pkg/handler/handler_test.go index 437fd33..e88ab1a 100644 --- a/pkg/router/router_test.go +++ b/pkg/handler/handler_test.go @@ -1,4 +1,4 @@ -package router_test +package handler_test import ( "encoding/base64" diff --git a/pkg/router/session.go b/pkg/handler/session.go similarity index 99% rename from pkg/router/session.go rename to pkg/handler/session.go index 2c67680..e90240f 100644 --- a/pkg/router/session.go +++ b/pkg/handler/session.go @@ -1,4 +1,4 @@ -package router +package handler import ( "context" diff --git a/pkg/router/session_fallback.go b/pkg/handler/session_fallback.go similarity index 97% rename from pkg/router/session_fallback.go rename to pkg/handler/session_fallback.go index f352e8c..729d6b2 100644 --- a/pkg/router/session_fallback.go +++ b/pkg/handler/session_fallback.go @@ -1,4 +1,4 @@ -package router +package handler import ( "net/http" diff --git a/pkg/router/session_fallback_test.go b/pkg/handler/session_fallback_test.go similarity index 94% rename from pkg/router/session_fallback_test.go rename to pkg/handler/session_fallback_test.go index 7be1dd7..3ccadc2 100644 --- a/pkg/router/session_fallback_test.go +++ b/pkg/handler/session_fallback_test.go @@ -1,4 +1,4 @@ -package router_test +package handler_test import ( "context" @@ -13,9 +13,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/nais/wonderwall/pkg/handler" "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" ) @@ -117,7 +117,7 @@ func TestHandler_DeleteSessionFallback(t *testing.T) { }) } -func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request { +func makeRequestWithFallbackCookies(t *testing.T, h *handler.Handler, tokens *openid.Tokens) *http.Request { writer := httptest.NewRecorder() expiresIn := time.Minute data := session.NewData("sid", tokens, nil) @@ -150,7 +150,7 @@ func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie assert.Empty(t, expired.Value) } -func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedValue string, cookies []*http.Cookie) { +func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) { desiredCookie := getCookieFromJar(cookieName, cookies) assert.NotNil(t, desiredCookie) diff --git a/pkg/router/templates/error.gohtml b/pkg/handler/templates/error.gohtml similarity index 100% rename from pkg/router/templates/error.gohtml rename to pkg/handler/templates/error.gohtml diff --git a/pkg/router/middleware/correlationid.go b/pkg/middleware/correlationid.go similarity index 100% rename from pkg/router/middleware/correlationid.go rename to pkg/middleware/correlationid.go diff --git a/pkg/router/middleware/logentry.go b/pkg/middleware/logentry.go similarity index 100% rename from pkg/router/middleware/logentry.go rename to pkg/middleware/logentry.go diff --git a/pkg/router/middleware/prometheus.go b/pkg/middleware/prometheus.go similarity index 100% rename from pkg/router/middleware/prometheus.go rename to pkg/middleware/prometheus.go diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index a963b22..e058981 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -20,7 +20,8 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/openid/client" + handlerpkg "github.com/nais/wonderwall/pkg/handler" + openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" scopespkg "github.com/nais/wonderwall/pkg/openid/scopes" "github.com/nais/wonderwall/pkg/router" @@ -34,7 +35,7 @@ type IdentityProvider struct { Provider TestProvider ProviderHandler *IdentityProviderHandler ProviderServer *httptest.Server - RelyingPartyHandler *router.Handler + RelyingPartyHandler *handlerpkg.Handler RelyingPartyServer *httptest.Server } @@ -76,7 +77,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider { sessionStore := session.NewMemory() ctx, cancel := context.WithCancel(context.Background()) - rpHandler, err := router.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore) + rpHandler, err := handlerpkg.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore) if err != nil { panic(err) } @@ -88,7 +89,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider { openidConfig.ClientConfig.CallbackURI = rpServer.URL + "/oauth2/callback" openidConfig.ClientConfig.PostLogoutRedirectURI = rpServer.URL openidConfig.ClientConfig.LogoutCallbackURI = rpServer.URL + "/oauth2/logout/callback" - rpHandler.Client = client.NewClient(openidConfig) + rpHandler.Client = openidclient.NewClient(openidConfig) return IdentityProvider{ cancelFunc: cancel, @@ -357,7 +358,7 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - expectedCodeChallenge := client.CodeChallenge(codeVerifier) + expectedCodeChallenge := openidclient.CodeChallenge(codeVerifier) if expectedCodeChallenge != auth.CodeChallenge { w.WriteHeader(http.StatusBadRequest) diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index d18747b..00d425f 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -11,8 +11,8 @@ import ( "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" - "github.com/nais/wonderwall/pkg/router/request" "github.com/nais/wonderwall/pkg/strings" + urlpkg "github.com/nais/wonderwall/pkg/url" ) const ( @@ -53,7 +53,7 @@ func NewLogin(c Client, r *http.Request) (Login, error) { return nil, fmt.Errorf("generating auth code url: %w", err) } - redirect := request.CanonicalRedirectURL(r, c.config().Wonderwall().Ingress) + redirect := urlpkg.CanonicalRedirect(r, c.config().Wonderwall().Ingress) cookie := params.cookie(redirect) return &login{ diff --git a/pkg/router/request/request_test.go b/pkg/router/request/request_test.go deleted file mode 100644 index 5596de2..0000000 --- a/pkg/router/request/request_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package request_test - -import ( - "net/http" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/router/request" -) - -func TestCanonicalRedirectURL(t *testing.T) { - r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil) - assert.NoError(t, err) - - t.Run("default redirect", func(t *testing.T) { - for _, test := range []struct { - name string - ingress string - expected string - }{ - { - name: "root with trailing slash", - ingress: "http://localhost:8080/", - expected: "/", - }, - { - name: "root without trailing slash", - ingress: "http://localhost:8080", - expected: "/", - }, - { - name: "path with trailing slash", - ingress: "http://localhost:8080/path/", - expected: "/path", - }, - { - name: "path without trailing slash", - ingress: "http://localhost:8080/path", - expected: "/path", - }, - } { - t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, test.ingress)) - }) - } - }) - - // Default path is /some-path - ingress := "http://localhost:8080/some-path" - - // HTTP Referer header is 2nd priority - t.Run("Referer header is set", func(t *testing.T) { - for _, test := range []struct { - name string - value string - expected string - }{ - { - name: "full URL", - value: "http://localhost:8080/foo/bar/baz", - expected: "/foo/bar/baz", - }, - { - name: "full URL with query parameters", - value: "http://localhost:8080/foo/bar/baz?gnu=notunix", - expected: "/foo/bar/baz?gnu=notunix", - }, - { - name: "absolute path", - value: "/foo/bar/baz", - expected: "/foo/bar/baz", - }, - { - name: "absolute path with query parameters", - value: "/foo/bar/baz?gnu=notunix", - expected: "/foo/bar/baz?gnu=notunix", - }, - } { - t.Run(test.name, func(t *testing.T) { - r.Header.Set("Referer", test.value) - assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, ingress)) - }) - } - }) - - // If redirect parameter is set, use that - t.Run("redirect parameter is set", func(t *testing.T) { - for _, test := range []struct { - name string - value string - expected string - }{ - { - name: "complete url with parameters", - value: "http://localhost:8080/path/to/redirect?val1=foo&val2=bar", - expected: "/path/to/redirect?val1=foo&val2=bar", - }, - { - name: "root url with trailing slash", - value: "http://localhost:8080/", - expected: "/", - }, - { - name: "root url without trailing slash", - value: "http://localhost:8080", - expected: "/", - }, - { - name: "url path with trailing slash", - value: "http://localhost:8080/path/", - expected: "/path/", - }, - { - name: "url path without trailing slash", - value: "http://localhost:8080/path", - expected: "/path", - }, - { - name: "absolute path", - value: "/path", - expected: "/path", - }, - { - name: "absolute path with query parameters", - value: "/path?gnu=notunix", - expected: "/path?gnu=notunix", - }, - } { - t.Run(test.name, func(t *testing.T) { - v := &url.Values{} - v.Set("redirect", test.value) - r.URL.RawQuery = v.Encode() - assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, ingress)) - }) - } - }) -} diff --git a/pkg/router/router.go b/pkg/router/router.go index 8370d15..667c87f 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -5,11 +5,12 @@ import ( chi_middleware "github.com/go-chi/chi/v5/middleware" "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/router/middleware" + "github.com/nais/wonderwall/pkg/handler" + "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/router/paths" ) -func New(handler *Handler) chi.Router { +func New(handler *handler.Handler) chi.Router { r := chi.NewRouter() r.Use(middleware.CorrelationIDHandler) r.Use(chi_middleware.Recoverer) diff --git a/pkg/router/request/request.go b/pkg/url/url.go similarity index 58% rename from pkg/router/request/request.go rename to pkg/url/url.go index d0eb31b..767addf 100644 --- a/pkg/router/request/request.go +++ b/pkg/url/url.go @@ -1,18 +1,22 @@ -package request +package url import ( + "fmt" "net/http" "net/url" + "strings" "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router/paths" ) const ( RedirectURLParameter = "redirect" ) -// CanonicalRedirectURL constructs a redirect URL that points back to the application. -func CanonicalRedirectURL(r *http.Request, ingress string) string { +// CanonicalRedirect constructs a redirect URL that points back to the application. +func CanonicalRedirect(r *http.Request, ingress string) string { // 1. Default defaultPath := defaultRedirectURL(ingress) redirect := defaultPath @@ -37,6 +41,29 @@ func CanonicalRedirectURL(r *http.Request, ingress string) string { return redirect } +// Retry returns a URI that should retry the desired route that failed. +// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes +// are related to the authentication flow, we default to redirecting back to the handled +// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow. +func Retry(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string { + retryURI := r.URL.Path + prefix := config.ParseIngress(ingress) + + if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) { + return prefix + retryURI + } + + redirect := CanonicalRedirect(r, ingress) + + if loginCookie != nil && len(loginCookie.Referer) > 0 { + redirect = loginCookie.Referer + } + + retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login) + retryURI = retryURI + fmt.Sprintf("?%s=%s", RedirectURLParameter, redirect) + return retryURI +} + func defaultRedirectURL(ingress string) string { defaultPath := "/" ingressPath := config.ParseIngress(ingress) diff --git a/pkg/router/handler_error_test.go b/pkg/url/url_test.go similarity index 62% rename from pkg/router/handler_error_test.go rename to pkg/url/url_test.go index 64f2eb3..43fda9a 100644 --- a/pkg/router/handler_error_test.go +++ b/pkg/url/url_test.go @@ -1,16 +1,145 @@ -package router_test +package url_test import ( "net/http" + "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/router" + urlpkg "github.com/nais/wonderwall/pkg/url" ) -func TestRetryURI(t *testing.T) { +func TestCanonicalRedirect(t *testing.T) { + r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil) + assert.NoError(t, err) + + t.Run("default redirect", func(t *testing.T) { + for _, test := range []struct { + name string + ingress string + expected string + }{ + { + name: "root with trailing slash", + ingress: "http://localhost:8080/", + expected: "/", + }, + { + name: "root without trailing slash", + ingress: "http://localhost:8080", + expected: "/", + }, + { + name: "path with trailing slash", + ingress: "http://localhost:8080/path/", + expected: "/path", + }, + { + name: "path without trailing slash", + ingress: "http://localhost:8080/path", + expected: "/path", + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, test.ingress)) + }) + } + }) + + // Default path is /some-path + ingress := "http://localhost:8080/some-path" + + // HTTP Referer header is 2nd priority + t.Run("Referer header is set", func(t *testing.T) { + for _, test := range []struct { + name string + value string + expected string + }{ + { + name: "full URL", + value: "http://localhost:8080/foo/bar/baz", + expected: "/foo/bar/baz", + }, + { + name: "full URL with query parameters", + value: "http://localhost:8080/foo/bar/baz?gnu=notunix", + expected: "/foo/bar/baz?gnu=notunix", + }, + { + name: "absolute path", + value: "/foo/bar/baz", + expected: "/foo/bar/baz", + }, + { + name: "absolute path with query parameters", + value: "/foo/bar/baz?gnu=notunix", + expected: "/foo/bar/baz?gnu=notunix", + }, + } { + t.Run(test.name, func(t *testing.T) { + r.Header.Set("Referer", test.value) + assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, ingress)) + }) + } + }) + + // If redirect parameter is set, use that + t.Run("redirect parameter is set", func(t *testing.T) { + for _, test := range []struct { + name string + value string + expected string + }{ + { + name: "complete url with parameters", + value: "http://localhost:8080/path/to/redirect?val1=foo&val2=bar", + expected: "/path/to/redirect?val1=foo&val2=bar", + }, + { + name: "root url with trailing slash", + value: "http://localhost:8080/", + expected: "/", + }, + { + name: "root url without trailing slash", + value: "http://localhost:8080", + expected: "/", + }, + { + name: "url path with trailing slash", + value: "http://localhost:8080/path/", + expected: "/path/", + }, + { + name: "url path without trailing slash", + value: "http://localhost:8080/path", + expected: "/path", + }, + { + name: "absolute path", + value: "/path", + expected: "/path", + }, + { + name: "absolute path with query parameters", + value: "/path?gnu=notunix", + expected: "/path?gnu=notunix", + }, + } { + t.Run(test.name, func(t *testing.T) { + v := &url.Values{} + v.Set("redirect", test.value) + r.URL.RawQuery = v.Encode() + assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, ingress)) + }) + } + }) +} + +func TestRetry(t *testing.T) { httpRequest := func(url string, referer ...string) *http.Request { req, _ := http.NewRequest(http.MethodGet, url, nil) if len(referer) > 0 { @@ -165,7 +294,7 @@ func TestRetryURI(t *testing.T) { test.ingress = "/" } - retryURI := router.RetryURI(test.request, test.ingress, test.loginCookie) + retryURI := urlpkg.Retry(test.request, test.ingress, test.loginCookie) assert.Equal(t, test.want, retryURI) }) }