refactor(router): begin extraction of openid client

This commit is contained in:
Trong Huu Nguyen
2022-06-21 10:48:54 +02:00
parent d1559f5479
commit 10dddd00bc
23 changed files with 816 additions and 586 deletions

54
pkg/mock/client.go Normal file
View File

@@ -0,0 +1,54 @@
package mock
import (
"time"
"github.com/rs/zerolog"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
)
func Config() *config.Config {
return &config.Config{
EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES
Ingress: "/",
OpenID: config.OpenID{
Provider: "test",
},
SessionMaxLifetime: time.Hour,
}
}
func NewClient(provider openid.Provider) openid.Client {
return openid.NewClient(*Config(), provider)
}
func NewClientWithCfg(cfg *config.Config, provider openid.Provider) openid.Client {
return openid.NewClient(*cfg, provider)
}
func NewHandler(provider openid.Provider) *router.Handler {
cfg := Config()
return NewHandlerWithCfg(cfg, provider)
}
func NewHandlerWithCfg(cfg *config.Config, provider openid.Provider) *router.Handler {
if cfg == nil {
cfg = Config()
}
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
sessionStore := session.NewMemory()
h, err := router.NewHandler(*cfg, crypter, zerolog.Logger{}, provider, sessionStore)
if err != nil {
panic(err)
}
h.CookieOptions = h.CookieOptions.WithSecure(false)
return h
}

95
pkg/openid/client.go Normal file
View File

