mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
refactor(router): begin extraction of openid client
This commit is contained in:
54
pkg/mock/client.go
Normal file
54
pkg/mock/client.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func Config() *config.Config {
|
||||
return &config.Config{
|
||||
EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES
|
||||
Ingress: "/",
|
||||
OpenID: config.OpenID{
|
||||
Provider: "test",
|
||||
},
|
||||
SessionMaxLifetime: time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClient(provider openid.Provider) openid.Client {
|
||||
return openid.NewClient(*Config(), provider)
|
||||
}
|
||||
|
||||
func NewClientWithCfg(cfg *config.Config, provider openid.Provider) openid.Client {
|
||||
return openid.NewClient(*cfg, provider)
|
||||
}
|
||||
|
||||
func NewHandler(provider openid.Provider) *router.Handler {
|
||||
cfg := Config()
|
||||
return NewHandlerWithCfg(cfg, provider)
|
||||
}
|
||||
|
||||
func NewHandlerWithCfg(cfg *config.Config, provider openid.Provider) *router.Handler {
|
||||
if cfg == nil {
|
||||
cfg = Config()
|
||||
}
|
||||
|
||||
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
|
||||
sessionStore := session.NewMemory()
|
||||
|
||||
h, err := router.NewHandler(*cfg, crypter, zerolog.Logger{}, provider, sessionStore)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
h.CookieOptions = h.CookieOptions.WithSecure(false)
|
||||
return h
|
||||
}
|
||||
95
pkg/openid/client.go
Normal file
95
pkg/openid/client.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package openid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Config() config.Config
|
||||
Provider() Provider
|
||||
OAuth2Config() *oauth2.Config
|
||||
|
||||
Login(r *http.Request) (Login, error)
|
||||
LoginCallback(r *http.Request) error
|
||||
Logout(r *http.Request) error
|
||||
LogoutCallback(r *http.Request) error
|
||||
|
||||
AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
RefreshGrant(r *http.Request) error
|
||||
}
|
||||
|
||||
type client struct {
|
||||
cfg config.Config
|
||||
provider Provider
|
||||
oauth2Config *oauth2.Config
|
||||
}
|
||||
|
||||
func NewClient(cfg config.Config, provider Provider) Client {
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: provider.GetClientConfiguration().GetClientID(),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
|
||||
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
|
||||
Scopes: provider.GetClientConfiguration().GetScopes(),
|
||||
}
|
||||
|
||||
return &client{
|
||||
cfg: cfg,
|
||||
provider: provider,
|
||||
oauth2Config: oauth2Config,
|
||||
}
|
||||
}
|
||||
|
||||
func (c client) Config() config.Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c client) Provider() Provider {
|
||||
return c.provider
|
||||
}
|
||||
|
||||
func (c client) OAuth2Config() *oauth2.Config {
|
||||
return c.oauth2Config
|
||||
}
|
||||
|
||||
func (c client) Login(r *http.Request) (Login, error) {
|
||||
login, err := NewLogin(c, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
return login, nil
|
||||
}
|
||||
|
||||
func (c client) LoginCallback(r *http.Request) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c client) Logout(r *http.Request) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c client) LogoutCallback(r *http.Request) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return c.oauth2Config.Exchange(ctx, code, opts...)
|
||||
}
|
||||
|
||||
func (c client) RefreshGrant(r *http.Request) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
@@ -3,19 +3,109 @@ package openid
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
"github.com/nais/wonderwall/pkg/strings"
|
||||
)
|
||||
|
||||
type LoginParameters struct {
|
||||
const (
|
||||
LocaleURLParameter = "locale"
|
||||
SecurityLevelURLParameter = "level"
|
||||
)
|
||||
|
||||
var (
|
||||
InvalidSecurityLevelError = errors.New("InvalidSecurityLevel")
|
||||
InvalidLocaleError = errors.New("InvalidLocale")
|
||||
InvalidLoginParameterError = errors.New("InvalidLoginParameter")
|
||||
|
||||
// LoginParameterMapping maps incoming login parameters to OpenID Connect parameters
|
||||
LoginParameterMapping = map[string]string{
|
||||
LocaleURLParameter: "ui_locales",
|
||||
SecurityLevelURLParameter: "acr_values",
|
||||
}
|
||||
)
|
||||
|
||||
type Login interface {
|
||||
AuthCodeURL() string
|
||||
CanonicalRedirect() string
|
||||
CodeChallenge() string
|
||||
CodeVerifier() string
|
||||
Cookie() *LoginCookie
|
||||
Nonce() string
|
||||
State() string
|
||||
}
|
||||
|
||||
func NewLogin(c Client, r *http.Request) (Login, error) {
|
||||
params, err := newLoginParameters(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating login parameters: %w", err)
|
||||
}
|
||||
|
||||
url, err := params.authCodeURL(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating login url: %w", err)
|
||||
}
|
||||
|
||||
redirect := request.CanonicalRedirectURL(r, c.Config().Ingress)
|
||||
cookie := params.cookie(redirect)
|
||||
|
||||
return login{
|
||||
authCodeURL: url,
|
||||
canonicalRedirect: redirect,
|
||||
cookie: cookie,
|
||||
params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type login struct {
|
||||
authCodeURL string
|
||||
canonicalRedirect string
|
||||
cookie *LoginCookie
|
||||
params *loginParameters
|
||||
}
|
||||
|
||||
func (l login) CodeChallenge() string {
|
||||
return l.params.CodeChallenge
|
||||
}
|
||||
|
||||
func (l login) CodeVerifier() string {
|
||||
return l.params.CodeVerifier
|
||||
}
|
||||
|
||||
func (l login) Nonce() string {
|
||||
return l.params.Nonce
|
||||
}
|
||||
|
||||
func (l login) State() string {
|
||||
return l.params.State
|
||||
}
|
||||
|
||||
func (l login) AuthCodeURL() string {
|
||||
return l.authCodeURL
|
||||
}
|
||||
|
||||
func (l login) CanonicalRedirect() string {
|
||||
return l.canonicalRedirect
|
||||
}
|
||||
|
||||
func (l login) Cookie() *LoginCookie {
|
||||
return l.cookie
|
||||
}
|
||||
|
||||
type loginParameters struct {
|
||||
Client
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
Nonce string
|
||||
State string
|
||||
}
|
||||
|
||||
func GenerateLoginParameters() (*LoginParameters, error) {
|
||||
func newLoginParameters(c Client) (*loginParameters, error) {
|
||||
codeVerifier, err := strings.GenerateBase64(64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating code verifier: %w", err)
|
||||
@@ -31,15 +121,100 @@ func GenerateLoginParameters() (*LoginParameters, error) {
|
||||
return nil, fmt.Errorf("creating state: %w", err)
|
||||
}
|
||||
|
||||
return &LoginParameters{
|
||||
return &loginParameters{
|
||||
Client: c,
|
||||
CodeVerifier: codeVerifier,
|
||||
CodeChallenge: CodeChallenge(codeVerifier),
|
||||
CodeChallenge: codeChallenge(codeVerifier),
|
||||
Nonce: nonce,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CodeChallenge(codeVerifier string) string {
|
||||
func (in *loginParameters) authCodeURL(r *http.Request) (string, error) {
|
||||
opts := []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("scope", in.Provider().GetClientConfiguration().GetScopes().String()),
|
||||
oauth2.SetAuthURLParam("nonce", in.Nonce),
|
||||
oauth2.SetAuthURLParam("response_mode", "query"),
|
||||
oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
}
|
||||
|
||||
if in.Config().Loginstatus.NeedsResourceIndicator() {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("resource", in.Config().Loginstatus.ResourceIndicator))
|
||||
}
|
||||
|
||||
opts, err := in.withSecurityLevel(r, opts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
|
||||
}
|
||||
|
||||
opts, err = in.withLocale(r, opts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
|
||||
}
|
||||
|
||||
authCodeUrl := in.OAuth2Config().AuthCodeURL(in.State, opts...)
|
||||
return authCodeUrl, nil
|
||||
}
|
||||
|
||||
func (in *loginParameters) cookie(redirect string) *LoginCookie {
|
||||
return &LoginCookie{
|
||||
State: in.State,
|
||||
Nonce: in.Nonce,
|
||||
CodeVerifier: in.CodeVerifier,
|
||||
Referer: redirect,
|
||||
}
|
||||
}
|
||||
|
||||
func (in *loginParameters) withLocale(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) {
|
||||
return withParamMapping(r,
|
||||
opts,
|
||||
LocaleURLParameter,
|
||||
in.Provider().GetClientConfiguration().GetUILocales(),
|
||||
in.Provider().GetOpenIDConfiguration().UILocalesSupported,
|
||||
)
|
||||
}
|
||||
|
||||
func (in *loginParameters) withSecurityLevel(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) {
|
||||
return withParamMapping(r,
|
||||
opts,
|
||||
SecurityLevelURLParameter,
|
||||
in.Provider().GetClientConfiguration().GetACRValues(),
|
||||
in.Provider().GetOpenIDConfiguration().ACRValuesSupported,
|
||||
)
|
||||
}
|
||||
|
||||
func withParamMapping(r *http.Request, opts []oauth2.AuthCodeOption, param, fallback string, supported Supported) ([]oauth2.AuthCodeOption, error) {
|
||||
if len(fallback) == 0 {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
value, err := LoginURLParameter(r, param, fallback, supported)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[param], value))
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
|
||||
// The value must exist in the supplied list of supported values.
|
||||
func LoginURLParameter(r *http.Request, parameter, fallback string, supported Supported) (string, error) {
|
||||
value := r.URL.Query().Get(parameter)
|
||||
|
||||
if len(value) == 0 {
|
||||
value = fallback
|
||||
}
|
||||
|
||||
if supported.Contains(value) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
|
||||
}
|
||||
|
||||
func codeChallenge(codeVerifier string) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(codeVerifier))
|
||||
codeVerifierHash := hasher.Sum(nil)
|
||||
|
||||
21
pkg/openid/login_callback.go
Normal file
21
pkg/openid/login_callback.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openid
|
||||
|
||||
type LoginCallback struct {
|
||||
Client
|
||||
}
|
||||
|
||||
func (in LoginCallback) IdentityProviderError() (bool, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (in LoginCallback) ValidateRequest() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (in LoginCallback) RedeemCode() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (in LoginCallback) ParseAndValidateTokens() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
205
pkg/openid/login_test.go
Normal file
205
pkg/openid/login_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package openid_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
|
||||
func TestLogin_URL(t *testing.T) {
|
||||
type loginURLTest struct {
|
||||
url string
|
||||
extraParams map[string]string
|
||||
error error
|
||||
}
|
||||
|
||||
tests := []loginURLTest{
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=Level4",
|
||||
extraParams: map[string]string{
|
||||
"acr_values": "Level4",
|
||||
},
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login",
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=NoLevel",
|
||||
error: openid.InvalidSecurityLevelError,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?locale=nb",
|
||||
extraParams: map[string]string{
|
||||
"ui_locales": "nb",
|
||||
},
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb",
|
||||
extraParams: map[string]string{
|
||||
"acr_values": "Level4",
|
||||
"ui_locales": "nb",
|
||||
},
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?locale=es",
|
||||
error: openid.InvalidLocaleError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.url, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider := mock.NewTestProvider()
|
||||
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
|
||||
client := mock.NewClient(provider)
|
||||
result, err := client.Login(req)
|
||||
|
||||
if test.error != nil {
|
||||
assert.True(t, errors.Is(err, test.error))
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
parsed, err := url.Parse(result.AuthCodeURL())
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := parsed.Query()
|
||||
assert.Contains(t, query, "response_type")
|
||||
assert.Contains(t, query, "client_id")
|
||||
assert.Contains(t, query, "redirect_uri")
|
||||
assert.Contains(t, query, "scope")
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Contains(t, query, "nonce")
|
||||
assert.Contains(t, query, "response_mode")
|
||||
assert.Contains(t, query, "code_challenge")
|
||||
assert.Contains(t, query, "code_challenge_method")
|
||||
assert.NotContains(t, query, "resource")
|
||||
|
||||
assert.ElementsMatch(t, query["response_type"], []string{"code"})
|
||||
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
|
||||
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
|
||||
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
|
||||
assert.ElementsMatch(t, query["state"], []string{result.State()})
|
||||
assert.ElementsMatch(t, query["nonce"], []string{result.Nonce()})
|
||||
assert.ElementsMatch(t, query["response_mode"], []string{"query"})
|
||||
assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()})
|
||||
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
|
||||
|
||||
if test.extraParams != nil {
|
||||
for key, value := range test.extraParams {
|
||||
assert.Contains(t, query, key)
|
||||
assert.ElementsMatch(t, query[key], []string{value})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginURL_WithResourceIndicator(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider := mock.NewTestProvider()
|
||||
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
|
||||
cfg := mock.Config()
|
||||
cfg.Loginstatus.Enabled = true
|
||||
cfg.Loginstatus.ResourceIndicator = "https://some-resource"
|
||||
client := mock.NewClientWithCfg(cfg, provider)
|
||||
result, err := client.Login(req)
|
||||
|
||||
assert.NotEmpty(t, result)
|
||||
parsed, err := url.Parse(result.AuthCodeURL())
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := parsed.Query()
|
||||
assert.Contains(t, query, "resource")
|
||||
assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"})
|
||||
}
|
||||
|
||||
func TestLoginURLParameter(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
parameter string
|
||||
fallback string
|
||||
supported openid.Supported
|
||||
url string
|
||||
expectErr error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no URL parameter should use fallback value",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
expected: "valid",
|
||||
},
|
||||
{
|
||||
name: "non-matching URL parameter should be ignored",
|
||||
url: "http://localhost:8080/oauth2/login?other_param=value2",
|
||||
expected: "valid",
|
||||
},
|
||||
{
|
||||
name: "matching URL parameter should take precedence",
|
||||
url: "http://localhost:8080/oauth2/login?param=valid2",
|
||||
expected: "valid2",
|
||||
},
|
||||
{
|
||||
name: "invalid URL parameter value should return error",
|
||||
url: "http://localhost:8080/oauth2/login?param=invalid",
|
||||
expectErr: openid.InvalidLoginParameterError,
|
||||
},
|
||||
{
|
||||
name: "invalid fallback value should return error",
|
||||
fallback: "invalid",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
expectErr: openid.InvalidLoginParameterError,
|
||||
},
|
||||
{
|
||||
name: "no supported values should return error",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
supported: openid.Supported{""},
|
||||
expectErr: openid.InvalidLoginParameterError,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// default test values
|
||||
parameter := "param"
|
||||
fallback := "valid"
|
||||
supported := openid.Supported{"valid", "valid2"}
|
||||
|
||||
if len(test.parameter) > 0 {
|
||||
parameter = test.parameter
|
||||
}
|
||||
|
||||
if len(test.fallback) > 0 {
|
||||
fallback = test.fallback
|
||||
}
|
||||
|
||||
if len(test.supported) > 0 {
|
||||
supported = test.supported
|
||||
}
|
||||
|
||||
val, err := openid.LoginURLParameter(r, parameter, fallback, supported)
|
||||
|
||||
if test.expectErr == nil {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expected, val)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
pkg/openid/logout.go
Normal file
13
pkg/openid/logout.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package openid
|
||||
|
||||
type Logout struct {
|
||||
Client
|
||||
}
|
||||
|
||||
func (in Logout) URL() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (in Logout) Cookie() LogoutCookie {
|
||||
panic("not implemented")
|
||||
}
|
||||
9
pkg/openid/logout_callback.go
Normal file
9
pkg/openid/logout_callback.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package openid
|
||||
|
||||
type LogoutCallback struct {
|
||||
Client
|
||||
}
|
||||
|
||||
func (in LogoutCallback) ValidateRequest() (bool, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -2,10 +2,8 @@ package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
@@ -16,14 +14,13 @@ import (
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Client openid.Client
|
||||
Config config.Config
|
||||
CookieOptions cookie.Options
|
||||
Crypter crypto.Crypter
|
||||
OauthConfig oauth2.Config
|
||||
Loginstatus loginstatus.Client
|
||||
Provider openid.Provider
|
||||
Sessions session.Store
|
||||
lock sync.Mutex
|
||||
Httplogger zerolog.Logger
|
||||
}
|
||||
|
||||
@@ -34,28 +31,19 @@ func NewHandler(
|
||||
provider openid.Provider,
|
||||
sessionStore session.Store,
|
||||
) (*Handler, error) {
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: provider.GetClientConfiguration().GetClientID(),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
|
||||
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
|
||||
},
|
||||
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
|
||||
Scopes: provider.GetClientConfiguration().GetScopes(),
|
||||
}
|
||||
client := openid.NewClient(cfg, provider)
|
||||
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)
|
||||
|
||||
cookiePath := config.ParseIngress(cfg.Ingress)
|
||||
cookieOpts := cookie.DefaultOptions().WithPath(cookiePath)
|
||||
|
||||
return &Handler{
|
||||
Client: client,
|
||||
Config: cfg,
|
||||
CookieOptions: cookieOpts,
|
||||
Crypter: crypter,
|
||||
Httplogger: httplogger,
|
||||
lock: sync.Mutex{},
|
||||
Loginstatus: loginstatusClient,
|
||||
OauthConfig: oauthConfig,
|
||||
Provider: provider,
|
||||
Sessions: sessionStore,
|
||||
}, nil
|
||||
|
||||
@@ -112,7 +112,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid.
|
||||
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
|
||||
}
|
||||
|
||||
tokens, err = h.OauthConfig.Exchange(ctx, code, opts...)
|
||||
tokens, err = h.Client.AuthCodeGrant(ctx, code, opts)
|
||||
if err != nil {
|
||||
log.Warnf("callback: exchanging authorization code for token; retrying: %+v", err)
|
||||
return retry.RetryableError(err)
|
||||
@@ -154,7 +154,7 @@ func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) {
|
||||
"claims": tokens.Claims(),
|
||||
}
|
||||
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
logger.Info().Msg("callback: successful login")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,16 +2,21 @@ package router
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/rs/zerolog"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
logentry "github.com/nais/wonderwall/pkg/router/middleware"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
)
|
||||
|
||||
@@ -58,7 +63,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s
|
||||
|
||||
errorPage := ErrorPage{
|
||||
CorrelationID: middleware.GetReqID(r.Context()),
|
||||
RetryURI: request.RetryURI(r, h.Config.Ingress, loginCookie),
|
||||
RetryURI: RetryURI(r, h.Config.Ingress, loginCookie),
|
||||
}
|
||||
err = errorTemplate.Execute(w, errorPage)
|
||||
if err != nil {
|
||||
@@ -97,3 +102,26 @@ func (h *Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error
|
||||
func (h *Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
h.respondError(w, r, http.StatusUnauthorized, cause, zerolog.WarnLevel)
|
||||
}
|
||||
|
||||
// RetryURI returns a URI that should retry the desired route that failed.
|
||||
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
|
||||
// are related to the authentication flow, we default to redirecting back to the handled
|
||||
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
|
||||
func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
|
||||
retryURI := r.URL.Path
|
||||
prefix := config.ParseIngress(ingress)
|
||||
|
||||
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
|
||||
return prefix + retryURI
|
||||
}
|
||||
|
||||
redirect := request.CanonicalRedirectURL(r, ingress)
|
||||
|
||||
if loginCookie != nil && len(loginCookie.Referer) > 0 {
|
||||
redirect = loginCookie.Referer
|
||||
}
|
||||
|
||||
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
|
||||
retryURI = retryURI + fmt.Sprintf("?%s=%s", request.RedirectURLParameter, redirect)
|
||||
return retryURI
|
||||
}
|
||||
|
||||
172
pkg/router/handler_error_test.go
Normal file
172
pkg/router/handler_error_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
)
|
||||
|
||||
func TestRetryURI(t *testing.T) {
|
||||
httpRequest := func(url string, referer ...string) *http.Request {
|
||||
req, _ := http.NewRequest(http.MethodGet, url, nil)
|
||||
if len(referer) > 0 {
|
||||
req.Header.Add("Referer", referer[0])
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
ingress string
|
||||
loginCookie *openid.LoginCookie
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "login path",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "callback path",
|
||||
request: httpRequest("/oauth2/callback"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "logout path",
|
||||
request: httpRequest("/oauth2/logout"),
|
||||
want: "/oauth2/logout",
|
||||
},
|
||||
{
|
||||
name: "front-channel logout path",
|
||||
request: httpRequest("/oauth2/logout/frontchannel"),
|
||||
want: "/oauth2/logout/frontchannel",
|
||||
},
|
||||
{
|
||||
name: "login with non-default ingress",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/domene",
|
||||
},
|
||||
{
|
||||
name: "logout with non-default ingress",
|
||||
request: httpRequest("/oauth2/logout"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/logout",
|
||||
},
|
||||
{
|
||||
name: "login with referer",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with root referer",
|
||||
request: httpRequest("/oauth2/login", "/"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with root referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login", "/"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/"},
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with empty cookie referer",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: ""},
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/api/headers"},
|
||||
want: "/oauth2/login?redirect=/api/headers",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/domene/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me"),
|
||||
want: "/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set and query parameters",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"),
|
||||
want: "/oauth2/login?redirect=/api/me?a=b&c=d",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set on non-default ingress",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/other", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/other",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to relative root takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer takes precedence over redirect parameter",
|
||||
request: httpRequest("/oauth2/login?redirect=/other"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
|
||||
want: "/oauth2/login?redirect=/domene/api/me",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if len(test.ingress) == 0 {
|
||||
test.ingress = "/"
|
||||
}
|
||||
|
||||
retryURI := router.RetryURI(test.request, test.ingress, test.loginCookie)
|
||||
assert.Equal(t, test.want, retryURI)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
logentry "github.com/nais/wonderwall/pkg/router/middleware"
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,44 +17,30 @@ const (
|
||||
)
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := openid.GenerateLoginParameters()
|
||||
login, err := h.Client.Login(r)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
loginURL, err := h.LoginURL(r, params)
|
||||
if err != nil {
|
||||
cause := fmt.Errorf("login: creating login URL: %w", err)
|
||||
|
||||
if errors.Is(err, InvalidSecurityLevelError) || errors.Is(err, InvalidLocaleError) {
|
||||
h.BadRequest(w, r, cause)
|
||||
if errors.Is(err, openid.InvalidSecurityLevelError) || errors.Is(err, openid.InvalidLocaleError) {
|
||||
h.BadRequest(w, r, err)
|
||||
} else {
|
||||
h.InternalError(w, r, cause)
|
||||
h.InternalError(w, r, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
redirect := request.CanonicalRedirectURL(r, h.Config.Ingress)
|
||||
err = h.setLoginCookies(w, &openid.LoginCookie{
|
||||
State: params.State,
|
||||
Nonce: params.Nonce,
|
||||
CodeVerifier: params.CodeVerifier,
|
||||
Referer: redirect,
|
||||
})
|
||||
err = h.setLoginCookies(w, login.Cookie())
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"redirect_to": redirect,
|
||||
"redirect_after_login": login.CanonicalRedirect(),
|
||||
}
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
logger.Info().Msg("login: redirecting to identity provider")
|
||||
|
||||
http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) getLoginCookie(r *http.Request) (*openid.LoginCookie, error) {
|
||||
|
||||
@@ -17,12 +17,6 @@ import (
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
var idToken string
|
||||
|
||||
sessionData, err := h.getSessionFromCookie(w, r)
|
||||
@@ -37,7 +31,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
fields := map[string]interface{}{
|
||||
"claims": sessionData.Claims,
|
||||
}
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
logger.Info().Msg("logout: successful local logout")
|
||||
}
|
||||
|
||||
@@ -47,6 +41,12 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
h.Loginstatus.ClearCookie(w, h.CookieOptions)
|
||||
}
|
||||
|
||||
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
logoutCookie, err := h.logoutCookie()
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err))
|
||||
@@ -72,7 +72,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
fields := map[string]interface{}{
|
||||
"redirect_to": logoutCookie.RedirectTo,
|
||||
}
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger := logentry.LogEntryWithFields(r.Context(), fields)
|
||||
logger.Info().Msg("logout: redirecting to identity provider")
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
actualState := params.Get("state")
|
||||
|
||||
if expectedState != actualState {
|
||||
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState)
|
||||
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState)
|
||||
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
request2 "github.com/nais/wonderwall/pkg/router/request"
|
||||
)
|
||||
|
||||
var (
|
||||
InvalidSecurityLevelError = errors.New("InvalidSecurityLevel")
|
||||
InvalidLocaleError = errors.New("InvalidLocale")
|
||||
)
|
||||
|
||||
func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (string, error) {
|
||||
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().AuthorizationEndpoint)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
v := u.Query()
|
||||
v.Add("response_type", "code")
|
||||
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
|
||||
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI())
|
||||
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
|
||||
v.Add("state", params.State)
|
||||
v.Add("nonce", params.Nonce)
|
||||
v.Add("response_mode", "query")
|
||||
v.Add("code_challenge", params.CodeChallenge)
|
||||
v.Add("code_challenge_method", "S256")
|
||||
|
||||
if h.Config.Loginstatus.NeedsResourceIndicator() {
|
||||
v.Add("resource", h.Config.Loginstatus.ResourceIndicator)
|
||||
}
|
||||
|
||||
err = h.withSecurityLevel(r, v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
|
||||
}
|
||||
|
||||
err = h.withLocale(r, v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
|
||||
}
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
|
||||
acrValues := h.Provider.GetClientConfiguration().GetACRValues()
|
||||
if len(acrValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fallback := acrValues
|
||||
supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported
|
||||
|
||||
securityLevel, err := request2.LoginURLParameter(r, request2.SecurityLevelURLParameter, fallback, supported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.Add("acr_values", securityLevel)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) withLocale(r *http.Request, v url.Values) error {
|
||||
uiLocales := h.Provider.GetClientConfiguration().GetUILocales()
|
||||
if len(uiLocales) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fallback := uiLocales
|
||||
supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported
|
||||
|
||||
locale, err := request2.LoginURLParameter(r, request2.LocaleURLParameter, fallback, supported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.Add("ui_locales", locale)
|
||||
return nil
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
)
|
||||
|
||||
func TestLoginURL(t *testing.T) {
|
||||
type loginURLTest struct {
|
||||
url string
|
||||
error error
|
||||
}
|
||||
|
||||
tests := []loginURLTest{
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=Level4",
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login",
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=NoLevel",
|
||||
error: router.InvalidSecurityLevelError,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?locale=nb",
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb",
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
url: "http://localhost:1234/oauth2/login?locale=es",
|
||||
error: router.InvalidLocaleError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.url, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params, err := openid.GenerateLoginParameters()
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider := mock.NewTestProvider()
|
||||
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
|
||||
handler := newHandler(provider)
|
||||
result, err := handler.LoginURL(req, params)
|
||||
|
||||
if test.error != nil {
|
||||
assert.True(t, errors.Is(err, test.error))
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
parsed, err := url.Parse(result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := parsed.Query()
|
||||
assert.Contains(t, query, "response_type")
|
||||
assert.Contains(t, query, "client_id")
|
||||
assert.Contains(t, query, "redirect_uri")
|
||||
assert.Contains(t, query, "scope")
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Contains(t, query, "nonce")
|
||||
assert.Contains(t, query, "response_mode")
|
||||
assert.Contains(t, query, "code_challenge")
|
||||
assert.Contains(t, query, "code_challenge_method")
|
||||
assert.NotContains(t, query, "resource")
|
||||
|
||||
assert.ElementsMatch(t, query["response_type"], []string{"code"})
|
||||
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
|
||||
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
|
||||
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
|
||||
assert.ElementsMatch(t, query["state"], []string{params.State})
|
||||
assert.ElementsMatch(t, query["nonce"], []string{params.Nonce})
|
||||
assert.ElementsMatch(t, query["response_mode"], []string{"query"})
|
||||
assert.ElementsMatch(t, query["code_challenge"], []string{params.CodeChallenge})
|
||||
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginURL_WithResourceIndicator(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params, err := openid.GenerateLoginParameters()
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider := mock.NewTestProvider()
|
||||
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
|
||||
handler := newHandler(provider)
|
||||
handler.Config.Loginstatus.Enabled = true
|
||||
handler.Config.Loginstatus.ResourceIndicator = "https://some-resource"
|
||||
result, err := handler.LoginURL(req, params)
|
||||
|
||||
assert.NotEmpty(t, result)
|
||||
parsed, err := url.Parse(result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := parsed.Query()
|
||||
assert.Contains(t, query, "resource")
|
||||
assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"})
|
||||
}
|
||||
@@ -52,6 +52,10 @@ func LogEntry(ctx context.Context) zerolog.Logger {
|
||||
return httplog.NewLogger("wonderwall")
|
||||
}
|
||||
|
||||
func LogEntryWithFields(ctx context.Context, fields any) zerolog.Logger {
|
||||
return LogEntry(ctx).With().Fields(fields).Logger()
|
||||
}
|
||||
|
||||
type requestLogger struct {
|
||||
Logger zerolog.Logger
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package request
|
||||
|
||||
const (
|
||||
LocaleURLParameter = "locale"
|
||||
RedirectURLParameter = "redirect"
|
||||
SecurityLevelURLParameter = "level"
|
||||
)
|
||||
@@ -1,19 +1,14 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
|
||||
var (
|
||||
InvalidLoginParameterError = errors.New("InvalidLoginParameter")
|
||||
const (
|
||||
RedirectURLParameter = "redirect"
|
||||
)
|
||||
|
||||
// CanonicalRedirectURL constructs a redirect URL that points back to the application.
|
||||
@@ -78,22 +73,6 @@ func parseRedirectParam(r *http.Request) (string, bool) {
|
||||
return redirectParamURLString, true
|
||||
}
|
||||
|
||||
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
|
||||
// The value must exist in the supplied list of supported values.
|
||||
func LoginURLParameter(r *http.Request, parameter, fallback string, supported openid.Supported) (string, error) {
|
||||
value := r.URL.Query().Get(parameter)
|
||||
|
||||
if len(value) == 0 {
|
||||
value = fallback
|
||||
}
|
||||
|
||||
if supported.Contains(value) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
|
||||
}
|
||||
|
||||
func refererPath(r *http.Request) string {
|
||||
if len(r.Referer()) == 0 {
|
||||
return ""
|
||||
@@ -109,26 +88,3 @@ func refererPath(r *http.Request) string {
|
||||
referer.Host = ""
|
||||
return referer.String()
|
||||
}
|
||||
|
||||
// RetryURI returns a URI that should retry the desired route that failed.
|
||||
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
|
||||
// are related to the authentication flow, we default to redirecting back to the handled
|
||||
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
|
||||
func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
|
||||
retryURI := r.URL.Path
|
||||
prefix := config.ParseIngress(ingress)
|
||||
|
||||
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
|
||||
return prefix + retryURI
|
||||
}
|
||||
|
||||
redirect := CanonicalRedirectURL(r, ingress)
|
||||
|
||||
if loginCookie != nil && len(loginCookie.Referer) > 0 {
|
||||
redirect = loginCookie.Referer
|
||||
}
|
||||
|
||||
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
|
||||
retryURI = retryURI + fmt.Sprintf("?%s=%s", RedirectURLParameter, redirect)
|
||||
return retryURI
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
)
|
||||
|
||||
@@ -138,240 +137,3 @@ func TestCanonicalRedirectURL(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginURLParameter(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
parameter string
|
||||
fallback string
|
||||
supported openid.Supported
|
||||
url string
|
||||
expectErr error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no URL parameter should use fallback value",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
expected: "valid",
|
||||
},
|
||||
{
|
||||
name: "non-matching URL parameter should be ignored",
|
||||
url: "http://localhost:8080/oauth2/login?other_param=value2",
|
||||
expected: "valid",
|
||||
},
|
||||
{
|
||||
name: "matching URL parameter should take precedence",
|
||||
url: "http://localhost:8080/oauth2/login?param=valid2",
|
||||
expected: "valid2",
|
||||
},
|
||||
{
|
||||
name: "invalid URL parameter value should return error",
|
||||
url: "http://localhost:8080/oauth2/login?param=invalid",
|
||||
expectErr: request.InvalidLoginParameterError,
|
||||
},
|
||||
{
|
||||
name: "invalid fallback value should return error",
|
||||
fallback: "invalid",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
expectErr: request.InvalidLoginParameterError,
|
||||
},
|
||||
{
|
||||
name: "no supported values should return error",
|
||||
url: "http://localhost:8080/oauth2/login",
|
||||
supported: openid.Supported{""},
|
||||
expectErr: request.InvalidLoginParameterError,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// default test values
|
||||
parameter := "param"
|
||||
fallback := "valid"
|
||||
supported := openid.Supported{"valid", "valid2"}
|
||||
|
||||
if len(test.parameter) > 0 {
|
||||
parameter = test.parameter
|
||||
}
|
||||
|
||||
if len(test.fallback) > 0 {
|
||||
fallback = test.fallback
|
||||
}
|
||||
|
||||
if len(test.supported) > 0 {
|
||||
supported = test.supported
|
||||
}
|
||||
|
||||
val, err := request.LoginURLParameter(r, parameter, fallback, supported)
|
||||
|
||||
if test.expectErr == nil {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expected, val)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryURI(t *testing.T) {
|
||||
httpRequest := func(url string, referer ...string) *http.Request {
|
||||
req, _ := http.NewRequest(http.MethodGet, url, nil)
|
||||
if len(referer) > 0 {
|
||||
req.Header.Add("Referer", referer[0])
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
ingress string
|
||||
loginCookie *openid.LoginCookie
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "login path",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "callback path",
|
||||
request: httpRequest("/oauth2/callback"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "logout path",
|
||||
request: httpRequest("/oauth2/logout"),
|
||||
want: "/oauth2/logout",
|
||||
},
|
||||
{
|
||||
name: "front-channel logout path",
|
||||
request: httpRequest("/oauth2/logout/frontchannel"),
|
||||
want: "/oauth2/logout/frontchannel",
|
||||
},
|
||||
{
|
||||
name: "login with non-default ingress",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/domene",
|
||||
},
|
||||
{
|
||||
name: "logout with non-default ingress",
|
||||
request: httpRequest("/oauth2/logout"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/logout",
|
||||
},
|
||||
{
|
||||
name: "login with referer",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with root referer",
|
||||
request: httpRequest("/oauth2/login", "/"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with root referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login", "/"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/"},
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with empty cookie referer",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: ""},
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login", "/api/me"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/api/headers"},
|
||||
want: "/oauth2/login?redirect=/api/headers",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer on non-default ingress",
|
||||
request: httpRequest("/oauth2/login"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/domene/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me"),
|
||||
want: "/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set and query parameters",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"),
|
||||
want: "/oauth2/login?redirect=/api/me?a=b&c=d",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set on non-default ingress",
|
||||
request: httpRequest("/oauth2/login?redirect=/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/api/me",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/other", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/other",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to relative root takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
|
||||
want: "/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header",
|
||||
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
|
||||
ingress: "https://test.nav.no/domene",
|
||||
want: "/domene/oauth2/login?redirect=/",
|
||||
},
|
||||
{
|
||||
name: "login with cookie referer takes precedence over redirect parameter",
|
||||
request: httpRequest("/oauth2/login?redirect=/other"),
|
||||
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
|
||||
want: "/oauth2/login?redirect=/domene/api/me",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if len(test.ingress) == 0 {
|
||||
test.ingress = "/"
|
||||
}
|
||||
|
||||
retryURI := request.RetryURI(test.request, test.ingress, test.loginCookie)
|
||||
assert.Equal(t, test.want, retryURI)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,45 +8,17 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
var cfg = config.Config{
|
||||
EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES
|
||||
Ingress: "/",
|
||||
OpenID: config.OpenID{
|
||||
Provider: "test",
|
||||
},
|
||||
SessionMaxLifetime: time.Hour,
|
||||
}
|
||||
|
||||
func newHandler(provider openid.Provider) *router.Handler {
|
||||
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
|
||||
sessionStore := session.NewMemory()
|
||||
|
||||
h, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, provider, sessionStore)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
h.CookieOptions = h.CookieOptions.WithSecure(false)
|
||||
return h
|
||||
}
|
||||
|
||||
func TestHandler_Login(t *testing.T) {
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
h := newHandler(idp)
|
||||
h := mock.NewHandler(idp)
|
||||
r := router.New(h)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
@@ -103,13 +75,14 @@ func TestHandler_Login(t *testing.T) {
|
||||
func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
|
||||
h := newHandler(idp)
|
||||
h := mock.NewHandler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback"
|
||||
h.Client = mock.NewClient(idp)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -239,12 +212,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
_, idp := mock.IdentityProviderServer()
|
||||
idp.WithFrontChannelLogoutSupport()
|
||||
|
||||
h := newHandler(idp)
|
||||
h := mock.NewHandler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
h.Client = mock.NewClient(idp)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -313,12 +287,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
func TestHandler_SessionStateRequired(t *testing.T) {
|
||||
idpServer, idp := mock.IdentityProviderServer()
|
||||
idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession")
|
||||
h := newHandler(idp)
|
||||
h := mock.NewHandler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
h.Client = mock.NewClient(idp)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
p := mock.NewTestProvider()
|
||||
h := newHandler(p)
|
||||
h := mock.NewHandler(p)
|
||||
tokens := makeTokens(p)
|
||||
|
||||
t.Run("request without fallback session cookies", func(t *testing.T) {
|
||||
@@ -47,7 +47,7 @@ func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
|
||||
func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
provider := mock.NewTestProvider()
|
||||
h := newHandler(provider)
|
||||
h := mock.NewHandler(provider)
|
||||
|
||||
// request should set session cookies in response
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -82,7 +82,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
|
||||
func TestHandler_DeleteSessionFallback(t *testing.T) {
|
||||
p := mock.NewTestProvider()
|
||||
h := newHandler(p)
|
||||
h := mock.NewHandler(p)
|
||||
tokens := makeTokens(p)
|
||||
|
||||
t.Run("expire cookies if they are set", func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user