refactor: split out packages from router

This commit is contained in:
Trong Huu Nguyen
2022-07-15 07:44:54 +02:00
parent fd630e6dbd
commit e3b9d33296
24 changed files with 199 additions and 206 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/logging"
"github.com/nais/wonderwall/pkg/metrics"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
@@ -58,7 +59,7 @@ func run() error {
crypt := crypto.NewCrypter(key)
sessionStore := session.NewStore(cfg)
httplogger := logging.NewHttpLogger(cfg)
h, err := router.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore)
h, err := handler.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore)
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)
}

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"context"

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"context"
@@ -11,9 +11,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/loginstatus"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
const (

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"net/http"

View File

@@ -1,23 +1,18 @@
package router
package handler
import (
_ "embed"
"fmt"
"html/template"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/go-chi/chi/v5/middleware"
"github.com/rs/zerolog"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
"github.com/nais/wonderwall/pkg/router/request"
logentry "github.com/nais/wonderwall/pkg/middleware"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
type ErrorPage struct {
@@ -63,7 +58,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s
errorPage := ErrorPage{
CorrelationID: middleware.GetReqID(r.Context()),
RetryURI: RetryURI(r, h.Cfg.Wonderwall().Ingress, loginCookie),
RetryURI: urlpkg.Retry(r, h.Cfg.Wonderwall().Ingress, loginCookie),
}
err = errorTemplate.Execute(w, errorPage)
if err != nil {
@@ -102,26 +97,3 @@ func (h *Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error
func (h *Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
h.respondError(w, r, http.StatusUnauthorized, cause, zerolog.WarnLevel)
}
// RetryURI returns a URI that should retry the desired route that failed.
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
// are related to the authentication flow, we default to redirecting back to the handled
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
retryURI := r.URL.Path
prefix := config.ParseIngress(ingress)
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
return prefix + retryURI
}
redirect := request.CanonicalRedirectURL(r, ingress)
if loginCookie != nil && len(loginCookie.Referer) > 0 {
redirect = loginCookie.Referer
}
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
retryURI = retryURI + fmt.Sprintf("?%s=%s", request.RedirectURLParameter, redirect)
return retryURI
}

View File

@@ -0,0 +1 @@
package handler_test

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"net/http"

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"encoding/json"
@@ -8,9 +8,9 @@ import (
"time"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
const (

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"errors"
@@ -8,7 +8,7 @@ import (
"github.com/go-redis/redis/v8"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
logentry "github.com/nais/wonderwall/pkg/middleware"
)
// Logout triggers self-initiated for the current user

View File

@@ -1,9 +1,9 @@
package router
package handler
import (
"net/http"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
logentry "github.com/nais/wonderwall/pkg/middleware"
)
// LogoutCallback handles the callback from the self-initiated logout for the current user

View File

@@ -1,4 +1,4 @@
package router_test
package handler_test
import (
"encoding/base64"

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"context"

View File

@@ -1,4 +1,4 @@
package router
package handler
import (
"net/http"

View File

@@ -1,4 +1,4 @@
package router_test
package handler_test
import (
"context"
@@ -13,9 +13,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
)
@@ -117,7 +117,7 @@ func TestHandler_DeleteSessionFallback(t *testing.T) {
})
}
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request {
func makeRequestWithFallbackCookies(t *testing.T, h *handler.Handler, tokens *openid.Tokens) *http.Request {
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", tokens, nil)
@@ -150,7 +150,7 @@ func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie
assert.Empty(t, expired.Value)
}
func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
desiredCookie := getCookieFromJar(cookieName, cookies)
assert.NotNil(t, desiredCookie)

View File

@@ -20,7 +20,8 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/openid/client"
handlerpkg "github.com/nais/wonderwall/pkg/handler"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
scopespkg "github.com/nais/wonderwall/pkg/openid/scopes"
"github.com/nais/wonderwall/pkg/router"
@@ -34,7 +35,7 @@ type IdentityProvider struct {
Provider TestProvider
ProviderHandler *IdentityProviderHandler
ProviderServer *httptest.Server
RelyingPartyHandler *router.Handler
RelyingPartyHandler *handlerpkg.Handler
RelyingPartyServer *httptest.Server
}
@@ -76,7 +77,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider {
sessionStore := session.NewMemory()
ctx, cancel := context.WithCancel(context.Background())
rpHandler, err := router.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore)
rpHandler, err := handlerpkg.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore)
if err != nil {
panic(err)
}
@@ -88,7 +89,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider {
openidConfig.ClientConfig.CallbackURI = rpServer.URL + "/oauth2/callback"
openidConfig.ClientConfig.PostLogoutRedirectURI = rpServer.URL
openidConfig.ClientConfig.LogoutCallbackURI = rpServer.URL + "/oauth2/logout/callback"
rpHandler.Client = client.NewClient(openidConfig)
rpHandler.Client = openidclient.NewClient(openidConfig)
return IdentityProvider{
cancelFunc: cancel,
@@ -357,7 +358,7 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
return
}
expectedCodeChallenge := client.CodeChallenge(codeVerifier)
expectedCodeChallenge := openidclient.CodeChallenge(codeVerifier)
if expectedCodeChallenge != auth.CodeChallenge {
w.WriteHeader(http.StatusBadRequest)

View File

@@ -11,8 +11,8 @@ import (
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/router/request"
"github.com/nais/wonderwall/pkg/strings"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
const (
@@ -53,7 +53,7 @@ func NewLogin(c Client, r *http.Request) (Login, error) {
return nil, fmt.Errorf("generating auth code url: %w", err)
}
redirect := request.CanonicalRedirectURL(r, c.config().Wonderwall().Ingress)
redirect := urlpkg.CanonicalRedirect(r, c.config().Wonderwall().Ingress)
cookie := params.cookie(redirect)
return &login{

View File

@@ -1,139 +0,0 @@
package request_test
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/router/request"
)
func TestCanonicalRedirectURL(t *testing.T) {
r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil)
assert.NoError(t, err)
t.Run("default redirect", func(t *testing.T) {
for _, test := range []struct {
name string
ingress string
expected string
}{
{
name: "root with trailing slash",
ingress: "http://localhost:8080/",
expected: "/",
},
{
name: "root without trailing slash",
ingress: "http://localhost:8080",
expected: "/",
},
{
name: "path with trailing slash",
ingress: "http://localhost:8080/path/",
expected: "/path",
},
{
name: "path without trailing slash",
ingress: "http://localhost:8080/path",
expected: "/path",
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, test.ingress))
})
}
})
// Default path is /some-path
ingress := "http://localhost:8080/some-path"
// HTTP Referer header is 2nd priority
t.Run("Referer header is set", func(t *testing.T) {
for _, test := range []struct {
name string
value string
expected string
}{
{
name: "full URL",
value: "http://localhost:8080/foo/bar/baz",
expected: "/foo/bar/baz",
},
{
name: "full URL with query parameters",
value: "http://localhost:8080/foo/bar/baz?gnu=notunix",
expected: "/foo/bar/baz?gnu=notunix",
},
{
name: "absolute path",
value: "/foo/bar/baz",
expected: "/foo/bar/baz",
},
{
name: "absolute path with query parameters",
value: "/foo/bar/baz?gnu=notunix",
expected: "/foo/bar/baz?gnu=notunix",
},
} {
t.Run(test.name, func(t *testing.T) {
r.Header.Set("Referer", test.value)
assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, ingress))
})
}
})
// If redirect parameter is set, use that
t.Run("redirect parameter is set", func(t *testing.T) {
for _, test := range []struct {
name string
value string
expected string
}{
{
name: "complete url with parameters",
value: "http://localhost:8080/path/to/redirect?val1=foo&val2=bar",
expected: "/path/to/redirect?val1=foo&val2=bar",
},
{
name: "root url with trailing slash",
value: "http://localhost:8080/",
expected: "/",
},
{
name: "root url without trailing slash",
value: "http://localhost:8080",
expected: "/",
},
{
name: "url path with trailing slash",
value: "http://localhost:8080/path/",
expected: "/path/",
},
{
name: "url path without trailing slash",
value: "http://localhost:8080/path",
expected: "/path",
},
{
name: "absolute path",
value: "/path",
expected: "/path",
},
{
name: "absolute path with query parameters",
value: "/path?gnu=notunix",
expected: "/path?gnu=notunix",
},
} {
t.Run(test.name, func(t *testing.T) {
v := &url.Values{}
v.Set("redirect", test.value)
r.URL.RawQuery = v.Encode()
assert.Equal(t, test.expected, request.CanonicalRedirectURL(r, ingress))
})
}
})
}

View File

@@ -5,11 +5,12 @@ import (
chi_middleware "github.com/go-chi/chi/v5/middleware"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/router/middleware"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
)
func New(handler *Handler) chi.Router {
func New(handler *handler.Handler) chi.Router {
r := chi.NewRouter()
r.Use(middleware.CorrelationIDHandler)
r.Use(chi_middleware.Recoverer)

View File

@@ -1,18 +1,22 @@
package request
package url
import (
"fmt"
"net/http"
"net/url"
"strings"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router/paths"
)
const (
RedirectURLParameter = "redirect"
)
// CanonicalRedirectURL constructs a redirect URL that points back to the application.
func CanonicalRedirectURL(r *http.Request, ingress string) string {
// CanonicalRedirect constructs a redirect URL that points back to the application.
func CanonicalRedirect(r *http.Request, ingress string) string {
// 1. Default
defaultPath := defaultRedirectURL(ingress)
redirect := defaultPath
@@ -37,6 +41,29 @@ func CanonicalRedirectURL(r *http.Request, ingress string) string {
return redirect
}
// Retry returns a URI that should retry the desired route that failed.
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
// are related to the authentication flow, we default to redirecting back to the handled
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
func Retry(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
retryURI := r.URL.Path
prefix := config.ParseIngress(ingress)
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
return prefix + retryURI
}
redirect := CanonicalRedirect(r, ingress)
if loginCookie != nil && len(loginCookie.Referer) > 0 {
redirect = loginCookie.Referer
}
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
retryURI = retryURI + fmt.Sprintf("?%s=%s", RedirectURLParameter, redirect)
return retryURI
}
func defaultRedirectURL(ingress string) string {
defaultPath := "/"
ingressPath := config.ParseIngress(ingress)

View File

@@ -1,16 +1,145 @@
package router_test
package url_test
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestRetryURI(t *testing.T) {
func TestCanonicalRedirect(t *testing.T) {
r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil)
assert.NoError(t, err)
t.Run("default redirect", func(t *testing.T) {
for _, test := range []struct {
name string
ingress string
expected string
}{
{
name: "root with trailing slash",
ingress: "http://localhost:8080/",
expected: "/",
},
{
name: "root without trailing slash",
ingress: "http://localhost:8080",
expected: "/",
},
{
name: "path with trailing slash",
ingress: "http://localhost:8080/path/",
expected: "/path",
},
{
name: "path without trailing slash",
ingress: "http://localhost:8080/path",
expected: "/path",
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, test.ingress))
})
}
})
// Default path is /some-path
ingress := "http://localhost:8080/some-path"
// HTTP Referer header is 2nd priority
t.Run("Referer header is set", func(t *testing.T) {
for _, test := range []struct {
name string
value string
expected string
}{
{
name: "full URL",
value: "http://localhost:8080/foo/bar/baz",
expected: "/foo/bar/baz",
},
{
name: "full URL with query parameters",
value: "http://localhost:8080/foo/bar/baz?gnu=notunix",
expected: "/foo/bar/baz?gnu=notunix",
},
{
name: "absolute path",
value: "/foo/bar/baz",
expected: "/foo/bar/baz",
},
{
name: "absolute path with query parameters",
value: "/foo/bar/baz?gnu=notunix",
expected: "/foo/bar/baz?gnu=notunix",
},
} {
t.Run(test.name, func(t *testing.T) {
r.Header.Set("Referer", test.value)
assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, ingress))
})
}
})
// If redirect parameter is set, use that
t.Run("redirect parameter is set", func(t *testing.T) {
for _, test := range []struct {
name string
value string
expected string
}{
{
name: "complete url with parameters",
value: "http://localhost:8080/path/to/redirect?val1=foo&val2=bar",
expected: "/path/to/redirect?val1=foo&val2=bar",
},
{
name: "root url with trailing slash",
value: "http://localhost:8080/",
expected: "/",
},
{
name: "root url without trailing slash",
value: "http://localhost:8080",
expected: "/",
},
{
name: "url path with trailing slash",
value: "http://localhost:8080/path/",
expected: "/path/",
},
{
name: "url path without trailing slash",
value: "http://localhost:8080/path",
expected: "/path",
},
{
name: "absolute path",
value: "/path",
expected: "/path",
},
{
name: "absolute path with query parameters",
value: "/path?gnu=notunix",
expected: "/path?gnu=notunix",
},
} {
t.Run(test.name, func(t *testing.T) {
v := &url.Values{}
v.Set("redirect", test.value)
r.URL.RawQuery = v.Encode()
assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r, ingress))
})
}
})
}
func TestRetry(t *testing.T) {
httpRequest := func(url string, referer ...string) *http.Request {
req, _ := http.NewRequest(http.MethodGet, url, nil)
if len(referer) > 0 {
@@ -165,7 +294,7 @@ func TestRetryURI(t *testing.T) {
test.ingress = "/"
}
retryURI := router.RetryURI(test.request, test.ingress, test.loginCookie)
retryURI := urlpkg.Retry(test.request, test.ingress, test.loginCookie)
assert.Equal(t, test.want, retryURI)
})
}