@@ -0,0 +1,95 @@
package openid
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/config"
)
type Client interface {
Config() config.Config
Provider() Provider
OAuth2Config() *oauth2.Config
Login(r *http.Request) (Login, error)
LoginCallback(r *http.Request) error
Logout(r *http.Request) error
LogoutCallback(r *http.Request) error
AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error)
RefreshGrant(r *http.Request) error
}
type client struct {
cfg config.Config
provider Provider
oauth2Config *oauth2.Config
}
func NewClient(cfg config.Config, provider Provider) Client {
oauth2Config := &oauth2.Config{
ClientID: provider.GetClientConfiguration().GetClientID(),
Endpoint: oauth2.Endpoint{
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
Scopes: provider.GetClientConfiguration().GetScopes(),
}
return &client{
cfg: cfg,
provider: provider,
oauth2Config: oauth2Config,
}
}
func (c client) Config() config.Config {
return c.cfg
}
func (c client) Provider() Provider {
return c.provider
}
func (c client) OAuth2Config() *oauth2.Config {
return c.oauth2Config
}
func (c client) Login(r *http.Request) (Login, error) {
login, err := NewLogin(c, r)
if err != nil {
return nil, fmt.Errorf("login: %w", err)
}
return login, nil
}
func (c client) LoginCallback(r *http.Request) error {
//TODO implement me
panic("implement me")
}
func (c client) Logout(r *http.Request) error {
//TODO implement me
panic("implement me")
}
func (c client) LogoutCallback(r *http.Request) error {
//TODO implement me
panic("implement me")
}
func (c client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) {
return c.oauth2Config.Exchange(ctx, code, opts...)
}
func (c client) RefreshGrant(r *http.Request) error {
//TODO implement me
panic("implement me")
}

View File

@@ -3,19 +3,109 @@ package openid
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/router/request"
"github.com/nais/wonderwall/pkg/strings"
)
type LoginParameters struct {
const (
LocaleURLParameter = "locale"
SecurityLevelURLParameter = "level"
)
var (
InvalidSecurityLevelError = errors.New("InvalidSecurityLevel")
InvalidLocaleError = errors.New("InvalidLocale")
InvalidLoginParameterError = errors.New("InvalidLoginParameter")
// LoginParameterMapping maps incoming login parameters to OpenID Connect parameters
LoginParameterMapping = map[string]string{
LocaleURLParameter: "ui_locales",
SecurityLevelURLParameter: "acr_values",
}
)
type Login interface {
AuthCodeURL() string
CanonicalRedirect() string
CodeChallenge() string
CodeVerifier() string
Cookie() *LoginCookie
Nonce() string
State() string
}
func NewLogin(c Client, r *http.Request) (Login, error) {
params, err := newLoginParameters(c)
if err != nil {
return nil, fmt.Errorf("generating login parameters: %w", err)
}
url, err := params.authCodeURL(r)
if err != nil {
return nil, fmt.Errorf("generating login url: %w", err)
}
redirect := request.CanonicalRedirectURL(r, c.Config().Ingress)
cookie := params.cookie(redirect)
return login{
authCodeURL: url,
canonicalRedirect: redirect,
cookie: cookie,
params: params,
}, nil
}
type login struct {
authCodeURL string
canonicalRedirect string
cookie *LoginCookie
params *loginParameters
}
func (l login) CodeChallenge() string {
return l.params.CodeChallenge
}
func (l login) CodeVerifier() string {
return l.params.CodeVerifier
}
func (l login) Nonce() string {
return l.params.Nonce
}
func (l login) State() string {
return l.params.State
}
func (l login) AuthCodeURL() string {
return l.authCodeURL
}
func (l login) CanonicalRedirect() string {
return l.canonicalRedirect
}
func (l login) Cookie() *LoginCookie {
return l.cookie
}
type loginParameters struct {
Client
CodeVerifier string
CodeChallenge string
Nonce string
State string
}
func GenerateLoginParameters() (*LoginParameters, error) {
func newLoginParameters(c Client) (*loginParameters, error) {
codeVerifier, err := strings.GenerateBase64(64)
if err != nil {
return nil, fmt.Errorf("creating code verifier: %w", err)
@@ -31,15 +121,100 @@ func GenerateLoginParameters() (*LoginParameters, error) {
return nil, fmt.Errorf("creating state: %w", err)
}
return &LoginParameters{
return &loginParameters{
Client: c,
CodeVerifier: codeVerifier,
CodeChallenge: CodeChallenge(codeVerifier),
CodeChallenge: codeChallenge(codeVerifier),
Nonce: nonce,
State: state,
}, nil
}
func CodeChallenge(codeVerifier string) string {
func (in *loginParameters) authCodeURL(r *http.Request) (string, error) {
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("scope", in.Provider().GetClientConfiguration().GetScopes().String()),
oauth2.SetAuthURLParam("nonce", in.Nonce),
oauth2.SetAuthURLParam("response_mode", "query"),
oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
}
if in.Config().Loginstatus.NeedsResourceIndicator() {
opts = append(opts, oauth2.SetAuthURLParam("resource", in.Config().Loginstatus.ResourceIndicator))
}
opts, err := in.withSecurityLevel(r, opts)
if err != nil {
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
}
opts, err = in.withLocale(r, opts)
if err != nil {
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
}
authCodeUrl := in.OAuth2Config().AuthCodeURL(in.State, opts...)
return authCodeUrl, nil
}
func (in *loginParameters) cookie(redirect string) *LoginCookie {
return &LoginCookie{
State: in.State,
Nonce: in.Nonce,
CodeVerifier: in.CodeVerifier,
Referer: redirect,
}
}
func (in *loginParameters) withLocale(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) {
return withParamMapping(r,
opts,
LocaleURLParameter,
in.Provider().GetClientConfiguration().GetUILocales(),
in.Provider().GetOpenIDConfiguration().UILocalesSupported,
)
}
func (in *loginParameters) withSecurityLevel(r *http.Request, opts []oauth2.AuthCodeOption) ([]oauth2.AuthCodeOption, error) {
return withParamMapping(r,
opts,
SecurityLevelURLParameter,
in.Provider().GetClientConfiguration().GetACRValues(),
in.Provider().GetOpenIDConfiguration().ACRValuesSupported,
)
}
func withParamMapping(r *http.Request, opts []oauth2.AuthCodeOption, param, fallback string, supported Supported) ([]oauth2.AuthCodeOption, error) {
if len(fallback) == 0 {
return opts, nil
}
value, err := LoginURLParameter(r, param, fallback, supported)
if err != nil {
return nil, err
}
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[param], value))
return opts, nil
}
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
// The value must exist in the supplied list of supported values.
func LoginURLParameter(r *http.Request, parameter, fallback string, supported Supported) (string, error) {
value := r.URL.Query().Get(parameter)
if len(value) == 0 {
value = fallback
}
if supported.Contains(value) {
return value, nil
}
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
}
func codeChallenge(codeVerifier string) string {
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
codeVerifierHash := hasher.Sum(nil)

View File

@@ -0,0 +1,21 @@
package openid
type LoginCallback struct {
Client
}
func (in LoginCallback) IdentityProviderError() (bool, error) {
panic("not implemented")
}
func (in LoginCallback) ValidateRequest() error {
panic("not implemented")
}
func (in LoginCallback) RedeemCode() error {
panic("not implemented")
}
func (in LoginCallback) ParseAndValidateTokens() error {
panic("not implemented")
}

205
pkg/openid/login_test.go Normal file
View File

@@ -0,0 +1,205 @@
package openid_test
import (
"errors"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
)
func TestLogin_URL(t *testing.T) {
type loginURLTest struct {
url string
extraParams map[string]string
error error
}
tests := []loginURLTest{
{
url: "http://localhost:1234/oauth2/login?level=Level4",
extraParams: map[string]string{
"acr_values": "Level4",
},
error: nil,
},
{
url: "http://localhost:1234/oauth2/login",
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?level=NoLevel",
error: openid.InvalidSecurityLevelError,
},
{
url: "http://localhost:1234/oauth2/login?locale=nb",
extraParams: map[string]string{
"ui_locales": "nb",
},
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb",
extraParams: map[string]string{
"acr_values": "Level4",
"ui_locales": "nb",
},
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?locale=es",
error: openid.InvalidLocaleError,
},
}
for _, test := range tests {
t.Run(test.url, func(t *testing.T) {
req, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
provider := mock.NewTestProvider()
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
client := mock.NewClient(provider)
result, err := client.Login(req)
if test.error != nil {
assert.True(t, errors.Is(err, test.error))
} else {
assert.NoError(t, err)
parsed, err := url.Parse(result.AuthCodeURL())
assert.NoError(t, err)
query := parsed.Query()
assert.Contains(t, query, "response_type")
assert.Contains(t, query, "client_id")
assert.Contains(t, query, "redirect_uri")
assert.Contains(t, query, "scope")
assert.Contains(t, query, "state")
assert.Contains(t, query, "nonce")
assert.Contains(t, query, "response_mode")
assert.Contains(t, query, "code_challenge")
assert.Contains(t, query, "code_challenge_method")
assert.NotContains(t, query, "resource")
assert.ElementsMatch(t, query["response_type"], []string{"code"})
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
assert.ElementsMatch(t, query["state"], []string{result.State()})
assert.ElementsMatch(t, query["nonce"], []string{result.Nonce()})
assert.ElementsMatch(t, query["response_mode"], []string{"query"})
assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()})
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
if test.extraParams != nil {
for key, value := range test.extraParams {
assert.Contains(t, query, key)
assert.ElementsMatch(t, query[key], []string{value})
}
}
}
})
}
}
func TestLoginURL_WithResourceIndicator(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil)
assert.NoError(t, err)
provider := mock.NewTestProvider()
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
cfg := mock.Config()
cfg.Loginstatus.Enabled = true
cfg.Loginstatus.ResourceIndicator = "https://some-resource"
client := mock.NewClientWithCfg(cfg, provider)
result, err := client.Login(req)
assert.NotEmpty(t, result)
parsed, err := url.Parse(result.AuthCodeURL())
assert.NoError(t, err)
query := parsed.Query()
assert.Contains(t, query, "resource")
assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"})
}
func TestLoginURLParameter(t *testing.T) {
for _, test := range []struct {
name string
parameter string
fallback string
supported openid.Supported
url string
expectErr error
expected string
}{
{
name: "no URL parameter should use fallback value",
url: "http://localhost:8080/oauth2/login",
expected: "valid",
},
{
name: "non-matching URL parameter should be ignored",
url: "http://localhost:8080/oauth2/login?other_param=value2",
expected: "valid",
},
{
name: "matching URL parameter should take precedence",
url: "http://localhost:8080/oauth2/login?param=valid2",
expected: "valid2",
},
{
name: "invalid URL parameter value should return error",
url: "http://localhost:8080/oauth2/login?param=invalid",
expectErr: openid.InvalidLoginParameterError,
},
{
name: "invalid fallback value should return error",
fallback: "invalid",
url: "http://localhost:8080/oauth2/login",
expectErr: openid.InvalidLoginParameterError,
},
{
name: "no supported values should return error",
url: "http://localhost:8080/oauth2/login",
supported: openid.Supported{""},
expectErr: openid.InvalidLoginParameterError,
},
} {
t.Run(test.name, func(t *testing.T) {
r, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
// default test values
parameter := "param"
fallback := "valid"
supported := openid.Supported{"valid", "valid2"}
if len(test.parameter) > 0 {
parameter = test.parameter
}
if len(test.fallback) > 0 {
fallback = test.fallback
}
if len(test.supported) > 0 {
supported = test.supported
}
val, err := openid.LoginURLParameter(r, parameter, fallback, supported)
if test.expectErr == nil {
assert.NoError(t, err)
assert.Equal(t, test.expected, val)
} else {
assert.Error(t, err)
}
})
}
}

