refactor: move request related utilities to own pkg

This commit is contained in:
Trong Huu Nguyen
2021-10-06 12:39:08 +02:00
parent fb4adc9cc5
commit 7979bb09fb
10 changed files with 67 additions and 79 deletions

View File

@@ -2,7 +2,7 @@ package errorhandler
import (
"github.com/go-chi/chi/v5/middleware"
"github.com/nais/wonderwall/pkg/url"
"github.com/nais/wonderwall/pkg/request"
"html/template"
"net/http"
@@ -21,7 +21,7 @@ func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause
t, _ := template.ParseFiles("templates/error.html")
errorPage := ErrorPage{
CorrelationID: middleware.GetReqID(r.Context()),
CanonicalRedirectURL: url.CanonicalRedirectURL(r),
CanonicalRedirectURL: request.CanonicalRedirectURL(r),
}
t.Execute(w, errorPage)
}

View File

@@ -0,0 +1,8 @@
package request
const (
LocaleURLParameter = "locale"
PostLogoutRedirectURIParameter = "post_logout_redirect_uri"
RedirectURLParameter = "redirect"
SecurityLevelURLParameter = "level"
)

View File

@@ -1,16 +1,37 @@
package router
package request
import (
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/config"
"net/http"
"net/url"
)
var (
InvalidLoginParameterError = errors.New("InvalidLoginParameter")
)
// CanonicalRedirectURL constructs a redirect URL that points back to the application.
func CanonicalRedirectURL(r *http.Request) string {
redirectURL := "/"
referer, err := url.Parse(r.Referer())
if err == nil && len(referer.Path) > 0 {
redirectURL = referer.Path
}
override := r.URL.Query().Get(RedirectURLParameter)
if len(override) > 0 {
referer, err = url.Parse(override)
if err == nil {
// strip scheme and host to avoid cross-domain redirects
referer.Scheme = ""
referer.Host = ""
redirectURL = referer.String()
}
}
return redirectURL
}
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
// The value must exist in the supplied list of supported values.
func LoginURLParameter(r *http.Request, parameter, fallback string, supported config.Supported) (string, error) {

View File

@@ -1,15 +1,32 @@
package router_test
package request_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/request"
"github.com/stretchr/testify/assert"
"net/http"
"net/url"
"testing"
)
func TestCanonicalRedirectURL(t *testing.T) {
r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil)
assert.NoError(t, err)
// Default URL is /
assert.Equal(t, "/", request.CanonicalRedirectURL(r))
// HTTP Referer header is 2nd priority
r.Header.Set("referer", "http://localhost:8080/foo/bar/baz?gnu=notunix")
assert.Equal(t, "/foo/bar/baz", request.CanonicalRedirectURL(r))
// If redirect parameter is set, use that
v := &url.Values{}
v.Set("redirect", "https://google.com/path/to/redirect?val1=foo&val2=bar")
r.URL.RawQuery = v.Encode()
assert.Equal(t, "/path/to/redirect?val1=foo&val2=bar", request.CanonicalRedirectURL(r))
}
func TestLoginURLParameter(t *testing.T) {
for _, test := range []struct {
name string
@@ -38,19 +55,19 @@ func TestLoginURLParameter(t *testing.T) {
{
name: "invalid URL parameter value should return error",
url: "http://localhost:8080/oauth2/login?param=invalid",
expectErr: router.InvalidLoginParameterError,
expectErr: request.InvalidLoginParameterError,
},
{
name: "invalid fallback value should return error",
fallback: "invalid",
url: "http://localhost:8080/oauth2/login",
expectErr: router.InvalidLoginParameterError,
expectErr: request.InvalidLoginParameterError,
},
{
name: "no supported values should return error",
url: "http://localhost:8080/oauth2/login",
supported: config.Supported{""},
expectErr: router.InvalidLoginParameterError,
expectErr: request.InvalidLoginParameterError,
},
} {
t.Run(test.name, func(t *testing.T) {
@@ -74,7 +91,7 @@ func TestLoginURLParameter(t *testing.T) {
supported = test.supported
}
val, err := router.LoginURLParameter(r, parameter, fallback, supported)
val, err := request.LoginURLParameter(r, parameter, fallback, supported)
if test.expectErr == nil {
assert.NoError(t, err)

View File

@@ -12,12 +12,6 @@ import (
"github.com/nais/wonderwall/pkg/session"
)
const (
SecurityLevelURLParameter = "level"
LocaleURLParameter = "locale"
PostLogoutRedirectURIParameter = "post_logout_redirect_uri"
)
type Handler struct {
Config config.IDPorten
Crypter cryptutil.Crypter

View File

@@ -3,7 +3,7 @@ package router
import (
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/url"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"github.com/nais/wonderwall/pkg/auth"
@@ -34,7 +34,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
State: params.State,
Nonce: params.Nonce,
CodeVerifier: params.CodeVerifier,
Referer: url.CanonicalRedirectURL(r),
Referer: request.CanonicalRedirectURL(r),
})
if err != nil {
errorhandler.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))

View File

@@ -2,6 +2,7 @@ package router
import (
"fmt"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
h.deleteCookie(w, h.GetSessionCookieName())
v := u.Query()
v.Add("post_logout_redirect_uri", PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI))
v.Add("post_logout_redirect_uri", request.PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI))
if len(idToken) != 0 {
v.Add("id_token_hint", idToken)

View File

@@ -3,6 +3,7 @@ package router
import (
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"net/url"
@@ -55,7 +56,7 @@ func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
fallback := h.Config.SecurityLevel.Value
supported := h.Config.WellKnown.ACRValuesSupported
securityLevel, err := LoginURLParameter(r, SecurityLevelURLParameter, fallback, supported)
securityLevel, err := request.LoginURLParameter(r, request.SecurityLevelURLParameter, fallback, supported)
if err != nil {
return err
}
@@ -72,7 +73,7 @@ func (h *Handler) withLocale(r *http.Request, v url.Values) error {
fallback := h.Config.Locale.Value
supported := h.Config.WellKnown.UILocalesSupported
locale, err := LoginURLParameter(r, LocaleURLParameter, fallback, supported)
locale, err := request.LoginURLParameter(r, request.LocaleURLParameter, fallback, supported)
if err != nil {
return err
}

View File

@@ -1,28 +0,0 @@
package url
import (
"net/http"
"net/url"
)
const RedirectURLParameter = "redirect"
// CanonicalRedirectURL constructs a redirect URL that points back to the application.
func CanonicalRedirectURL(r *http.Request) string {
redirectURL := "/"
referer, err := url.Parse(r.Referer())
if err == nil && len(referer.Path) > 0 {
redirectURL = referer.Path
}
override := r.URL.Query().Get(RedirectURLParameter)
if len(override) > 0 {
referer, err = url.Parse(override)
if err == nil {
// strip scheme and host to avoid cross-domain redirects
referer.Scheme = ""
referer.Host = ""
redirectURL = referer.String()
}
}
return redirectURL
}

View File

@@ -1,26 +0,0 @@
package url
import (
"github.com/stretchr/testify/assert"
"net/http"
"net/url"
"testing"
)
func TestCanonicalRedirectURL(t *testing.T) {
r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil)
assert.NoError(t, err)
// Default URL is /
assert.Equal(t, "/", CanonicalRedirectURL(r))
// HTTP Referer header is 2nd priority
r.Header.Set("referer", "http://localhost:8080/foo/bar/baz?gnu=notunix")
assert.Equal(t, "/foo/bar/baz", CanonicalRedirectURL(r))
// If redirect parameter is set, use that
v := &url.Values{}
v.Set("redirect", "https://google.com/path/to/redirect?val1=foo&val2=bar")
r.URL.RawQuery = v.Encode()
assert.Equal(t, "/path/to/redirect?val1=foo&val2=bar", CanonicalRedirectURL(r))
}