refactor: remove cookie session fallback store

The implementation is error-prone and difficult to maintain.
We instead just assume that the backing session store is highly
available.
This commit is contained in:
Trong Huu Nguyen
2022-08-16 09:42:48 +02:00
parent 5a50ba7c3a
commit ae8028cc96
9 changed files with 22 additions and 412 deletions

View File

@@ -36,7 +36,11 @@ func run() error {
defer cancel()
crypt := crypto.NewCrypter(key)
sessionStore := session.NewStore(cfg)
sessionStore, err := session.NewStore(cfg)
if err != nil {
return err
}
h, err := handler.NewHandler(ctx, cfg, openidConfig, crypt, sessionStore)
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)

View File

@@ -12,7 +12,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntry(r).WithField("request_path", r.URL.Path)
isAuthenticated := false
accessToken, ok := h.accessToken(w, r)
accessToken, ok := h.accessToken(r)
if ok {
// add authentication if session cookie and token checks out
isAuthenticated = true
@@ -44,8 +44,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
}
func (h *Handler) accessToken(w http.ResponseWriter, r *http.Request) (string, bool) {
sessionData, err := h.getSessionFromCookie(w, r)
func (h *Handler) accessToken(r *http.Request) (string, bool) {
sessionData, err := h.getSessionFromCookie(r)
if err != nil || sessionData == nil || len(sessionData.AccessToken) == 0 {
return "", false
}

View File

@@ -22,7 +22,6 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
logoutFrontchannel := h.Client.LogoutFrontchannel(r)
if logoutFrontchannel.MissingSidParameter() {
logger.Debug("front-channel logout: sid parameter not set in request; ignoring")
h.DeleteSessionFallback(w, r)
w.WriteHeader(http.StatusAccepted)
return
}
@@ -36,7 +35,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
return
}
err = h.destroySession(w, r, sessionID)
err = h.destroySession(r, sessionID)
if err != nil {
logger.Warnf("front-channel logout: destroying session: %+v", err)
w.WriteHeader(http.StatusAccepted)

View File

@@ -24,11 +24,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
var idToken string
sessionData, err := h.getSessionFromCookie(w, r)
sessionData, err := h.getSessionFromCookie(r)
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
err = h.destroySession(r, h.localSessionID(sessionData.ExternalSessionID))
if err != nil && !errors.Is(err, session.KeyNotFoundError) {
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return

View File

@@ -10,7 +10,6 @@ import (
"github.com/sethvargo/go-retry"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
retrypkg "github.com/nais/wonderwall/pkg/retry"
"github.com/nais/wonderwall/pkg/session"
@@ -25,7 +24,7 @@ func (h *Handler) localSessionID(sessionID string) string {
return fmt.Sprintf("%s:%s:%s", h.OpenIDConfig.Provider().Name(), h.OpenIDConfig.Client().ClientID(), sessionID)
}
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
sessionID, err := cookie.GetDecrypted(r, cookie.Session, h.Crypter)
if err != nil {
return nil, fmt.Errorf("no session cookie: %w", err)
@@ -33,7 +32,6 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
sessionData, err := h.getSession(r, sessionID)
if err == nil {
h.DeleteSessionFallback(w, r)
return sessionData, nil
}
@@ -41,14 +39,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return nil, fmt.Errorf("session not found in store: %w", err)
}
logentry.LogEntry(r).Warnf("get session: store is unavailable: %+v; using cookie fallback", err)
fallbackSessionData, err := h.GetSessionFallback(w, r)
if err != nil {
return nil, fmt.Errorf("getting fallback session: %w", err)
}
return fallbackSessionData, nil
return nil, fmt.Errorf("get session: store is unavailable: %+v", err)
}
func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) {
@@ -130,22 +121,14 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
return nil
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err == nil {
h.DeleteSessionFallback(w, r)
return nil
}
logentry.LogEntry(r).Warnf("create session: store is unavailable: %+v; using cookie fallback", err)
err = h.SetSessionFallback(w, r, sessionData, sessionLifetime)
if err != nil {
return fmt.Errorf("writing session to fallback store: %w", err)
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return fmt.Errorf("create session: store is unavailable: %+v", err)
}
return nil
}
func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error {
func (h *Handler) destroySession(r *http.Request, sessionID string) error {
retryable := func(ctx context.Context) error {
err := h.Sessions.Delete(r.Context(), sessionID)
if err == nil {
@@ -164,6 +147,5 @@ func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, session
return err
}
h.DeleteSessionFallback(w, r)
return nil
}

View File

@@ -1,27 +0,0 @@
package handler
import (
"net/http"
"time"
"github.com/nais/wonderwall/pkg/session"
)
func (h *Handler) SetSessionFallback(w http.ResponseWriter, r *http.Request, data *session.Data, expiresIn time.Duration) error {
store := h.cookieStore(w, r)
return store.Write(data, expiresIn)
}
func (h *Handler) GetSessionFallback(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
store := h.cookieStore(w, r)
return store.Read(r.Context())
}
func (h *Handler) DeleteSessionFallback(w http.ResponseWriter, r *http.Request) {
store := h.cookieStore(w, r)
store.Delete()
}
func (h *Handler) cookieStore(w http.ResponseWriter, r *http.Request) session.CookieStore {
return session.NewCookie(w, r, h.Crypter, h.Provider, h.CookieOptions.WithPath(h.Path(r)))
}

View File

@@ -1,197 +0,0 @@
package handler_test
import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
)
func TestHandler_GetSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
tokens := makeTokens(idp.Provider)
rpHandler := idp.RelyingPartyHandler
t.Run("request without fallback session cookies", func(t *testing.T) {
r := idp.GetRequest("/")
w := httptest.NewRecorder()
_, err := rpHandler.GetSessionFallback(w, r)
assert.Error(t, err)
})
t.Run("request with fallback session cookies", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t, idp, tokens)
w := httptest.NewRecorder()
sessionData, err := rpHandler.GetSessionFallback(w, r)
assert.NoError(t, err)
assert.Equal(t, "sid", sessionData.ExternalSessionID)
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
assert.Empty(t, sessionData.RefreshToken)
})
}
func TestHandler_SetSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
tokens := makeTokens(idp.Provider)
rpHandler := idp.RelyingPartyHandler
// request should set session cookies in response
writer := httptest.NewRecorder()
r := idp.GetRequest("/")
expiresIn := time.Minute
data := session.NewData("sid", tokens, nil)
err := rpHandler.SetSessionFallback(writer, r, data, expiresIn)
assert.NoError(t, err)
cookies := writer.Result().Cookies()
for _, test := range []struct {
cookieName string
want string
}{
{
cookieName: "wonderwall-1",
want: "sid",
},
{
cookieName: "wonderwall-2",
want: tokens.IDToken.GetSerialized(),
},
{
cookieName: "wonderwall-3",
want: tokens.AccessToken,
},
} {
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
}
}
func TestHandler_DeleteSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpHandler := idp.RelyingPartyHandler
tokens := makeTokens(idp.Provider)
t.Run("expire cookies if they are set", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t, idp, tokens)
writer := httptest.NewRecorder()
rpHandler.DeleteSessionFallback(writer, r)
cookies := writer.Result().Cookies()
assert.NotEmpty(t, cookies)
assert.Len(t, cookies, 3)
assertCookieExpired(t, "wonderwall-1", cookies)
assertCookieExpired(t, "wonderwall-2", cookies)
assertCookieExpired(t, "wonderwall-3", cookies)
})
t.Run("skip expiring cookies if they are not set", func(t *testing.T) {
writer := httptest.NewRecorder()
r := idp.GetRequest("/")
rpHandler.DeleteSessionFallback(writer, r)
cookies := writer.Result().Cookies()
assert.Empty(t, cookies)
})
}
func makeRequestWithFallbackCookies(t *testing.T, idp *mock.IdentityProvider, tokens *openid.Tokens) *http.Request {
writer := httptest.NewRecorder()
r := mock.NewGetRequest("/", idp.OpenIDConfig)
expiresIn := time.Minute
data := session.NewData("sid", tokens, nil)
err := idp.RelyingPartyHandler.SetSessionFallback(writer, r, data, expiresIn)
assert.NoError(t, err)
cookies := writer.Result().Cookies()
externalSessionIDCookie := getCookieFromJar("wonderwall-1", cookies)
assert.NotNil(t, externalSessionIDCookie)
idTokenCookie := getCookieFromJar("wonderwall-2", cookies)
assert.NotNil(t, idTokenCookie)
accessTokenCookie := getCookieFromJar("wonderwall-3", cookies)
assert.NotNil(t, accessTokenCookie)
r.AddCookie(externalSessionIDCookie)
r.AddCookie(idTokenCookie)
r.AddCookie(accessTokenCookie)
return r
}
func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie) {
expired := getCookieFromJar(cookieName, cookies)
assert.NotNil(t, expired)
assert.Less(t, expired.MaxAge, 0)
assert.True(t, expired.Expires.Before(time.Now()))
assert.Empty(t, expired.Value)
}
func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
desiredCookie := getCookieFromJar(cookieName, cookies)
assert.NotNil(t, desiredCookie)
ciphertext, err := base64.StdEncoding.DecodeString(desiredCookie.Value)
assert.NoError(t, err)
plainbytes, err := h.Crypter.Decrypt(ciphertext)
assert.NoError(t, err)
assert.Equal(t, expectedValue, string(plainbytes))
}
func makeTokens(provider *mock.TestProvider) *openid.Tokens {
jwks := *provider.PrivateJwkSet()
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
if err != nil {
log.Fatalf("getting public jwk set: %+v", err)
}
signer, ok := jwks.Key(0)
if !ok {
log.Fatalf("getting signer")
}
idToken := jwtlib.New()
idToken.Set("jti", "id-token-jti")
signedIdToken, err := jwtlib.Sign(idToken, jwtlib.WithKey(jwa.RS256, signer))
if err != nil {
log.Fatalf("signing id_token: %+v", err)
}
parsedIdToken, err := jwtlib.Parse(signedIdToken, jwtlib.WithKeySet(*jwksPublic))
if err != nil {
log.Fatalf("parsing signed id_token: %+v", err)
}
accessToken := "some-access-token"
return &openid.Tokens{
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
AccessToken: accessToken,
}
}