13
pkg/openid/logout.go Normal file
View File

@@ -0,0 +1,13 @@
package openid
type Logout struct {
Client
}
func (in Logout) URL() string {
panic("not implemented")
}
func (in Logout) Cookie() LogoutCookie {
panic("not implemented")
}

View File

@@ -0,0 +1,9 @@
package openid
type LogoutCallback struct {
Client
}
func (in LogoutCallback) ValidateRequest() (bool, error) {
panic("not implemented")
}

View File

@@ -2,10 +2,8 @@ package router
import (
"net/http"
"sync"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
@@ -16,14 +14,13 @@ import (
)
type Handler struct {
Client openid.Client
Config config.Config
CookieOptions cookie.Options
Crypter crypto.Crypter
OauthConfig oauth2.Config
Loginstatus loginstatus.Client
Provider openid.Provider
Sessions session.Store
lock sync.Mutex
Httplogger zerolog.Logger
}
@@ -34,28 +31,19 @@ func NewHandler(
provider openid.Provider,
sessionStore session.Store,
) (*Handler, error) {
oauthConfig := oauth2.Config{
ClientID: provider.GetClientConfiguration().GetClientID(),
Endpoint: oauth2.Endpoint{
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
},
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
Scopes: provider.GetClientConfiguration().GetScopes(),
}
client := openid.NewClient(cfg, provider)
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)
cookiePath := config.ParseIngress(cfg.Ingress)
cookieOpts := cookie.DefaultOptions().WithPath(cookiePath)
return &Handler{
Client: client,
Config: cfg,
CookieOptions: cookieOpts,
Crypter: crypter,
Httplogger: httplogger,
lock: sync.Mutex{},
Loginstatus: loginstatusClient,
OauthConfig: oauthConfig,
Provider: provider,
Sessions: sessionStore,
}, nil

View File

@@ -112,7 +112,7 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *openid.
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
}
tokens, err = h.OauthConfig.Exchange(ctx, code, opts...)
tokens, err = h.Client.AuthCodeGrant(ctx, code, opts)
if err != nil {
log.Warnf("callback: exchanging authorization code for token; retrying: %+v", err)
return retry.RetryableError(err)
@@ -154,7 +154,7 @@ func logSuccessfulLogin(r *http.Request, tokens *jwt.Tokens, referer string) {
"claims": tokens.Claims(),
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger.Info().Msg("callback: successful login")
}

View File

@@ -2,16 +2,21 @@ package router
import (
_ "embed"
"fmt"
"html/template"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/go-chi/chi/v5/middleware"
"github.com/rs/zerolog"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
"github.com/nais/wonderwall/pkg/router/request"
)
@@ -58,7 +63,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s
errorPage := ErrorPage{
CorrelationID: middleware.GetReqID(r.Context()),
RetryURI: request.RetryURI(r, h.Config.Ingress, loginCookie),
RetryURI: RetryURI(r, h.Config.Ingress, loginCookie),
}
err = errorTemplate.Execute(w, errorPage)
if err != nil {
@@ -97,3 +102,26 @@ func (h *Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error
func (h *Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
h.respondError(w, r, http.StatusUnauthorized, cause, zerolog.WarnLevel)
}
// RetryURI returns a URI that should retry the desired route that failed.
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
// are related to the authentication flow, we default to redirecting back to the handled
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
retryURI := r.URL.Path
prefix := config.ParseIngress(ingress)
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
return prefix + retryURI
}
redirect := request.CanonicalRedirectURL(r, ingress)
if loginCookie != nil && len(loginCookie.Referer) > 0 {
redirect = loginCookie.Referer
}
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
retryURI = retryURI + fmt.Sprintf("?%s=%s", request.RedirectURLParameter, redirect)
return retryURI
}

View File

@@ -0,0 +1,172 @@
package router_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
)
func TestRetryURI(t *testing.T) {
httpRequest := func(url string, referer ...string) *http.Request {
req, _ := http.NewRequest(http.MethodGet, url, nil)
if len(referer) > 0 {
req.Header.Add("Referer", referer[0])
}
return req
}
for _, test := range []struct {
name string
request *http.Request
ingress string
loginCookie *openid.LoginCookie
want string
}{
{
name: "login path",
request: httpRequest("/oauth2/login"),
want: "/oauth2/login?redirect=/",
},
{
name: "callback path",
request: httpRequest("/oauth2/callback"),
want: "/oauth2/login?redirect=/",
},
{
name: "logout path",
request: httpRequest("/oauth2/logout"),
want: "/oauth2/logout",
},
{
name: "front-channel logout path",
request: httpRequest("/oauth2/logout/frontchannel"),
want: "/oauth2/logout/frontchannel",
},
{
name: "login with non-default ingress",
request: httpRequest("/oauth2/login"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/domene",
},
{
name: "logout with non-default ingress",
request: httpRequest("/oauth2/logout"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/logout",
},
{
name: "login with referer",
request: httpRequest("/oauth2/login", "/api/me"),
want: "/oauth2/login?redirect=/api/me",
},
{
name: "login with referer on non-default ingress",
request: httpRequest("/oauth2/login", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/api/me",
},
{
name: "login with root referer",
request: httpRequest("/oauth2/login", "/"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with root referer on non-default ingress",
request: httpRequest("/oauth2/login", "/"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with cookie referer",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: "/"},
want: "/oauth2/login?redirect=/",
},
{
name: "login with empty cookie referer",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: ""},
want: "/oauth2/login?redirect=/",
},
{
name: "login with cookie referer takes precedence over referer header",
request: httpRequest("/oauth2/login", "/api/me"),
loginCookie: &openid.LoginCookie{Referer: "/api/headers"},
want: "/oauth2/login?redirect=/api/headers",
},
{
name: "login with cookie referer on non-default ingress",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/domene/api/me",
},
{
name: "login with redirect parameter set",
request: httpRequest("/oauth2/login?redirect=/api/me"),
want: "/oauth2/login?redirect=/api/me",
},
{
name: "login with redirect parameter set and query parameters",
request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"),
want: "/oauth2/login?redirect=/api/me?a=b&c=d",
},
{
name: "login with redirect parameter set on non-default ingress",
request: httpRequest("/oauth2/login?redirect=/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/api/me",
},
{
name: "login with redirect parameter set takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/other", "/api/me"),
want: "/oauth2/login?redirect=/other",
},
{
name: "login with redirect parameter set to relative root takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with cookie referer takes precedence over redirect parameter",
request: httpRequest("/oauth2/login?redirect=/other"),
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
want: "/oauth2/login?redirect=/domene/api/me",
},
} {
t.Run(test.name, func(t *testing.T) {
if len(test.ingress) == 0 {
test.ingress = "/"
}
retryURI := router.RetryURI(test.request, test.ingress, test.loginCookie)
assert.Equal(t, test.want, retryURI)
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
"github.com/nais/wonderwall/pkg/router/request"
)
const (
@@ -18,44 +17,30 @@ const (
)
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
params, err := openid.GenerateLoginParameters()
login, err := h.Client.Login(r)
if err != nil {
h.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err))
return
}
loginURL, err := h.LoginURL(r, params)
if err != nil {
cause := fmt.Errorf("login: creating login URL: %w", err)
if errors.Is(err, InvalidSecurityLevelError) || errors.Is(err, InvalidLocaleError) {
h.BadRequest(w, r, cause)
if errors.Is(err, openid.InvalidSecurityLevelError) || errors.Is(err, openid.InvalidLocaleError) {
h.BadRequest(w, r, err)
} else {
h.InternalError(w, r, cause)
h.InternalError(w, r, err)
}
return
}
redirect := request.CanonicalRedirectURL(r, h.Config.Ingress)
err = h.setLoginCookies(w, &openid.LoginCookie{
State: params.State,
Nonce: params.Nonce,
CodeVerifier: params.CodeVerifier,
Referer: redirect,
})
err = h.setLoginCookies(w, login.Cookie())
if err != nil {
h.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
return
}
fields := map[string]interface{}{
"redirect_to": redirect,
"redirect_after_login": login.CanonicalRedirect(),
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger.Info().Msg("login: redirecting to identity provider")
http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
}
func (h *Handler) getLoginCookie(r *http.Request) (*openid.LoginCookie, error) {

View File

@@ -17,12 +17,6 @@ import (
// Logout triggers self-initiated for the current user
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
return
}
var idToken string
sessionData, err := h.getSessionFromCookie(w, r)
@@ -37,7 +31,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
fields := map[string]interface{}{
"claims": sessionData.Claims,
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger.Info().Msg("logout: successful local logout")
}
@@ -47,6 +41,12 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
h.Loginstatus.ClearCookie(w, h.CookieOptions)
}
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
return
}
logoutCookie, err := h.logoutCookie()
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err))
@@ -72,7 +72,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
fields := map[string]interface{}{
"redirect_to": logoutCookie.RedirectTo,
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger.Info().Msg("logout: redirecting to identity provider")
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)

View File

@@ -33,7 +33,7 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
actualState := params.Get("state")
if expectedState != actualState {
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState)
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState)
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
return
}

View File

@@ -1,88 +0,0 @@
package router
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/nais/wonderwall/pkg/openid"
request2 "github.com/nais/wonderwall/pkg/router/request"
)
var (
InvalidSecurityLevelError = errors.New("InvalidSecurityLevel")
InvalidLocaleError = errors.New("InvalidLocale")
)
func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (string, error) {
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().AuthorizationEndpoint)
if err != nil {
return "", err
}
v := u.Query()
v.Add("response_type", "code")
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI())
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
v.Add("state", params.State)
v.Add("nonce", params.Nonce)
v.Add("response_mode", "query")
v.Add("code_challenge", params.CodeChallenge)
v.Add("code_challenge_method", "S256")
if h.Config.Loginstatus.NeedsResourceIndicator() {
v.Add("resource", h.Config.Loginstatus.ResourceIndicator)
}
err = h.withSecurityLevel(r, v)
if err != nil {
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
}
err = h.withLocale(r, v)
if err != nil {
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
}
u.RawQuery = v.Encode()
return u.String(), nil
}
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
acrValues := h.Provider.GetClientConfiguration().GetACRValues()
if len(acrValues) == 0 {
return nil
}
fallback := acrValues
supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported
securityLevel, err := request2.LoginURLParameter(r, request2.SecurityLevelURLParameter, fallback, supported)
if err != nil {
return err
}
v.Add("acr_values", securityLevel)
return nil
}
func (h *Handler) withLocale(r *http.Request, v url.Values) error {
uiLocales := h.Provider.GetClientConfiguration().GetUILocales()
if len(uiLocales) == 0 {
return nil
}
fallback := uiLocales
supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported
locale, err := request2.LoginURLParameter(r, request2.LocaleURLParameter, fallback, supported)
if err != nil {
return err
}
v.Add("ui_locales", locale)
return nil
}

