feat: add handler for logout callbacks

This commit is contained in:
Trong Huu Nguyen
2022-05-09 11:50:19 +02:00
parent b3dfa54768
commit 32dd80b5da
21 changed files with 262 additions and 51 deletions

View File

@@ -45,6 +45,7 @@ Wonderwall exposes and owns these endpoints (which means they will never be prox
| `/oauth2/login` | Initiates the OpenID Connect Authorization Code flow |
| `/oauth2/callback` | Handles the callback from the identity provider |
| `/oauth2/logout` | Initiates local and global/single-logout |
| `/oauth2/logout/callback` | Handles the logout callback from the identity provider |
| `/oauth2/logout/frontchannel` | Handles global logout request (initiated by identity provider on behalf of another client) |
## Usage

View File

@@ -10,7 +10,8 @@ import (
type TestClientConfiguration struct {
ClientID string
ClientJWK jwk.Key
RedirectURI string
CallbackURI string
LogoutCallbackURI string
PostLogoutRedirectURI string
Scopes scopes.Scopes
ACRValues string
@@ -18,8 +19,8 @@ type TestClientConfiguration struct {
WellKnownURL string
}
func (c TestClientConfiguration) GetRedirectURI() string {
return c.RedirectURI
func (c TestClientConfiguration) GetCallbackURI() string {
return c.CallbackURI
}
func (c TestClientConfiguration) GetClientID() string {
@@ -30,6 +31,10 @@ func (c TestClientConfiguration) GetClientJWK() jwk.Key {
return c.ClientJWK
}
func (c TestClientConfiguration) GetLogoutCallbackURI() string {
return c.LogoutCallbackURI
}
func (c TestClientConfiguration) GetPostLogoutRedirectURI() string {
return c.PostLogoutRedirectURI
}
@@ -59,7 +64,8 @@ func clientConfiguration() TestClientConfiguration {
return TestClientConfiguration{
ClientID: "client_id",
ClientJWK: key,
RedirectURI: "http://localhost/callback",
CallbackURI: "http://localhost/callback",
LogoutCallbackURI: "http://localhost/logout/callback",
WellKnownURL: "",
UILocales: "nb",
ACRValues: "Level4",

View File

@@ -223,3 +223,27 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(token)
}
func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
state := query.Get("state")
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
if state == "" || postLogoutRedirectURI == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required parameters"))
return
}
u, err := url.Parse(postLogoutRedirectURI)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("couldn't parse post_logout_redirect_uri"))
return
}
v := url.Values{}
v.Set("state", state)
u.RawQuery = v.Encode()
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}

View File

@@ -26,5 +26,6 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router {
r.Get("/authorize", ip.Authorize)
r.Post("/token", ip.Token)
r.Get("/jwks", ip.Jwks)
r.Get("/endsession", ip.EndSession)
return r
}

View File

@@ -10,7 +10,8 @@ type Configuration interface {
GetClientID() string
GetClientJWK() jwk.Key
GetPostLogoutRedirectURI() string
GetRedirectURI() string
GetCallbackURI() string
GetLogoutCallbackURI() string
GetScopes() scopes.Scopes
GetACRValues() string
GetUILocales() string

View File

@@ -9,12 +9,13 @@ import (
type OpenIDConfig struct {
config.OpenID
clientJwk jwk.Key
redirectURI string
clientJwk jwk.Key
callbackURI string
logoutCallbackURI string
}
func (in *OpenIDConfig) GetRedirectURI() string {
return in.redirectURI
func (in *OpenIDConfig) GetCallbackURI() string {
return in.callbackURI
}
func (in *OpenIDConfig) GetClientID() string {
@@ -25,6 +26,10 @@ func (in *OpenIDConfig) GetClientJWK() jwk.Key {
return in.clientJwk
}
func (in *OpenIDConfig) GetLogoutCallbackURI() string {
return in.logoutCallbackURI
}
func (in *OpenIDConfig) GetPostLogoutRedirectURI() string {
return in.PostLogoutRedirectURI
}
@@ -45,10 +50,11 @@ func (in *OpenIDConfig) GetWellKnownURL() string {
return in.WellKnownURL
}
func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, redirectURI string) *OpenIDConfig {
func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, callbackURI, logoutCallbackURI string) *OpenIDConfig {
return &OpenIDConfig{
OpenID: cfg.OpenID,
clientJwk: clientJwk,
redirectURI: redirectURI,
OpenID: cfg.OpenID,
clientJwk: clientJwk,
callbackURI: callbackURI,
logoutCallbackURI: logoutCallbackURI,
}
}

View File

@@ -6,3 +6,8 @@ type LoginCookie struct {
CodeVerifier string `json:"code_verifier"`
Referer string `json:"referer"`
}
type LogoutCookie struct {
State string `json:"state"`
RedirectTo string `json:"redirect_to"`
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid/clients"
"github.com/nais/wonderwall/pkg/router/paths"
)
const (
@@ -91,12 +92,17 @@ func NewProvider(ctx context.Context, cfg *config.Config) (Provider, error) {
return nil, fmt.Errorf("missing required config %s", config.Ingress)
}
redirectURI, err := RedirectURI(ingress)
callbackURI, err := RedirectURI(ingress, paths.Callback)
if err != nil {
return nil, fmt.Errorf("creating redirect URI from ingress: %w", err)
return nil, fmt.Errorf("creating callback URI from ingress: %w", err)
}
openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, redirectURI)
logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback)
if err != nil {
return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err)
}
openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, callbackURI, logoutCallbackURI)
var clientConfig clients.Configuration
switch cfg.OpenID.Provider {
case config.ProviderIDPorten:
@@ -161,7 +167,8 @@ func printConfigs(clientCfg clients.Configuration, openIdCfg Configuration) {
log.Infof("acr values: '%s'", clientCfg.GetACRValues())
log.Infof("client id: '%s'", clientCfg.GetClientID())
log.Infof("post-logout redirect uri: '%s'", clientCfg.GetPostLogoutRedirectURI())
log.Infof("redirect uri: '%s'", clientCfg.GetRedirectURI())
log.Infof("callback uri: '%s'", clientCfg.GetCallbackURI())
log.Infof("logout callback uri: '%s'", clientCfg.GetLogoutCallbackURI())
log.Infof("scopes: '%s'", clientCfg.GetScopes())
log.Infof("ui locales: '%s'", clientCfg.GetUILocales())

View File

@@ -8,7 +8,7 @@ import (
"github.com/nais/wonderwall/pkg/router/paths"
)
func RedirectURI(ingress string) (string, error) {
func RedirectURI(ingress, redirectPath string) (string, error) {
if len(ingress) == 0 {
return "", fmt.Errorf("ingress cannot be empty")
}
@@ -18,6 +18,6 @@ func RedirectURI(ingress string) (string, error) {
return "", err
}
base.Path = path.Join(base.Path, paths.OAuth2, paths.Callback)
base.Path = path.Join(base.Path, paths.OAuth2, redirectPath)
return base.String(), nil
}

View File

@@ -7,36 +7,47 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router/paths"
)
func TestRedirectURI(t *testing.T) {
for _, test := range []struct {
input string
path string
want string
err error
}{
{
input: "https://nav.no/dagpenger",
path: paths.Callback,
want: "https://nav.no/dagpenger/oauth2/callback",
},
{
input: "https://nav.no/dagpenger/soknad",
path: paths.Callback,
want: "https://nav.no/dagpenger/soknad/oauth2/callback",
},
{
input: "https://nav.no",
path: paths.Callback,
want: "https://nav.no/oauth2/callback",
},
{
input: "https://nav.no/",
path: paths.Callback,
want: "https://nav.no/oauth2/callback",
},
{
input: "https://nav.no/",
path: paths.LogoutCallback,
want: "https://nav.no/oauth2/logout/callback",
},
{
input: "",
err: fmt.Errorf("ingress cannot be empty"),
},
} {
actual, err := openid.RedirectURI(test.input)
actual, err := openid.RedirectURI(test.input, test.path)
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {

View File

@@ -10,6 +10,7 @@ const (
SessionCookieName = "io.nais.wonderwall.session"
LoginCookieName = "io.nais.wonderwall.callback"
LoginLegacyCookieName = "io.nais.wonderwall.callback.legacy"
LogoutCookieName = "io.nais.wonderwall.logout"
)
func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, opts cookie.Options) error {

View File

@@ -40,7 +40,7 @@ func NewHandler(
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
},
RedirectURL: provider.GetClientConfiguration().GetRedirectURI(),
RedirectURL: provider.GetClientConfiguration().GetCallbackURI(),
Scopes: provider.GetClientConfiguration().GetScopes(),
}
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)

View File

@@ -1,13 +1,17 @@
package router
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
"github.com/go-redis/redis/v8"
"github.com/nais/wonderwall/pkg/router/request"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
"github.com/nais/wonderwall/pkg/strings"
)
// Logout triggers self-initiated for the current user
@@ -24,12 +28,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
if err != nil {
if err != nil && !errors.Is(err, redis.Nil) {
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}
log.WithField("claims", sessionData.Claims).Infof("logout: successful logout")
fields := map[string]interface{}{
"claims": sessionData.Claims,
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger.Info().Msg("logout: successful local logout")
}
h.deleteCookie(w, SessionCookieName, h.CookieOptions)
@@ -38,18 +46,63 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
h.Loginstatus.ClearCookie(w, h.CookieOptions)
}
v := u.Query()
postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI())
if len(postLogoutURI) > 0 {
v.Add("post_logout_redirect_uri", postLogoutURI)
logoutCookie, err := h.logoutCookie()
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err))
return
}
err = h.setLogoutCookie(w, logoutCookie)
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err))
return
}
v := u.Query()
v.Add("post_logout_redirect_uri", h.Provider.GetClientConfiguration().GetLogoutCallbackURI())
v.Add("state", logoutCookie.State)
if len(idToken) > 0 {
v.Add("id_token_hint", idToken)
}
u.RawQuery = v.Encode()
fields := map[string]interface{}{
"redirect_to": logoutCookie.RedirectTo,
}
logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger()
logger.Info().Msg("logout: redirecting to identity provider")
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}
func (h *Handler) logoutCookie() (*openid.LogoutCookie, error) {
state, err := strings.GenerateBase64(32)
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
return &openid.LogoutCookie{
State: state,
RedirectTo: h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI(),
}, nil
}
func (h *Handler) setLogoutCookie(w http.ResponseWriter, logoutCookie *openid.LogoutCookie) error {
logoutCookieJson, err := json.Marshal(logoutCookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}
opts := h.CookieOptions.
WithExpiresIn(LogoutCookieLifetime)
value := string(logoutCookieJson)
err = h.setEncryptedCookie(w, LogoutCookieName, value, opts)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,63 @@
package router
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
const (
LogoutCookieLifetime = 5 * time.Minute
)
// LogoutCallback handles the callback from the self-initiated logout for the current user
func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
h.deleteCookie(w, LogoutCookieName, h.CookieOptions)
logger := logentry.LogEntry(r.Context())
logoutCookie, err := h.getLogoutCookie(r)
if err != nil {
logger.Warn().Msgf("logout/callback: getting cookie: %+v", err)
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
return
}
params := r.URL.Query()
expectedState := logoutCookie.State
actualState := params.Get("state")
if expectedState != actualState {
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState)
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
return
}
if len(logoutCookie.RedirectTo) == 0 {
logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress")
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
return
}
logger.Info().Msgf("logout/callback: redirecting to %s", logoutCookie.RedirectTo)
http.Redirect(w, r, logoutCookie.RedirectTo, http.StatusTemporaryRedirect)
}
func (h *Handler) getLogoutCookie(r *http.Request) (*openid.LogoutCookie, error) {
logoutCookieJson, err := h.getDecryptedCookie(r, LogoutCookieName)
if err != nil {
return nil, err
}
var logoutCookie openid.LogoutCookie
err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie)
if err != nil {
return nil, fmt.Errorf("unmarshalling: %w", err)
}
return &logoutCookie, nil
}

