refactor(openid/logout): simplify logout logic

As we already clear any local sessions before redirecting to the
Identity Provider, and the callback always redirects to a pre-configured URL,
there isn't really any need to maintain and verify state in the logout
callback.

In other words, the logout callback handler is simply a redirect handler.
This commit is contained in:
Trong Huu Nguyen
2022-07-12 15:09:40 +02:00
parent c321cff4eb
commit 66cf08e602
13 changed files with 66 additions and 287 deletions

View File

@@ -13,7 +13,6 @@ const (
Session = "io.nais.wonderwall.session"
Login = "io.nais.wonderwall.callback"
LoginLegacy = "io.nais.wonderwall.callback.legacy"
Logout = "io.nais.wonderwall.logout"
)
type Cookie struct {

View File

@@ -318,10 +318,9 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
state := query.Get("state")
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
if state == "" || postLogoutRedirectURI == "" {
if postLogoutRedirectURI == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required parameters"))
return
@@ -333,9 +332,6 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req
w.Write([]byte("couldn't parse post_logout_redirect_uri"))
return
}
v := url.Values{}
v.Set("state", state)
u.RawQuery = v.Encode()
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}

View File

@@ -22,7 +22,7 @@ type Client interface {
Login(r *http.Request) (Login, error)
LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error)
Logout() (Logout, error)
LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error)
LogoutCallback(r *http.Request) LogoutCallback
LogoutFrontchannel(r *http.Request) LogoutFrontchannel
AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error)
@@ -88,13 +88,8 @@ func (c client) Logout() (Logout, error) {
return logout, nil
}
func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) {
logoutCallback, err := NewLogoutCallback(r, cookie)
if err != nil {
return nil, fmt.Errorf("logout/callback: %w", err)
}
return logoutCallback, nil
func (c client) LogoutCallback(r *http.Request) LogoutCallback {
return NewLogoutCallback(c, r)
}
func (c client) LogoutFrontchannel(r *http.Request) LogoutFrontchannel {

View File

@@ -160,7 +160,6 @@ func newLoginCallback(t *testing.T, url string, cookie *openid.LoginCookie) (moc
cfg := idp.OpenIDConfig
cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI
cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI
cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint
loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie)

View File

@@ -3,34 +3,18 @@ package client
import (
"fmt"
"net/url"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/strings"
)
type Logout interface {
CanonicalRedirect() string
Cookie() *openid.LogoutCookie
SingleLogoutURL(idToken string) string
}
type logout struct {
Client
cookie *openid.LogoutCookie
endSessionEndpoint *url.URL
}
func NewLogout(c Client) (Logout, error) {
state, err := strings.GenerateBase64(32)
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
cookie := &openid.LogoutCookie{
State: state,
RedirectTo: c.config().Client().GetPostLogoutRedirectURI(),
}
endSessionEndpoint, err := url.Parse(c.config().Provider().EndSessionEndpoint)
if err != nil {
return nil, fmt.Errorf("parsing end session endpoint: %w", err)
@@ -38,23 +22,13 @@ func NewLogout(c Client) (Logout, error) {
return &logout{
Client: c,
cookie: cookie,
endSessionEndpoint: endSessionEndpoint,
}, nil
}
func (in logout) CanonicalRedirect() string {
return in.cookie.RedirectTo
}
func (in logout) Cookie() *openid.LogoutCookie {
return in.cookie
}
func (in logout) SingleLogoutURL(idToken string) string {
v := in.endSessionEndpoint.Query()
v.Add("post_logout_redirect_uri", in.config().Client().GetLogoutCallbackURI())
v.Add("state", in.cookie.State)
if len(idToken) > 0 {
v.Add("id_token_hint", idToken)

View File

@@ -1,64 +1,31 @@
package client
import (
"fmt"
"net/http"
"net/url"
"github.com/nais/wonderwall/pkg/openid"
)
type LogoutCallback interface {
ValidateRequest() error
PostLogoutRedirectURI() string
}
type logoutCallback struct {
cookie *openid.LogoutCookie
requestParams url.Values
Client
request *http.Request
}
func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) {
if cookie == nil {
return nil, fmt.Errorf("cookie is nil")
}
func NewLogoutCallback(c Client, r *http.Request) LogoutCallback {
return &logoutCallback{
requestParams: r.URL.Query(),
cookie: cookie,
}, nil
Client: c,
request: r,
}
}
func (in logoutCallback) ValidateRequest() error {
if err := in.emptyRedirectError(); err != nil {
return err
func (in logoutCallback) PostLogoutRedirectURI() string {
redirect := in.config().Client().GetPostLogoutRedirectURI()
if len(redirect) == 0 {
return in.config().Wonderwall().Ingress
}
if err := in.stateMismatchError(); err != nil {
return err
}
return nil
}
func (in logoutCallback) emptyRedirectError() error {
if len(in.cookie.RedirectTo) == 0 {
return fmt.Errorf("empty redirect")
}
return nil
}
func (in logoutCallback) stateMismatchError() error {
expectedState := in.cookie.State
actualState := in.requestParams.Get("state")
if len(actualState) <= 0 {
return fmt.Errorf("missing state parameter in request (possible csrf)")
}
if expectedState != actualState {
return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState)
}
return nil
return redirect
}

View File

@@ -6,76 +6,35 @@ import (
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid/client"
)
func TestLogoutCallback_ValidateRequest(t *testing.T) {
t.Run("nil cookie", func(t *testing.T) {
_, err := newLogoutCallback(t, "http://localhost/oauth2/logout/callback?state=some-state", nil)
assert.Error(t, err)
func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
lc, cfg := newLogoutCallback(t)
cfg.ClientConfig.PostLogoutRedirectURI = "http://some-fancy-logout-page"
uri := lc.PostLogoutRedirectURI()
assert.NotEmpty(t, uri)
assert.Equal(t, "http://some-fancy-logout-page", uri)
})
for _, test := range []struct {
name string
url string
cookie *openid.LogoutCookie
wantErr bool
}{
{
name: "valid request",
url: "http://localhost/oauth2/logout/callback?state=some-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: false,
},
{
name: "empty redirect",
url: "http://localhost/oauth2/logout/callback?state=some-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "",
},
wantErr: true,
},
{
name: "empty state",
url: "http://localhost/oauth2/logout/callback",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: true,
},
{
name: "state mismatch",
url: "http://localhost/oauth2/logout/callback?state=some-other-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: true,
},
} {
t.Run(test.name, func(t *testing.T) {
lc, err := newLogoutCallback(t, test.url, test.cookie)
assert.NoError(t, err)
t.Run("empty preconfigured post-logout redirect uri", func(t *testing.T) {
lc, cfg := newLogoutCallback(t)
cfg.ClientConfig.PostLogoutRedirectURI = ""
cfg.WonderwallConfig.Ingress = "http://wonderwall"
err = lc.ValidateRequest()
if test.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
uri := lc.PostLogoutRedirectURI()
assert.NotEmpty(t, uri)
assert.Equal(t, "http://wonderwall", uri)
})
}
func newLogoutCallback(t *testing.T, url string, cookie *openid.LogoutCookie) (client.LogoutCallback, error) {
req, err := http.NewRequest("GET", url, nil)
func newLogoutCallback(t *testing.T) (client.LogoutCallback, mock.Configuration) {
req, err := http.NewRequest("GET", "http://wonderwall/oauth2/logout/callback", nil)
assert.NoError(t, err)
return newTestClient().LogoutCallback(req, cookie)
cfg := mock.NewTestConfiguration(mock.Config())
return newTestClientWithConfig(cfg).LogoutCallback(req), cfg
}

View File

@@ -16,28 +16,9 @@ const (
EndSessionEndpoint = "http://provider/endsession"
)
func TestLogout_CanonicalRedirect(t *testing.T) {
logout := newLogout(t)
canonicalRedirect := logout.CanonicalRedirect()
assert.Equal(t, PostLogoutRedirectURI, canonicalRedirect)
}
func TestLogout_Cookie(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
assert.NotNil(t, cookie)
assert.NotEmpty(t, cookie.State)
assert.NotEmpty(t, cookie.RedirectTo)
}
func TestLogout_SingleLogoutURL(t *testing.T) {
t.Run("with id_token", func(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
state := cookie.State
idToken := "some-id-token"
raw := logout.SingleLogoutURL(idToken)
@@ -46,43 +27,34 @@ func TestLogout_SingleLogoutURL(t *testing.T) {
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 3)
assert.Contains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
t.Run("without id_token", func(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
state := cookie.State
idToken := ""
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 2)
assert.NotContains(t, query, "id_token_hint")
assert.Contains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
t.Run("without id_token", func(t *testing.T) {
logout := newLogout(t)
idToken := ""
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 1)
assert.NotContains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))

View File

@@ -6,8 +6,3 @@ type LoginCookie struct {
CodeVerifier string `json:"code_verifier"`
Referer string `json:"referer"`
}
type LogoutCookie struct {
State string `json:"state"`
RedirectTo string `json:"redirect_to"`
}

View File

@@ -1,23 +1,16 @@
package router
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"github.com/go-redis/redis/v8"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
const (
LogoutCookieLifetime = 5 * time.Minute
)
// Logout triggers self-initiated for the current user
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
var idToken string
@@ -47,31 +40,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
logout, err := h.Client.Logout()
if err != nil {
h.InternalError(w, r, err)
}
err = h.setLogoutCookie(w, logout.Cookie())
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err))
return
}
fields := map[string]interface{}{
"redirect_to": logout.CanonicalRedirect(),
}
logger := logentry.LogEntryWithFields(r.Context(), fields)
logger := logentry.LogEntry(r.Context())
logger.Info().Msg("logout: redirecting to identity provider")
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
}
func (h *Handler) setLogoutCookie(w http.ResponseWriter, logoutCookie *openid.LogoutCookie) error {
logoutCookieJson, err := json.Marshal(logoutCookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}
opts := h.CookieOptions.WithExpiresIn(LogoutCookieLifetime)
value := string(logoutCookieJson)
return cookie.EncryptAndSet(w, cookie.Logout, value, opts, h.Crypter)
}

View File

@@ -1,55 +1,16 @@
package router
import (
"encoding/json"
"fmt"
"net/http"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
logentry "github.com/nais/wonderwall/pkg/router/middleware"
)
// LogoutCallback handles the callback from the self-initiated logout for the current user
func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
cookie.Clear(w, cookie.Logout, h.CookieOptions)
redirect := h.Client.LogoutCallback(r).PostLogoutRedirectURI()
logger := logentry.LogEntry(r.Context())
logoutCookie, err := h.getLogoutCookie(r)
if err != nil {
logger.Warn().Msgf("logout/callback: getting cookie: %+v", err)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)
return
}
logoutCallback, err := h.Client.LogoutCallback(r, logoutCookie)
if err != nil {
h.InternalError(w, r, err)
return
}
if err := logoutCallback.ValidateRequest(); err != nil {
logger.Warn().Msgf("logout/callback: %+v; falling back to ingress", err)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)
return
}
logger.Info().Msgf("logout/callback: redirecting to %s", logoutCookie.RedirectTo)
http.Redirect(w, r, logoutCookie.RedirectTo, http.StatusTemporaryRedirect)
}
func (h *Handler) getLogoutCookie(r *http.Request) (*openid.LogoutCookie, error) {
logoutCookieJson, err := cookie.GetDecrypted(r, cookie.Logout, h.Crypter)
if err != nil {
return nil, err
}
var logoutCookie openid.LogoutCookie
err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie)
if err != nil {
return nil, fmt.Errorf("unmarshalling: %w", err)
}
return &logoutCookie, nil
logger.Info().Msgf("logout/callback: redirecting to %s", redirect)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}

View File

@@ -175,8 +175,7 @@ func isRelevantCookie(name string) bool {
switch name {
case cookie.Session,
cookie.Login,
cookie.LoginLegacy,
cookie.Logout:
cookie.LoginLegacy:
return true
}

View File

@@ -133,10 +133,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
cookies = rpClient.Jar.Cookies(logoutURL)
sessionCookie = getCookieFromJar(cookie.Session, cookies)
logoutCookie := getCookieFromJar(cookie.Logout, cookies)
assert.Nil(t, sessionCookie)
assert.NotNil(t, logoutCookie)
// Get endsession endpoint after local logout
location = resp.Header.Get("location")
@@ -147,12 +145,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
expectedState := endsessionParams["state"]
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.OpenIDConfig.Client().GetLogoutCallbackURI()})
assert.NotEmpty(t, endsessionParams["id_token_hint"])
assert.NotEmpty(t, expectedState)
// Follow redirect to endsession endpoint at identity provider
resp, err = rpClient.Get(endsessionURL.String())
@@ -165,12 +161,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
logoutCallbackURI, err := url.Parse(location)
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI())
logoutCallbackParams := endsessionURL.Query()
actualState := logoutCallbackParams["state"]
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
assert.NotEmpty(t, actualState)
assert.Equal(t, expectedState, actualState)
// Follow redirect back to logout callback
resp, err = rpClient.Get(logoutCallbackURI.String())
@@ -185,10 +177,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
cookies = rpClient.Jar.Cookies(logoutCallbackURI)
sessionCookie = getCookieFromJar(cookie.Session, cookies)
logoutCookie = getCookieFromJar(cookie.Logout, cookies)
assert.Nil(t, sessionCookie)
assert.Nil(t, logoutCookie)
}
func TestHandler_FrontChannelLogout(t *testing.T) {