View File

@@ -1,117 +0,0 @@
package router_test
import (
"errors"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
)
func TestLoginURL(t *testing.T) {
type loginURLTest struct {
url string
error error
}
tests := []loginURLTest{
{
url: "http://localhost:1234/oauth2/login?level=Level4",
error: nil,
},
{
url: "http://localhost:1234/oauth2/login",
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?level=NoLevel",
error: router.InvalidSecurityLevelError,
},
{
url: "http://localhost:1234/oauth2/login?locale=nb",
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?level=Level4&locale=nb",
error: nil,
},
{
url: "http://localhost:1234/oauth2/login?locale=es",
error: router.InvalidLocaleError,
},
}
for _, test := range tests {
t.Run(test.url, func(t *testing.T) {
req, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
params, err := openid.GenerateLoginParameters()
assert.NoError(t, err)
provider := mock.NewTestProvider()
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
handler := newHandler(provider)
result, err := handler.LoginURL(req, params)
if test.error != nil {
assert.True(t, errors.Is(err, test.error))
} else {
assert.NoError(t, err)
parsed, err := url.Parse(result)
assert.NoError(t, err)
query := parsed.Query()
assert.Contains(t, query, "response_type")
assert.Contains(t, query, "client_id")
assert.Contains(t, query, "redirect_uri")
assert.Contains(t, query, "scope")
assert.Contains(t, query, "state")
assert.Contains(t, query, "nonce")
assert.Contains(t, query, "response_mode")
assert.Contains(t, query, "code_challenge")
assert.Contains(t, query, "code_challenge_method")
assert.NotContains(t, query, "resource")
assert.ElementsMatch(t, query["response_type"], []string{"code"})
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
assert.ElementsMatch(t, query["state"], []string{params.State})
assert.ElementsMatch(t, query["nonce"], []string{params.Nonce})
assert.ElementsMatch(t, query["response_mode"], []string{"query"})
assert.ElementsMatch(t, query["code_challenge"], []string{params.CodeChallenge})
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
}
})
}
}
func TestLoginURL_WithResourceIndicator(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost:1234/oauth2/login", nil)
assert.NoError(t, err)
params, err := openid.GenerateLoginParameters()
assert.NoError(t, err)
provider := mock.NewTestProvider()
provider.OpenIDConfiguration.AuthorizationEndpoint = "https://provider/authorize"
handler := newHandler(provider)
handler.Config.Loginstatus.Enabled = true
handler.Config.Loginstatus.ResourceIndicator = "https://some-resource"
result, err := handler.LoginURL(req, params)
assert.NotEmpty(t, result)
parsed, err := url.Parse(result)
assert.NoError(t, err)
query := parsed.Query()
assert.Contains(t, query, "resource")
assert.ElementsMatch(t, query["resource"], []string{"https://some-resource"})
}