View File

@@ -1,152 +0,0 @@
package session
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/jwt"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/provider"
)
const (
ExternalIDCookieName = "wonderwall-1"
IDTokenCookieName = "wonderwall-2"
AccessTokenCookieName = "wonderwall-3"
)
type CookieStore interface {
Write(data *Data, expiration time.Duration) error
Read(ctx context.Context) (*Data, error)
Delete()
}
type cookieSessionStore struct {
req *http.Request
rw http.ResponseWriter
crypter crypto.Crypter
provider provider.Provider
cookieOpts cookie.Options
}
var _ CookieStore = &cookieSessionStore{}
func NewCookie(rw http.ResponseWriter, req *http.Request, crypter crypto.Crypter, provider provider.Provider, opts cookie.Options) CookieStore {
return &cookieSessionStore{
req: req,
rw: rw,
crypter: crypter,
provider: provider,
cookieOpts: opts,
}
}
func (c *cookieSessionStore) Write(data *Data, expiration time.Duration) error {
opts := c.cookieOpts.WithExpiresIn(expiration)
err := c.setCookie(ExternalIDCookieName, data.ExternalSessionID, opts)
if err != nil {
return fmt.Errorf("setting session id fallback cookie: %w", err)
}
err = c.setCookie(IDTokenCookieName, data.IDToken, opts)
if err != nil {
return fmt.Errorf("setting session id_token fallback cookie: %w", err)
}
err = c.setCookie(AccessTokenCookieName, data.AccessToken, opts)
if err != nil {
return fmt.Errorf("setting session access_token fallback cookie: %w", err)
}
return nil
}
func (c *cookieSessionStore) Read(ctx context.Context) (*Data, error) {
externalSessionID, err := c.getValue(ExternalIDCookieName)
if err != nil {
return nil, fmt.Errorf("reading session ID from fallback cookie: %w", err)
}
idToken, err := c.getValue(IDTokenCookieName)
if err != nil {
return nil, fmt.Errorf("reading id_token from fallback cookie: %w", err)
}
accessToken, err := c.getValue(AccessTokenCookieName)
if err != nil {
return nil, fmt.Errorf("reading access_token from fallback cookie: %w", err)
}
jwkSet, err := c.provider.GetPublicJwkSet(ctx)
if err != nil {
return nil, fmt.Errorf("callback: getting jwks: %w", err)
}
// TODO: currently a placeholder fallback value, should fetch from metadata cookie
expiry := time.Now().Add(time.Hour)
// attempt to get expiry from access_token if it is a JWT
parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet)
if err == nil {
expiry = parsedAccessToken.Expiration()
}
// TODO: set refresh token and metadata
rawTokens := &oauth2.Token{
AccessToken: accessToken,
TokenType: "Bearer",
RefreshToken: "",
Expiry: expiry,
}
rawTokens = rawTokens.WithExtra(map[string]interface{}{
"id_token": idToken,
})
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
if err != nil {
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
_, _ = c.provider.RefreshPublicJwkSet(ctx)
return nil, fmt.Errorf("parsing tokens: %w", err)
}
return NewData(externalSessionID, tokens, nil), nil
}
func (c *cookieSessionStore) Delete() {
for _, name := range c.allCookieNames() {
c.deleteIfNotFound(name)
}
}
func (c *cookieSessionStore) allCookieNames() []string {
return []string{
ExternalIDCookieName,
IDTokenCookieName,
AccessTokenCookieName,
}
}
func (c *cookieSessionStore) deleteIfNotFound(cookieName string) {
_, err := c.req.Cookie(cookieName)
if errors.Is(err, http.ErrNoCookie) {
return
}
cookie.Clear(c.rw, cookieName, c.cookieOpts)
}
func (c *cookieSessionStore) setCookie(name, value string, opts cookie.Options) error {
return cookie.EncryptAndSet(c.rw, name, value, opts, c.crypter)
}
func (c *cookieSessionStore) getValue(name string) (string, error) {
return cookie.GetDecrypted(c.req, name, c.crypter)
}

View File

@@ -3,6 +3,7 @@ package session
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
@@ -20,15 +21,15 @@ type Store interface {
Delete(ctx context.Context, keys ...string) error
}
func NewStore(cfg *config.Config) Store {
func NewStore(cfg *config.Config) (Store, error) {
if len(cfg.Redis.Address) == 0 {
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
return NewMemory()
return NewMemory(), nil
}
redisClient, err := cfg.Redis.Client()
if err != nil {
log.Fatalf("Failed to configure Redis: %v", err)
return nil, fmt.Errorf("failed to create Redis Client: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
@@ -36,10 +37,10 @@ func NewStore(cfg *config.Config) Store {
err = redisClient.Ping(ctx).Err()
if err != nil {
log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err)
return nil, fmt.Errorf("failed to connect to configured Redis: %w", err)
} else {
log.Infof("Using Redis as session backing store")
}
return NewRedis(redisClient)
return NewRedis(redisClient), nil
}