mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-08 09:27:12 +00:00
feat: add handler for logout callbacks
This commit is contained in:
@@ -45,6 +45,7 @@ Wonderwall exposes and owns these endpoints (which means they will never be prox
|
||||
| `/oauth2/login` | Initiates the OpenID Connect Authorization Code flow |
|
||||
| `/oauth2/callback` | Handles the callback from the identity provider |
|
||||
| `/oauth2/logout` | Initiates local and global/single-logout |
|
||||
| `/oauth2/logout/callback` | Handles the logout callback from the identity provider |
|
||||
| `/oauth2/logout/frontchannel` | Handles global logout request (initiated by identity provider on behalf of another client) |
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
type TestClientConfiguration struct {
|
||||
ClientID string
|
||||
ClientJWK jwk.Key
|
||||
RedirectURI string
|
||||
CallbackURI string
|
||||
LogoutCallbackURI string
|
||||
PostLogoutRedirectURI string
|
||||
Scopes scopes.Scopes
|
||||
ACRValues string
|
||||
@@ -18,8 +19,8 @@ type TestClientConfiguration struct {
|
||||
WellKnownURL string
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetRedirectURI() string {
|
||||
return c.RedirectURI
|
||||
func (c TestClientConfiguration) GetCallbackURI() string {
|
||||
return c.CallbackURI
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetClientID() string {
|
||||
@@ -30,6 +31,10 @@ func (c TestClientConfiguration) GetClientJWK() jwk.Key {
|
||||
return c.ClientJWK
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetLogoutCallbackURI() string {
|
||||
return c.LogoutCallbackURI
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetPostLogoutRedirectURI() string {
|
||||
return c.PostLogoutRedirectURI
|
||||
}
|
||||
@@ -59,7 +64,8 @@ func clientConfiguration() TestClientConfiguration {
|
||||
return TestClientConfiguration{
|
||||
ClientID: "client_id",
|
||||
ClientJWK: key,
|
||||
RedirectURI: "http://localhost/callback",
|
||||
CallbackURI: "http://localhost/callback",
|
||||
LogoutCallbackURI: "http://localhost/logout/callback",
|
||||
WellKnownURL: "",
|
||||
UILocales: "nb",
|
||||
ACRValues: "Level4",
|
||||
|
||||
@@ -223,3 +223,27 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
|
||||
|
||||
if state == "" || postLogoutRedirectURI == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required parameters"))
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse post_logout_redirect_uri"))
|
||||
return
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("state", state)
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
@@ -26,5 +26,6 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router {
|
||||
r.Get("/authorize", ip.Authorize)
|
||||
r.Post("/token", ip.Token)
|
||||
r.Get("/jwks", ip.Jwks)
|
||||
r.Get("/endsession", ip.EndSession)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ type Configuration interface {
|
||||
GetClientID() string
|
||||
GetClientJWK() jwk.Key
|
||||
GetPostLogoutRedirectURI() string
|
||||
GetRedirectURI() string
|
||||
GetCallbackURI() string
|
||||
GetLogoutCallbackURI() string
|
||||
GetScopes() scopes.Scopes
|
||||
GetACRValues() string
|
||||
GetUILocales() string
|
||||
|
||||
@@ -9,12 +9,13 @@ import (
|
||||
|
||||
type OpenIDConfig struct {
|
||||
config.OpenID
|
||||
clientJwk jwk.Key
|
||||
redirectURI string
|
||||
clientJwk jwk.Key
|
||||
callbackURI string
|
||||
logoutCallbackURI string
|
||||
}
|
||||
|
||||
func (in *OpenIDConfig) GetRedirectURI() string {
|
||||
return in.redirectURI
|
||||
func (in *OpenIDConfig) GetCallbackURI() string {
|
||||
return in.callbackURI
|
||||
}
|
||||
|
||||
func (in *OpenIDConfig) GetClientID() string {
|
||||
@@ -25,6 +26,10 @@ func (in *OpenIDConfig) GetClientJWK() jwk.Key {
|
||||
return in.clientJwk
|
||||
}
|
||||
|
||||
func (in *OpenIDConfig) GetLogoutCallbackURI() string {
|
||||
return in.logoutCallbackURI
|
||||
}
|
||||
|
||||
func (in *OpenIDConfig) GetPostLogoutRedirectURI() string {
|
||||
return in.PostLogoutRedirectURI
|
||||
}
|
||||
@@ -45,10 +50,11 @@ func (in *OpenIDConfig) GetWellKnownURL() string {
|
||||
return in.WellKnownURL
|
||||
}
|
||||
|
||||
func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, redirectURI string) *OpenIDConfig {
|
||||
func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, callbackURI, logoutCallbackURI string) *OpenIDConfig {
|
||||
return &OpenIDConfig{
|
||||
OpenID: cfg.OpenID,
|
||||
clientJwk: clientJwk,
|
||||
redirectURI: redirectURI,
|
||||
OpenID: cfg.OpenID,
|
||||
clientJwk: clientJwk,
|
||||
callbackURI: callbackURI,
|
||||
logoutCallbackURI: logoutCallbackURI,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,3 +6,8 @@ type LoginCookie struct {
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Referer string `json:"referer"`
|
||||
}
|
||||
|
||||
type LogoutCookie struct {
|
||||
State string `json:"state"`
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/openid/clients"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -91,12 +92,17 @@ func NewProvider(ctx context.Context, cfg *config.Config) (Provider, error) {
|
||||
return nil, fmt.Errorf("missing required config %s", config.Ingress)
|
||||
}
|
||||
|
||||
redirectURI, err := RedirectURI(ingress)
|
||||
callbackURI, err := RedirectURI(ingress, paths.Callback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating redirect URI from ingress: %w", err)
|
||||
return nil, fmt.Errorf("creating callback URI from ingress: %w", err)
|
||||
}
|
||||
|
||||
openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, redirectURI)
|
||||
logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err)
|
||||
}
|
||||
|
||||
openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, callbackURI, logoutCallbackURI)
|
||||
var clientConfig clients.Configuration
|
||||
switch cfg.OpenID.Provider {
|
||||
case config.ProviderIDPorten:
|
||||
@@ -161,7 +167,8 @@ func printConfigs(clientCfg clients.Configuration, openIdCfg Configuration) {
|
||||
log.Infof("acr values: '%s'", clientCfg.GetACRValues())
|
||||
log.Infof("client id: '%s'", clientCfg.GetClientID())
|
||||
log.Infof("post-logout redirect uri: '%s'", clientCfg.GetPostLogoutRedirectURI())
|
||||
log.Infof("redirect uri: '%s'", clientCfg.GetRedirectURI())
|
||||
log.Infof("callback uri: '%s'", clientCfg.GetCallbackURI())
|
||||
log.Infof("logout callback uri: '%s'", clientCfg.GetLogoutCallbackURI())
|
||||
log.Infof("scopes: '%s'", clientCfg.GetScopes())
|
||||
log.Infof("ui locales: '%s'", clientCfg.GetUILocales())
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
|
||||
func RedirectURI(ingress string) (string, error) {
|
||||
func RedirectURI(ingress, redirectPath string) (string, error) {
|
||||
if len(ingress) == 0 {
|
||||
return "", fmt.Errorf("ingress cannot be empty")
|
||||
}
|
||||
@@ -18,6 +18,6 @@ func RedirectURI(ingress string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
base.Path = path.Join(base.Path, paths.OAuth2, paths.Callback)
|
||||
base.Path = path.Join(base.Path, paths.OAuth2, redirectPath)
|
||||
return base.String(), nil
|
||||
}
|
||||
|
||||
@@ -7,36 +7,47 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
)
|
||||
|
||||
func TestRedirectURI(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
input string
|
||||
path string
|
||||
want string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
input: "https://nav.no/dagpenger",
|
||||
path: paths.Callback,
|
||||
want: "https://nav.no/dagpenger/oauth2/callback",
|
||||
},
|
||||
{
|
||||
input: "https://nav.no/dagpenger/soknad",
|
||||
path: paths.Callback,
|
||||
want: "https://nav.no/dagpenger/soknad/oauth2/callback",
|
||||
},
|
||||
{
|
||||
input: "https://nav.no",
|
||||
path: paths.Callback,
|
||||
want: "https://nav.no/oauth2/callback",
|
||||
},
|
||||
{
|
||||
input: "https://nav.no/",
|
||||
path: paths.Callback,
|
||||
want: "https://nav.no/oauth2/callback",
|
||||
},
|
||||
{
|
||||
input: "https://nav.no/",
|
||||
path: paths.LogoutCallback,
|
||||
want: "https://nav.no/oauth2/logout/callback",
|
||||
},
|
||||
{
|
||||
input: "",
|
||||
err: fmt.Errorf("ingress cannot be empty"),
|
||||
},
|
||||
} {
|
||||
actual, err := openid.RedirectURI(test.input)
|
||||
actual, err := openid.RedirectURI(test.input, test.path)
|
||||
if test.err != nil {
|
||||
assert.EqualError(t, err, test.err.Error())
|
||||
} else {
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
SessionCookieName = "io.nais.wonderwall.session"
|
||||
LoginCookieName = "io.nais.wonderwall.callback"
|
||||
LoginLegacyCookieName = "io.nais.wonderwall.callback.legacy"
|
||||
LogoutCookieName = "io.nais.wonderwall.logout"
|
||||
)
|
||||
|
||||
func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, opts cookie.Options) error {
|
||||
|
||||
@@ -40,7 +40,7 @@ func NewHandler(
|
||||
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
|
||||
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
|
||||
},
|
||||
RedirectURL: provider.GetClientConfiguration().GetRedirectURI(),
|
||||
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
|
||||
Scopes: provider.GetClientConfiguration().GetScopes(),
|
||||
}
|
||||
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/router/request"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
logentry "github.com/nais/wonderwall/pkg/router/middleware"
|
||||
"github.com/nais/wonderwall/pkg/strings"
|
||||
)
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
@@ -24,12 +28,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil && sessionData != nil {
|
||||
idToken = sessionData.IDToken
|
||||
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.WithField("claims", sessionData.Claims).Infof("logout: successful logout")
|
||||
fields := map[string]interface{}{
|
||||
"claims": sessionData.Claims,
|
||||
}
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger.Info().Msg("logout: successful local logout")
|
||||
}
|
||||
|
||||
h.deleteCookie(w, SessionCookieName, h.CookieOptions)
|
||||
@@ -38,18 +46,63 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
h.Loginstatus.ClearCookie(w, h.CookieOptions)
|
||||
}
|
||||
|
||||
v := u.Query()
|
||||
|
||||
postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI())
|
||||
if len(postLogoutURI) > 0 {
|
||||
v.Add("post_logout_redirect_uri", postLogoutURI)
|
||||
logoutCookie, err := h.logoutCookie()
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = h.setLogoutCookie(w, logoutCookie)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
v := u.Query()
|
||||
v.Add("post_logout_redirect_uri", h.Provider.GetClientConfiguration().GetLogoutCallbackURI())
|
||||
v.Add("state", logoutCookie.State)
|
||||
|
||||
if len(idToken) > 0 {
|
||||
v.Add("id_token_hint", idToken)
|
||||
}
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"redirect_to": logoutCookie.RedirectTo,
|
||||
}
|
||||
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
|
||||
logger.Info().Msg("logout: redirecting to identity provider")
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) logoutCookie() (*openid.LogoutCookie, error) {
|
||||
state, err := strings.GenerateBase64(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating state: %w", err)
|
||||
}
|
||||
|
||||
return &openid.LogoutCookie{
|
||||
State: state,
|
||||
RedirectTo: h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setLogoutCookie(w http.ResponseWriter, logoutCookie *openid.LogoutCookie) error {
|
||||
logoutCookieJson, err := json.Marshal(logoutCookie)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling login cookie: %w", err)
|
||||
}
|
||||
|
||||
opts := h.CookieOptions.
|
||||
WithExpiresIn(LogoutCookieLifetime)
|
||||
value := string(logoutCookieJson)
|
||||
|
||||
err = h.setEncryptedCookie(w, LogoutCookieName, value, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
63
pkg/router/handler_logout_callback.go
Normal file
63
pkg/router/handler_logout_callback.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
logentry "github.com/nais/wonderwall/pkg/router/middleware"
|
||||
)
|
||||
|
||||
const (
|
||||
LogoutCookieLifetime = 5 * time.Minute
|
||||
)
|
||||
|
||||
// LogoutCallback handles the callback from the self-initiated logout for the current user
|
||||
func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
h.deleteCookie(w, LogoutCookieName, h.CookieOptions)
|
||||
|
||||
logger := logentry.LogEntry(r.Context())
|
||||
|
||||
logoutCookie, err := h.getLogoutCookie(r)
|
||||
if err != nil {
|
||||
logger.Warn().Msgf("logout/callback: getting cookie: %+v", err)
|
||||
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
params := r.URL.Query()
|
||||
expectedState := logoutCookie.State
|
||||
actualState := params.Get("state")
|
||||
|
||||
if expectedState != actualState {
|
||||
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState)
|
||||
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
if len(logoutCookie.RedirectTo) == 0 {
|
||||
logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress")
|
||||
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info().Msgf("logout/callback: redirecting to %s", logoutCookie.RedirectTo)
|
||||
http.Redirect(w, r, logoutCookie.RedirectTo, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) getLogoutCookie(r *http.Request) (*openid.LogoutCookie, error) {
|
||||
logoutCookieJson, err := h.getDecryptedCookie(r, LogoutCookieName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var logoutCookie openid.LogoutCookie
|
||||
err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshalling: %w", err)
|
||||
}
|
||||
|
||||
return &logoutCookie, nil
|
||||
}
|
||||
@@ -24,7 +24,7 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (str
|
||||
v := u.Query()
|
||||
v.Add("response_type", "code")
|
||||
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
|
||||
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI())
|
||||
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI())
|
||||
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
|
||||
v.Add("state", params.State)
|
||||
v.Add("nonce", params.Nonce)
|
||||
|
||||
@@ -81,7 +81,7 @@ func TestLoginURL(t *testing.T) {
|
||||
|
||||
assert.ElementsMatch(t, query["response_type"], []string{"code"})
|
||||
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
|
||||
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.RedirectURI})
|
||||
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
|
||||
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
|
||||
assert.ElementsMatch(t, query["state"], []string{params.State})
|
||||
assert.ElementsMatch(t, query["nonce"], []string{params.Nonce})
|
||||
|
||||
@@ -5,5 +5,6 @@ const (
|
||||
Login = "/login"
|
||||
Callback = "/callback"
|
||||
Logout = "/logout"
|
||||
LogoutCallback = "/logout/callback"
|
||||
FrontChannelLogout = "/logout/frontchannel"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package request
|
||||
|
||||
const (
|
||||
LocaleURLParameter = "locale"
|
||||
PostLogoutRedirectURIParameter = "post_logout_redirect_uri"
|
||||
RedirectURLParameter = "redirect"
|
||||
SecurityLevelURLParameter = "level"
|
||||
LocaleURLParameter = "locale"
|
||||
RedirectURLParameter = "redirect"
|
||||
SecurityLevelURLParameter = "level"
|
||||
)
|
||||
|
||||
@@ -93,15 +93,6 @@ func LoginURLParameter(r *http.Request, parameter, fallback string, supported op
|
||||
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
|
||||
}
|
||||
|
||||
func PostLogoutRedirectURI(r *http.Request, fallback string) string {
|
||||
value := r.URL.Query().Get(PostLogoutRedirectURIParameter)
|
||||
|
||||
if len(value) > 0 {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func refererPath(r *http.Request) string {
|
||||
if len(r.Referer()) == 0 {
|
||||
return ""
|
||||
|
||||
@@ -25,6 +25,7 @@ func New(handler *Handler) chi.Router {
|
||||
r.Get(paths.Callback, handler.Callback)
|
||||
r.Get(paths.Logout, handler.Logout)
|
||||
r.Get(paths.FrontChannelLogout, handler.FrontChannelLogout)
|
||||
r.Get(paths.LogoutCallback, handler.LogoutCallback)
|
||||
})
|
||||
r.HandleFunc("/*", handler.Default)
|
||||
return r
|
||||
|
||||
@@ -81,7 +81,7 @@ func TestHandler_Login(t *testing.T) {
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetACRValues(), u.Query().Get("acr_values"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetUILocales(), u.Query().Get("ui_locales"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), u.Query().Get("client_id"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetRedirectURI(), u.Query().Get("redirect_uri"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetCallbackURI(), u.Query().Get("redirect_uri"))
|
||||
assert.NotEmpty(t, u.Query().Get("state"))
|
||||
assert.NotEmpty(t, u.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, u.Query().Get("code_challenge"))
|
||||
@@ -106,8 +106,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback"
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -175,8 +176,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
|
||||
cookies = client.Jar.Cookies(logoutURL)
|
||||
sessionCookie = getCookieFromJar(router.SessionCookieName, cookies)
|
||||
logoutCookie := getCookieFromJar(router.LogoutCookieName, cookies)
|
||||
|
||||
assert.Nil(t, sessionCookie)
|
||||
assert.NotNil(t, logoutCookie)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
location = resp.Header.Get("location")
|
||||
@@ -187,11 +190,48 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
endsessionParams := endsessionURL.Query()
|
||||
|
||||
expectedState := endsessionParams["state"]
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetPostLogoutRedirectURI()})
|
||||
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetLogoutCallbackURI()})
|
||||
assert.NotEmpty(t, endsessionParams["id_token_hint"])
|
||||
assert.NotEmpty(t, expectedState)
|
||||
|
||||
// Follow redirect to endsession endpoint at identity provider
|
||||
resp, err = client.Get(endsessionURL.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get post-logout redirect URI after successful logout at identity provider
|
||||
location = resp.Header.Get("location")
|
||||
logoutCallbackURI, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, logoutCallbackURI.String(), idp.ClientConfiguration.GetLogoutCallbackURI())
|
||||
logoutCallbackParams := endsessionURL.Query()
|
||||
actualState := logoutCallbackParams["state"]
|
||||
|
||||
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
|
||||
assert.NotEmpty(t, actualState)
|
||||
assert.Equal(t, expectedState, actualState)
|
||||
|
||||
// Follow redirect back to logout callback
|
||||
resp, err = client.Get(logoutCallbackURI.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// Get post-logout redirect URI after redirect back to logout callback
|
||||
location = resp.Header.Get("location")
|
||||
postLogoutRedirectURI, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, idp.ClientConfiguration.GetPostLogoutRedirectURI(), postLogoutRedirectURI.String())
|
||||
|
||||
cookies = client.Jar.Cookies(logoutCallbackURI)
|
||||
sessionCookie = getCookieFromJar(router.SessionCookieName, cookies)
|
||||
logoutCookie = getCookieFromJar(router.LogoutCookieName, cookies)
|
||||
|
||||
assert.Nil(t, sessionCookie)
|
||||
assert.Nil(t, logoutCookie)
|
||||
}
|
||||
|
||||
func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
@@ -202,7 +242,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
@@ -276,7 +316,7 @@ func TestHandler_SessionStateRequired(t *testing.T) {
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
|
||||
Reference in New Issue
Block a user