View File

@@ -52,6 +52,10 @@ func LogEntry(ctx context.Context) zerolog.Logger {
return httplog.NewLogger("wonderwall")
}
func LogEntryWithFields(ctx context.Context, fields any) zerolog.Logger {
return LogEntry(ctx).With().Fields(fields).Logger()
}
type requestLogger struct {
Logger zerolog.Logger
}

View File

@@ -1,7 +0,0 @@
package request
const (
LocaleURLParameter = "locale"
RedirectURLParameter = "redirect"
SecurityLevelURLParameter = "level"
)

View File

@@ -1,19 +1,14 @@
package request
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router/paths"
)
var (
InvalidLoginParameterError = errors.New("InvalidLoginParameter")
const (
RedirectURLParameter = "redirect"
)
// CanonicalRedirectURL constructs a redirect URL that points back to the application.
@@ -78,22 +73,6 @@ func parseRedirectParam(r *http.Request) (string, bool) {
return redirectParamURLString, true
}
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
// The value must exist in the supplied list of supported values.
func LoginURLParameter(r *http.Request, parameter, fallback string, supported openid.Supported) (string, error) {
value := r.URL.Query().Get(parameter)
if len(value) == 0 {
value = fallback
}
if supported.Contains(value) {
return value, nil
}
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
}
func refererPath(r *http.Request) string {
if len(r.Referer()) == 0 {
return ""
@@ -109,26 +88,3 @@ func refererPath(r *http.Request) string {
referer.Host = ""
return referer.String()
}
// RetryURI returns a URI that should retry the desired route that failed.
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
// are related to the authentication flow, we default to redirecting back to the handled
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
func RetryURI(r *http.Request, ingress string, loginCookie *openid.LoginCookie) string {
retryURI := r.URL.Path
prefix := config.ParseIngress(ingress)
if strings.HasSuffix(retryURI, paths.OAuth2+paths.Logout) || strings.HasSuffix(retryURI, paths.OAuth2+paths.FrontChannelLogout) {
return prefix + retryURI
}
redirect := CanonicalRedirectURL(r, ingress)
if loginCookie != nil && len(loginCookie.Referer) > 0 {
redirect = loginCookie.Referer
}
retryURI = fmt.Sprintf(prefix + paths.OAuth2 + paths.Login)
retryURI = retryURI + fmt.Sprintf("?%s=%s", RedirectURLParameter, redirect)
return retryURI
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router/request"
)
@@ -138,240 +137,3 @@ func TestCanonicalRedirectURL(t *testing.T) {
}
})
}
func TestLoginURLParameter(t *testing.T) {
for _, test := range []struct {
name string
parameter string
fallback string
supported openid.Supported
url string
expectErr error
expected string
}{
{
name: "no URL parameter should use fallback value",
url: "http://localhost:8080/oauth2/login",
expected: "valid",
},
{
name: "non-matching URL parameter should be ignored",
url: "http://localhost:8080/oauth2/login?other_param=value2",
expected: "valid",
},
{
name: "matching URL parameter should take precedence",
url: "http://localhost:8080/oauth2/login?param=valid2",
expected: "valid2",
},
{
name: "invalid URL parameter value should return error",
url: "http://localhost:8080/oauth2/login?param=invalid",
expectErr: request.InvalidLoginParameterError,
},
{
name: "invalid fallback value should return error",
fallback: "invalid",
url: "http://localhost:8080/oauth2/login",
expectErr: request.InvalidLoginParameterError,
},
{
name: "no supported values should return error",
url: "http://localhost:8080/oauth2/login",
supported: openid.Supported{""},
expectErr: request.InvalidLoginParameterError,
},
} {
t.Run(test.name, func(t *testing.T) {
r, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
// default test values
parameter := "param"
fallback := "valid"
supported := openid.Supported{"valid", "valid2"}
if len(test.parameter) > 0 {
parameter = test.parameter
}
if len(test.fallback) > 0 {
fallback = test.fallback
}
if len(test.supported) > 0 {
supported = test.supported
}
val, err := request.LoginURLParameter(r, parameter, fallback, supported)
if test.expectErr == nil {
assert.NoError(t, err)
assert.Equal(t, test.expected, val)
} else {
assert.Error(t, err)
}
})
}
}
func TestRetryURI(t *testing.T) {
httpRequest := func(url string, referer ...string) *http.Request {
req, _ := http.NewRequest(http.MethodGet, url, nil)
if len(referer) > 0 {
req.Header.Add("Referer", referer[0])
}
return req
}
for _, test := range []struct {
name string
request *http.Request
ingress string
loginCookie *openid.LoginCookie
want string
}{
{
name: "login path",
request: httpRequest("/oauth2/login"),
want: "/oauth2/login?redirect=/",
},
{
name: "callback path",
request: httpRequest("/oauth2/callback"),
want: "/oauth2/login?redirect=/",
},
{
name: "logout path",
request: httpRequest("/oauth2/logout"),
want: "/oauth2/logout",
},
{
name: "front-channel logout path",
request: httpRequest("/oauth2/logout/frontchannel"),
want: "/oauth2/logout/frontchannel",
},
{
name: "login with non-default ingress",
request: httpRequest("/oauth2/login"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/domene",
},
{
name: "logout with non-default ingress",
request: httpRequest("/oauth2/logout"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/logout",
},
{
name: "login with referer",
request: httpRequest("/oauth2/login", "/api/me"),
want: "/oauth2/login?redirect=/api/me",
},
{
name: "login with referer on non-default ingress",
request: httpRequest("/oauth2/login", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/api/me",
},
{
name: "login with root referer",
request: httpRequest("/oauth2/login", "/"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with root referer on non-default ingress",
request: httpRequest("/oauth2/login", "/"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with cookie referer",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: "/"},
want: "/oauth2/login?redirect=/",
},
{
name: "login with empty cookie referer",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: ""},
want: "/oauth2/login?redirect=/",
},
{
name: "login with cookie referer takes precedence over referer header",
request: httpRequest("/oauth2/login", "/api/me"),
loginCookie: &openid.LoginCookie{Referer: "/api/headers"},
want: "/oauth2/login?redirect=/api/headers",
},
{
name: "login with cookie referer on non-default ingress",
request: httpRequest("/oauth2/login"),
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/domene/api/me",
},
{
name: "login with redirect parameter set",
request: httpRequest("/oauth2/login?redirect=/api/me"),
want: "/oauth2/login?redirect=/api/me",
},
{
name: "login with redirect parameter set and query parameters",
request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"),
want: "/oauth2/login?redirect=/api/me?a=b&c=d",
},
{
name: "login with redirect parameter set on non-default ingress",
request: httpRequest("/oauth2/login?redirect=/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/api/me",
},
{
name: "login with redirect parameter set takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/other", "/api/me"),
want: "/oauth2/login?redirect=/other",
},
{
name: "login with redirect parameter set to relative root takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=/", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
want: "/oauth2/login?redirect=/",
},
{
name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header",
request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"),
ingress: "https://test.nav.no/domene",
want: "/domene/oauth2/login?redirect=/",
},
{
name: "login with cookie referer takes precedence over redirect parameter",
request: httpRequest("/oauth2/login?redirect=/other"),
loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"},
want: "/oauth2/login?redirect=/domene/api/me",
},
} {
t.Run(test.name, func(t *testing.T) {
if len(test.ingress) == 0 {
test.ingress = "/"
}
retryURI := request.RetryURI(test.request, test.ingress, test.loginCookie)
assert.Equal(t, test.want, retryURI)
})
}
}

View File

@@ -8,45 +8,17 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
)
var cfg = config.Config{
EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES
Ingress: "/",
OpenID: config.OpenID{
Provider: "test",
},
SessionMaxLifetime: time.Hour,
}
func newHandler(provider openid.Provider) *router.Handler {
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
sessionStore := session.NewMemory()
h, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, provider, sessionStore)
if err != nil {
panic(err)
}
h.CookieOptions = h.CookieOptions.WithSecure(false)
return h
}
func TestHandler_Login(t *testing.T) {
idpserver, idp := mock.IdentityProviderServer()
h := newHandler(idp)
h := mock.NewHandler(idp)
r := router.New(h)
jar, err := cookiejar.New(nil)
@@ -103,13 +75,14 @@ func TestHandler_Login(t *testing.T) {
func TestHandler_Callback_and_Logout(t *testing.T) {
idpserver, idp := mock.IdentityProviderServer()
h := newHandler(idp)
h := mock.NewHandler(idp)
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback"
h.Client = mock.NewClient(idp)
jar, err := cookiejar.New(nil)
assert.NoError(t, err)
@@ -239,12 +212,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
_, idp := mock.IdentityProviderServer()
idp.WithFrontChannelLogoutSupport()
h := newHandler(idp)
h := mock.NewHandler(idp)
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
h.Client = mock.NewClient(idp)
jar, err := cookiejar.New(nil)
assert.NoError(t, err)
@@ -313,12 +287,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
func TestHandler_SessionStateRequired(t *testing.T) {
idpServer, idp := mock.IdentityProviderServer()
idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession")
h := newHandler(idp)
h := mock.NewHandler(idp)
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
h.Client = mock.NewClient(idp)
jar, err := cookiejar.New(nil)
assert.NoError(t, err)

View File

@@ -21,7 +21,7 @@ import (
func TestHandler_GetSessionFallback(t *testing.T) {
p := mock.NewTestProvider()
h := newHandler(p)
h := mock.NewHandler(p)
tokens := makeTokens(p)
t.Run("request without fallback session cookies", func(t *testing.T) {
@@ -47,7 +47,7 @@ func TestHandler_GetSessionFallback(t *testing.T) {
func TestHandler_SetSessionFallback(t *testing.T) {
provider := mock.NewTestProvider()
h := newHandler(provider)
h := mock.NewHandler(provider)
// request should set session cookies in response
writer := httptest.NewRecorder()
@@ -82,7 +82,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
func TestHandler_DeleteSessionFallback(t *testing.T) {
p := mock.NewTestProvider()
h := newHandler(p)
h := mock.NewHandler(p)
tokens := makeTokens(p)
t.Run("expire cookies if they are set", func(t *testing.T) {