refactor(url): extract utility functions

This commit is contained in:
Trong Huu Nguyen
2023-02-06 20:24:05 +01:00
parent d13525f8a2
commit 5f74ee08bc
14 changed files with 244 additions and 284 deletions

View File

@@ -14,10 +14,10 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/handler/templates"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router/paths"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
const (

View File

@@ -15,9 +15,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/cookie"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
@@ -103,7 +103,7 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
logoutCallbackURI := resp.Location
req := idp.GetRequest(resp.Location.String())
expectedLogoutCallbackURL, err := urlpkg.LogoutCallbackURL(req)
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), expectedLogoutCallbackURL)

View File

@@ -8,8 +8,8 @@ import (
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin(t *testing.T) {
@@ -24,7 +24,7 @@ func TestLogin(t *testing.T) {
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
expectedCallbackURL, err := urlpkg.LoginCallbackURL(req)
expectedCallbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))

View File

@@ -6,8 +6,8 @@ import (
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogout(t *testing.T) {
@@ -27,7 +27,7 @@ func TestLogout(t *testing.T) {
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallbackURL(req)
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()

View File

@@ -11,10 +11,10 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/handler/autologin"
"github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/loginstatus"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/url"
)
type ReverseProxySource interface {

View File

@@ -6,8 +6,8 @@ import (
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestReverseProxy(t *testing.T) {
@@ -79,7 +79,7 @@ func TestReverseProxy(t *testing.T) {
callbackEndpoint.RawQuery = ""
req := idp.GetRequest(callbackLocation.String())
expectedCallbackURL, err := urlpkg.LoginCallbackURL(req)
expectedCallbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.Equal(t, expectedCallbackURL, callbackEndpoint.String())

View File

@@ -1,112 +0,0 @@
package url
import (
"fmt"
"net/http"
"net/url"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
)
const (
RedirectURLParameter = "redirect"
)
// CanonicalRedirect constructs a redirect URL that points back to the application.
func CanonicalRedirect(r *http.Request) string {
ingressPath, ok := mw.PathFrom(r.Context())
if len(ingressPath) == 0 || !ok {
ingressPath = "/"
}
// 1. Default
redirect := ingressPath
// 2. Redirect parameter is set
redirectParam := r.URL.Query().Get(RedirectURLParameter)
if len(redirectParam) > 0 {
redirect = redirectParam
}
// Ensure URL isn't encoded
redirect, err := url.QueryUnescape(redirect)
if err != nil {
return ingressPath
}
parsed, err := url.ParseRequestURI(redirect)
if err != nil {
// Silently fall back to ingress path
return ingressPath
}
// Strip scheme and host to avoid cross-domain redirects
parsed.Scheme = ""
parsed.Host = ""
redirect = parsed.String()
// Root path without trailing slash is empty
if len(parsed.Path) == 0 {
redirect = "/"
}
// Ensure that empty path redirections falls back to the ingress' context path if applicable
if len(redirect) == 0 {
redirect = ingressPath
}
return redirect
}
// Login constructs a URL string that points to the login path for the given target URL.
// The given redirect string should point to the location to be redirected to after login.
func Login(target *url.URL, redirect string) string {
u := target.JoinPath(paths.OAuth2, paths.Login)
v := u.Query()
v.Set(RedirectURLParameter, redirect)
u.RawQuery = v.Encode()
return u.String()
}
// LoginRelative constructs the relative URL with an absolute path that points to the application's login path, given an optional path prefix.
// The given redirect string should point to the location to be redirected to after login.
func LoginRelative(prefix, redirect string) string {
u := new(url.URL)
u.Path = prefix
if prefix == "" {
u.Path = "/"
}
return Login(u, redirect)
}
func LoginCallbackURL(r *http.Request) (string, error) {
return makeCallbackURL(r, paths.LoginCallback)
}
func LogoutCallbackURL(r *http.Request) (string, error) {
return makeCallbackURL(r, paths.LogoutCallback)
}
func makeCallbackURL(r *http.Request, callbackPath string) (string, error) {
u, err := Ingress(r)
if err != nil {
return "", err
}
return u.JoinPath(paths.OAuth2, callbackPath).String(), nil
}
func Ingress(r *http.Request) (*url.URL, error) {
ing, found := mw.IngressFrom(r.Context())
if !found {
return nil, fmt.Errorf("request host does not match any configured ingresses")
}
return ing.NewURL(), nil
}

View File

@@ -9,11 +9,11 @@ import (
"golang.org/x/oauth2"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/loginstatus"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/strings"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
const (
@@ -43,7 +43,7 @@ func NewLogin(c *Client, r *http.Request) (*Login, error) {
return nil, fmt.Errorf("generating parameters: %w", err)
}
callbackURL, err := urlpkg.LoginCallbackURL(r)
callbackURL, err := urlpkg.LoginCallback(r)
if err != nil {
return nil, fmt.Errorf("generating callback url: %w", err)
}

View File

@@ -8,8 +8,8 @@ import (
"golang.org/x/oauth2"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/openid"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
type LoginCallback struct {
@@ -26,7 +26,7 @@ func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*
// redirect_uri not set in cookie (e.g. login initiated at instance running older version, callback handled at newer version)
if len(cookie.RedirectURI) == 0 {
callbackURL, err := urlpkg.LoginCallbackURL(r)
callbackURL, err := urlpkg.LoginCallback(r)
if err != nil {
return nil, fmt.Errorf("generating callback url: %w", err)
}

View File

@@ -7,10 +7,10 @@ import (
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLoginCallback_StateMismatchError(t *testing.T) {
@@ -114,7 +114,7 @@ func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client
idp := mock.NewIdentityProvider(mock.Config())
req := idp.GetRequest(url)
redirect, err := urlpkg.LoginCallbackURL(req)
redirect, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{

View File

@@ -9,11 +9,11 @@ import (
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/loginstatus"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid/client"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin_URL(t *testing.T) {
@@ -92,7 +92,7 @@ func TestLogin_URL(t *testing.T) {
assert.Contains(t, query, "code_challenge_method")
assert.NotContains(t, query, "resource")
callbackURL, err := urlpkg.LoginCallbackURL(req)
callbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.ElementsMatch(t, query["response_type"], []string{"code"})

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"net/http"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/openid"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
type Logout struct {
@@ -15,7 +15,7 @@ type Logout struct {
}
func NewLogout(c *Client, r *http.Request) (*Logout, error) {
logoutCallbackURL, err := urlpkg.LogoutCallbackURL(r)
logoutCallbackURL, err := urlpkg.LogoutCallback(r)
if err != nil {
return nil, fmt.Errorf("generating logout callback url: %w", err)
}

82
pkg/url/url.go Normal file
View File

@@ -0,0 +1,82 @@
package url
import (
"errors"
"net/http"
"net/url"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
)
const (
RedirectQueryParameter = "redirect"
)
var (
ErrNoMatchingIngress = errors.New("request host does not match any configured ingresses")
)
// Login constructs a URL string that points to the login path for the given target URL.
// The given redirect string should point to the location to be redirected to after login.
func Login(target *url.URL, redirect string) string {
u := target.JoinPath(paths.OAuth2, paths.Login)
v := u.Query()
v.Set(RedirectQueryParameter, redirect)
u.RawQuery = v.Encode()
return u.String()
}
// LoginRelative constructs the relative URL with an absolute path that points to the application's login path, given an optional path prefix.
// The given redirect string should point to the location to be redirected to after login.
func LoginRelative(prefix, redirect string) string {
u := new(url.URL)
u.Path = prefix
if prefix == "" {
u.Path = "/"
}
return Login(u, redirect)
}
func LoginCallback(r *http.Request) (string, error) {
return makeCallbackURL(r, paths.LoginCallback)
}
func LogoutCallback(r *http.Request) (string, error) {
return makeCallbackURL(r, paths.LogoutCallback)
}
func makeCallbackURL(r *http.Request, callbackPath string) (string, error) {
u, err := MatchingIngress(r)
if err != nil {
return "", err
}
return u.JoinPath(paths.OAuth2, callbackPath).String(), nil
}
func MatchingPath(r *http.Request) *url.URL {
u := &url.URL{}
p, found := mw.PathFrom(r.Context())
if found && len(p) > 0 {
u.Path = p
} else {
u.Path = "/"
}
return u
}
func MatchingIngress(r *http.Request) (*url.URL, error) {
ing, found := mw.IngressFrom(r.Context())
if !found {
return nil, ErrNoMatchingIngress
}
return ing.NewURL(), nil
}

View File

@@ -1,148 +1,15 @@
package url_test
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/ingress"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestCanonicalRedirect(t *testing.T) {
t.Run("default redirect", func(t *testing.T) {
for _, test := range []struct {
name string
ingress string
expected string
}{
{
name: "root with trailing slash",
ingress: "http://localhost:8080/",
expected: "/",
},
{
name: "root without trailing slash",
ingress: "http://localhost:8080",
expected: "/",
},
{
name: "path with trailing slash",
ingress: "http://localhost:8080/path/",
expected: "/path",
},
{
name: "path without trailing slash",
ingress: "http://localhost:8080/path",
expected: "/path",
},
} {
t.Run(test.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, test.ingress+"/oauth2/login", nil)
parsed, err := ingress.ParseIngress(test.ingress)
assert.NoError(t, err)
r = mw.RequestWithPath(r, parsed.Path())
assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r))
})
}
})
// Default path is /some-path
defaultIngress := "http://localhost:8080/some-path"
r := httptest.NewRequest(http.MethodGet, defaultIngress+"/oauth2/login", nil)
r = mw.RequestWithPath(r, "/some-path")
// If redirect parameter is set, use that
t.Run("redirect parameter is set", func(t *testing.T) {
for _, test := range []struct {
name string
value string
expected string
}{
{
name: "complete url with parameters",
value: "http://localhost:8080/path/to/redirect?val1=foo&val2=bar",
expected: "/path/to/redirect?val1=foo&val2=bar",
},
{
name: "root url with trailing slash",
value: "http://localhost:8080/",
expected: "/",
},
{
name: "root url without trailing slash",
value: "http://localhost:8080",
expected: "/",
},
{
name: "url path with trailing slash",
value: "http://localhost:8080/path/",
expected: "/path/",
},
{
name: "url path without trailing slash",
value: "http://localhost:8080/path",
expected: "/path",
},
{
name: "absolute path",
value: "/path",
expected: "/path",
},
{
name: "absolute path with query parameters",
value: "/path?gnu=notunix",
expected: "/path?gnu=notunix",
},
{
name: "relative path",
value: "path",
expected: "/some-path", // should fall back to default path
},
{
name: "relative path with query parameters",
value: "path?gnu=notunix",
expected: "/some-path", // should fall back to default path
},
{
name: "url encoded path",
value: "%2Fpath",
expected: "/path",
},
{
name: "url encoded path and query parameters",
value: "%2Fpath%3Fgnu%3Dnotunix",
expected: "/path?gnu=notunix",
},
{
name: "url encoded url",
value: "http%3A%2F%2Flocalhost%3A8080%2Fpath",
expected: "/path",
},
{
name: "url encoded url and multiple query parameters",
value: "http%3A%2F%2Flocalhost%3A8080%2Fpath%3Fgnu%3Dnotunix%26foo%3Dbar",
expected: "/path?gnu=notunix&foo=bar",
},
} {
t.Run(test.name, func(t *testing.T) {
v := &url.Values{}
v.Set("redirect", test.value)
r.URL.RawQuery = v.Encode()
assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r))
})
}
})
}
func TestLogin(t *testing.T) {
for _, test := range []struct {
name string
@@ -224,7 +91,7 @@ func TestLoginRelative(t *testing.T) {
}
}
func TestLoginCallbackURL(t *testing.T) {
func TestLoginCallback(t *testing.T) {
cfg := mock.Config()
cfg.Ingresses = []string{
"https://nav.no",
@@ -235,9 +102,9 @@ func TestLoginCallbackURL(t *testing.T) {
ingresses := mock.Ingresses(cfg)
for _, test := range []struct {
input string
want string
err error
input string
want string
wantErr bool
}{
{
input: "https://nav.no/",
@@ -256,16 +123,16 @@ func TestLoginCallbackURL(t *testing.T) {
want: "https://nav.no/dagpenger/soknad/oauth2/callback",
},
{
input: "https://not-nav.no/",
err: fmt.Errorf("request host does not match any configured ingresses"),
input: "https://not-nav.no/",
wantErr: true,
},
} {
t.Run(test.input, func(t *testing.T) {
req := mock.NewGetRequest(test.input, ingresses)
actual, err := urlpkg.LoginCallbackURL(req)
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
actual, err := urlpkg.LoginCallback(req)
if test.wantErr {
assert.ErrorIs(t, err, urlpkg.ErrNoMatchingIngress)
} else {
assert.NoError(t, err)
assert.Equal(t, test.want, actual)
@@ -274,7 +141,7 @@ func TestLoginCallbackURL(t *testing.T) {
}
}
func TestLogoutCallbackURL(t *testing.T) {
func TestLogoutCallback(t *testing.T) {
cfg := mock.Config()
cfg.Ingresses = []string{
"https://nav.no",
@@ -285,9 +152,9 @@ func TestLogoutCallbackURL(t *testing.T) {
ingresses := mock.Ingresses(cfg)
for _, test := range []struct {
input string
want string
err error
input string
want string
wantErr bool
}{
{
input: "https://nav.no/",
@@ -306,16 +173,16 @@ func TestLogoutCallbackURL(t *testing.T) {
want: "https://nav.no/dagpenger/soknad/oauth2/logout/callback",
},
{
input: "https://not-nav.no/",
err: fmt.Errorf("request host does not match any configured ingresses"),
input: "https://not-nav.no/",
wantErr: true,
},
} {
t.Run(test.input, func(t *testing.T) {
req := mock.NewGetRequest(test.input, ingresses)
actual, err := urlpkg.LogoutCallbackURL(req)
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
actual, err := urlpkg.LogoutCallback(req)
if test.wantErr {
assert.ErrorIs(t, err, urlpkg.ErrNoMatchingIngress)
} else {
assert.NoError(t, err)
assert.Equal(t, test.want, actual)
@@ -323,3 +190,126 @@ func TestLogoutCallbackURL(t *testing.T) {
})
}
}
func TestMatchingPath(t *testing.T) {
cfg := mock.Config()
cfg.Ingresses = []string{
"http://wonderwall",
"http://wonderwall/some-path",
}
ingresses := mock.Ingresses(cfg)
t.Run("matching ingress path", func(t *testing.T) {
for _, tt := range []struct {
target string
expected string
}{
{
target: "/",
expected: "/",
},
{
target: "/some-path",
expected: "/some-path",
},
{
target: "/some-path/some-subpath",
expected: "/some-path",
},
{
target: "http://wonderwall",
expected: "/",
},
{
target: "http://wonderwall/some-path",
expected: "/some-path",
},
{
target: "http://wonderwall/some-path/some-subpath",
expected: "/some-path",
},
} {
t.Run(tt.target, func(t *testing.T) {
req := mock.NewGetRequest(tt.target, ingresses)
assert.Equal(t, tt.expected, urlpkg.MatchingPath(req).String())
})
}
})
t.Run("no matching path should fall back to root", func(t *testing.T) {
req := mock.NewGetRequest("http://wonderwall/no-match", ingresses)
assert.Equal(t, "/", urlpkg.MatchingPath(req).String())
})
}
func TestMatchingIngress(t *testing.T) {
cfg := mock.Config()
cfg.Ingresses = []string{
"http://wonderwall",
"http://wonderwall/some-path",
}
ingresses := mock.Ingresses(cfg)
t.Run("matching ingress path", func(t *testing.T) {
for _, tt := range []struct {
target string
expected string
}{
{
target: "http://wonderwall",
expected: "http://wonderwall",
},
{
target: "http://wonderwall/",
expected: "http://wonderwall",
},
{
target: "http://wonderwall/?val1=foo&val2=bar",
expected: "http://wonderwall",
},
{
target: "http://wonderwall/some-path",
expected: "http://wonderwall/some-path",
},
{
target: "http://wonderwall/some-path/",
expected: "http://wonderwall/some-path",
},
{
target: "http://wonderwall/some-path/some-subpath",
expected: "http://wonderwall/some-path",
},
} {
t.Run(tt.target, func(t *testing.T) {
req := mock.NewGetRequest(tt.target, ingresses)
actual, err := urlpkg.MatchingIngress(req)
assert.NoError(t, err)
assert.Equal(t, tt.expected, actual.String())
})
}
})
t.Run("relative URLs should return error", func(t *testing.T) {
for _, target := range []string{
"/",
"/some-path",
"/some-path/some-subpath",
} {
t.Run(target, func(t *testing.T) {
req := mock.NewGetRequest(target, ingresses)
_, err := urlpkg.MatchingIngress(req)
assert.ErrorIs(t, err, urlpkg.ErrNoMatchingIngress)
})
}
})
t.Run("no matching ingress should return error", func(t *testing.T) {
req := mock.NewGetRequest("http://not-wonderwall", ingresses)
_, err := urlpkg.MatchingIngress(req)
assert.ErrorIs(t, err, urlpkg.ErrNoMatchingIngress)
})
}