View File

@@ -24,7 +24,7 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (str
v := u.Query()
v.Add("response_type", "code")
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI())
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI())
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
v.Add("state", params.State)
v.Add("nonce", params.Nonce)

View File

@@ -81,7 +81,7 @@ func TestLoginURL(t *testing.T) {
assert.ElementsMatch(t, query["response_type"], []string{"code"})
assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID})
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.RedirectURI})
assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI})
assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()})
assert.ElementsMatch(t, query["state"], []string{params.State})
assert.ElementsMatch(t, query["nonce"], []string{params.Nonce})

View File

@@ -5,5 +5,6 @@ const (
Login = "/login"
Callback = "/callback"
Logout = "/logout"
LogoutCallback = "/logout/callback"
FrontChannelLogout = "/logout/frontchannel"
)

View File

@@ -1,8 +1,7 @@
package request
const (
LocaleURLParameter = "locale"
PostLogoutRedirectURIParameter = "post_logout_redirect_uri"
RedirectURLParameter = "redirect"
SecurityLevelURLParameter = "level"
LocaleURLParameter = "locale"
RedirectURLParameter = "redirect"
SecurityLevelURLParameter = "level"
)

View File

@@ -93,15 +93,6 @@ func LoginURLParameter(r *http.Request, parameter, fallback string, supported op
return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value)
}
func PostLogoutRedirectURI(r *http.Request, fallback string) string {
value := r.URL.Query().Get(PostLogoutRedirectURIParameter)
if len(value) > 0 {
return value
}
return fallback
}
func refererPath(r *http.Request) string {
if len(r.Referer()) == 0 {
return ""

View File

@@ -25,6 +25,7 @@ func New(handler *Handler) chi.Router {
r.Get(paths.Callback, handler.Callback)
r.Get(paths.Logout, handler.Logout)
r.Get(paths.FrontChannelLogout, handler.FrontChannelLogout)
r.Get(paths.LogoutCallback, handler.LogoutCallback)
})
r.HandleFunc("/*", handler.Default)
return r

View File

@@ -81,7 +81,7 @@ func TestHandler_Login(t *testing.T) {
assert.Equal(t, idp.GetClientConfiguration().GetACRValues(), u.Query().Get("acr_values"))
assert.Equal(t, idp.GetClientConfiguration().GetUILocales(), u.Query().Get("ui_locales"))
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), u.Query().Get("client_id"))
assert.Equal(t, idp.GetClientConfiguration().GetRedirectURI(), u.Query().Get("redirect_uri"))
assert.Equal(t, idp.GetClientConfiguration().GetCallbackURI(), u.Query().Get("redirect_uri"))
assert.NotEmpty(t, u.Query().Get("state"))
assert.NotEmpty(t, u.Query().Get("nonce"))
assert.NotEmpty(t, u.Query().Get("code_challenge"))
@@ -106,8 +106,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback"
jar, err := cookiejar.New(nil)
assert.NoError(t, err)
@@ -175,8 +176,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
cookies = client.Jar.Cookies(logoutURL)
sessionCookie = getCookieFromJar(router.SessionCookieName, cookies)
logoutCookie := getCookieFromJar(router.LogoutCookieName, cookies)
assert.Nil(t, sessionCookie)
assert.NotNil(t, logoutCookie)
// Get endsession endpoint after local logout
location = resp.Header.Get("location")
@@ -187,11 +190,48 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
expectedState := endsessionParams["state"]
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetPostLogoutRedirectURI()})
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetLogoutCallbackURI()})
assert.NotEmpty(t, endsessionParams["id_token_hint"])
assert.NotEmpty(t, expectedState)
// Follow redirect to endsession endpoint at identity provider
resp, err = client.Get(endsessionURL.String())
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
defer resp.Body.Close()
// Get post-logout redirect URI after successful logout at identity provider
location = resp.Header.Get("location")
logoutCallbackURI, err := url.Parse(location)
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), idp.ClientConfiguration.GetLogoutCallbackURI())
logoutCallbackParams := endsessionURL.Query()
actualState := logoutCallbackParams["state"]
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
assert.NotEmpty(t, actualState)
assert.Equal(t, expectedState, actualState)
// Follow redirect back to logout callback
resp, err = client.Get(logoutCallbackURI.String())
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Get post-logout redirect URI after redirect back to logout callback
location = resp.Header.Get("location")
postLogoutRedirectURI, err := url.Parse(location)
assert.NoError(t, err)
assert.Equal(t, idp.ClientConfiguration.GetPostLogoutRedirectURI(), postLogoutRedirectURI.String())
cookies = client.Jar.Cookies(logoutCallbackURI)
sessionCookie = getCookieFromJar(router.SessionCookieName, cookies)
logoutCookie = getCookieFromJar(router.LogoutCookieName, cookies)
assert.Nil(t, sessionCookie)
assert.Nil(t, logoutCookie)
}
func TestHandler_FrontChannelLogout(t *testing.T) {
@@ -202,7 +242,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
jar, err := cookiejar.New(nil)
@@ -276,7 +316,7 @@ func TestHandler_SessionStateRequired(t *testing.T) {
r := router.New(h)
server := httptest.NewServer(r)
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback"
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
jar, err := cookiejar.New(nil)