diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 8b5b69a..e511552 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -36,7 +36,11 @@ func run() error { defer cancel() crypt := crypto.NewCrypter(key) - sessionStore := session.NewStore(cfg) + sessionStore, err := session.NewStore(cfg) + if err != nil { + return err + } + h, err := handler.NewHandler(ctx, cfg, openidConfig, crypt, sessionStore) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) diff --git a/pkg/handler/handler_default.go b/pkg/handler/handler_default.go index fd513a1..b3deb48 100644 --- a/pkg/handler/handler_default.go +++ b/pkg/handler/handler_default.go @@ -12,7 +12,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { logger := mw.LogEntry(r).WithField("request_path", r.URL.Path) isAuthenticated := false - accessToken, ok := h.accessToken(w, r) + accessToken, ok := h.accessToken(r) if ok { // add authentication if session cookie and token checks out isAuthenticated = true @@ -44,8 +44,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx)) } -func (h *Handler) accessToken(w http.ResponseWriter, r *http.Request) (string, bool) { - sessionData, err := h.getSessionFromCookie(w, r) +func (h *Handler) accessToken(r *http.Request) (string, bool) { + sessionData, err := h.getSessionFromCookie(r) if err != nil || sessionData == nil || len(sessionData.AccessToken) == 0 { return "", false } diff --git a/pkg/handler/handler_frontchannellogout.go b/pkg/handler/handler_frontchannellogout.go index cf3a851..d63f7cd 100644 --- a/pkg/handler/handler_frontchannellogout.go +++ b/pkg/handler/handler_frontchannellogout.go @@ -22,7 +22,6 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { logoutFrontchannel := h.Client.LogoutFrontchannel(r) if logoutFrontchannel.MissingSidParameter() { logger.Debug("front-channel logout: sid parameter not set in request; ignoring") - h.DeleteSessionFallback(w, r) w.WriteHeader(http.StatusAccepted) return } @@ -36,7 +35,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { return } - err = h.destroySession(w, r, sessionID) + err = h.destroySession(r, sessionID) if err != nil { logger.Warnf("front-channel logout: destroying session: %+v", err) w.WriteHeader(http.StatusAccepted) diff --git a/pkg/handler/handler_logout.go b/pkg/handler/handler_logout.go index c42ddf5..ce2a4e7 100644 --- a/pkg/handler/handler_logout.go +++ b/pkg/handler/handler_logout.go @@ -24,11 +24,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { var idToken string - sessionData, err := h.getSessionFromCookie(w, r) + sessionData, err := h.getSessionFromCookie(r) if err == nil && sessionData != nil { idToken = sessionData.IDToken - err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID)) + err = h.destroySession(r, h.localSessionID(sessionData.ExternalSessionID)) if err != nil && !errors.Is(err, session.KeyNotFoundError) { h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) return diff --git a/pkg/handler/session.go b/pkg/handler/session.go index 4664198..5880ff9 100644 --- a/pkg/handler/session.go +++ b/pkg/handler/session.go @@ -10,7 +10,6 @@ import ( "github.com/sethvargo/go-retry" "github.com/nais/wonderwall/pkg/cookie" - logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" retrypkg "github.com/nais/wonderwall/pkg/retry" "github.com/nais/wonderwall/pkg/session" @@ -25,7 +24,7 @@ func (h *Handler) localSessionID(sessionID string) string { return fmt.Sprintf("%s:%s:%s", h.OpenIDConfig.Provider().Name(), h.OpenIDConfig.Client().ClientID(), sessionID) } -func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) { +func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { sessionID, err := cookie.GetDecrypted(r, cookie.Session, h.Crypter) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) @@ -33,7 +32,6 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( sessionData, err := h.getSession(r, sessionID) if err == nil { - h.DeleteSessionFallback(w, r) return sessionData, nil } @@ -41,14 +39,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return nil, fmt.Errorf("session not found in store: %w", err) } - logentry.LogEntry(r).Warnf("get session: store is unavailable: %+v; using cookie fallback", err) - - fallbackSessionData, err := h.GetSessionFallback(w, r) - if err != nil { - return nil, fmt.Errorf("getting fallback session: %w", err) - } - - return fallbackSessionData, nil + return nil, fmt.Errorf("get session: store is unavailable: %+v", err) } func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) { @@ -130,22 +121,14 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * return nil } - if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err == nil { - h.DeleteSessionFallback(w, r) - return nil - } - - logentry.LogEntry(r).Warnf("create session: store is unavailable: %+v; using cookie fallback", err) - - err = h.SetSessionFallback(w, r, sessionData, sessionLifetime) - if err != nil { - return fmt.Errorf("writing session to fallback store: %w", err) + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { + return fmt.Errorf("create session: store is unavailable: %+v", err) } return nil } -func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error { +func (h *Handler) destroySession(r *http.Request, sessionID string) error { retryable := func(ctx context.Context) error { err := h.Sessions.Delete(r.Context(), sessionID) if err == nil { @@ -164,6 +147,5 @@ func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, session return err } - h.DeleteSessionFallback(w, r) return nil } diff --git a/pkg/handler/session_fallback.go b/pkg/handler/session_fallback.go deleted file mode 100644 index dd57ad6..0000000 --- a/pkg/handler/session_fallback.go +++ /dev/null @@ -1,27 +0,0 @@ -package handler - -import ( - "net/http" - "time" - - "github.com/nais/wonderwall/pkg/session" -) - -func (h *Handler) SetSessionFallback(w http.ResponseWriter, r *http.Request, data *session.Data, expiresIn time.Duration) error { - store := h.cookieStore(w, r) - return store.Write(data, expiresIn) -} - -func (h *Handler) GetSessionFallback(w http.ResponseWriter, r *http.Request) (*session.Data, error) { - store := h.cookieStore(w, r) - return store.Read(r.Context()) -} - -func (h *Handler) DeleteSessionFallback(w http.ResponseWriter, r *http.Request) { - store := h.cookieStore(w, r) - store.Delete() -} - -func (h *Handler) cookieStore(w http.ResponseWriter, r *http.Request) session.CookieStore { - return session.NewCookie(w, r, h.Crypter, h.Provider, h.CookieOptions.WithPath(h.Path(r))) -} diff --git a/pkg/handler/session_fallback_test.go b/pkg/handler/session_fallback_test.go deleted file mode 100644 index e2eb4a7..0000000 --- a/pkg/handler/session_fallback_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package handler_test - -import ( - "context" - "encoding/base64" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/lestrrat-go/jwx/v2/jwa" - jwtlib "github.com/lestrrat-go/jwx/v2/jwt" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/handler" - "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/session" -) - -func TestHandler_GetSessionFallback(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - tokens := makeTokens(idp.Provider) - rpHandler := idp.RelyingPartyHandler - - t.Run("request without fallback session cookies", func(t *testing.T) { - r := idp.GetRequest("/") - w := httptest.NewRecorder() - _, err := rpHandler.GetSessionFallback(w, r) - assert.Error(t, err) - }) - - t.Run("request with fallback session cookies", func(t *testing.T) { - r := makeRequestWithFallbackCookies(t, idp, tokens) - w := httptest.NewRecorder() - sessionData, err := rpHandler.GetSessionFallback(w, r) - assert.NoError(t, err) - assert.Equal(t, "sid", sessionData.ExternalSessionID) - assert.Equal(t, tokens.AccessToken, sessionData.AccessToken) - assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken) - assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID) - assert.Empty(t, sessionData.RefreshToken) - }) -} - -func TestHandler_SetSessionFallback(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - tokens := makeTokens(idp.Provider) - rpHandler := idp.RelyingPartyHandler - - // request should set session cookies in response - writer := httptest.NewRecorder() - r := idp.GetRequest("/") - expiresIn := time.Minute - data := session.NewData("sid", tokens, nil) - err := rpHandler.SetSessionFallback(writer, r, data, expiresIn) - assert.NoError(t, err) - - cookies := writer.Result().Cookies() - - for _, test := range []struct { - cookieName string - want string - }{ - { - cookieName: "wonderwall-1", - want: "sid", - }, - { - cookieName: "wonderwall-2", - want: tokens.IDToken.GetSerialized(), - }, - { - cookieName: "wonderwall-3", - want: tokens.AccessToken, - }, - } { - assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies) - } -} - -func TestHandler_DeleteSessionFallback(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpHandler := idp.RelyingPartyHandler - tokens := makeTokens(idp.Provider) - - t.Run("expire cookies if they are set", func(t *testing.T) { - r := makeRequestWithFallbackCookies(t, idp, tokens) - writer := httptest.NewRecorder() - rpHandler.DeleteSessionFallback(writer, r) - cookies := writer.Result().Cookies() - - assert.NotEmpty(t, cookies) - assert.Len(t, cookies, 3) - - assertCookieExpired(t, "wonderwall-1", cookies) - assertCookieExpired(t, "wonderwall-2", cookies) - assertCookieExpired(t, "wonderwall-3", cookies) - }) - - t.Run("skip expiring cookies if they are not set", func(t *testing.T) { - writer := httptest.NewRecorder() - r := idp.GetRequest("/") - rpHandler.DeleteSessionFallback(writer, r) - cookies := writer.Result().Cookies() - - assert.Empty(t, cookies) - }) -} - -func makeRequestWithFallbackCookies(t *testing.T, idp *mock.IdentityProvider, tokens *openid.Tokens) *http.Request { - writer := httptest.NewRecorder() - r := mock.NewGetRequest("/", idp.OpenIDConfig) - - expiresIn := time.Minute - data := session.NewData("sid", tokens, nil) - err := idp.RelyingPartyHandler.SetSessionFallback(writer, r, data, expiresIn) - assert.NoError(t, err) - - cookies := writer.Result().Cookies() - - externalSessionIDCookie := getCookieFromJar("wonderwall-1", cookies) - assert.NotNil(t, externalSessionIDCookie) - idTokenCookie := getCookieFromJar("wonderwall-2", cookies) - assert.NotNil(t, idTokenCookie) - accessTokenCookie := getCookieFromJar("wonderwall-3", cookies) - assert.NotNil(t, accessTokenCookie) - - r.AddCookie(externalSessionIDCookie) - r.AddCookie(idTokenCookie) - r.AddCookie(accessTokenCookie) - - return r -} - -func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie) { - expired := getCookieFromJar(cookieName, cookies) - assert.NotNil(t, expired) - assert.Less(t, expired.MaxAge, 0) - assert.True(t, expired.Expires.Before(time.Now())) - assert.Empty(t, expired.Value) -} - -func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) { - desiredCookie := getCookieFromJar(cookieName, cookies) - assert.NotNil(t, desiredCookie) - - ciphertext, err := base64.StdEncoding.DecodeString(desiredCookie.Value) - assert.NoError(t, err) - - plainbytes, err := h.Crypter.Decrypt(ciphertext) - assert.NoError(t, err) - assert.Equal(t, expectedValue, string(plainbytes)) -} - -func makeTokens(provider *mock.TestProvider) *openid.Tokens { - jwks := *provider.PrivateJwkSet() - jwksPublic, err := provider.GetPublicJwkSet(context.TODO()) - if err != nil { - log.Fatalf("getting public jwk set: %+v", err) - } - - signer, ok := jwks.Key(0) - if !ok { - log.Fatalf("getting signer") - } - - idToken := jwtlib.New() - idToken.Set("jti", "id-token-jti") - - signedIdToken, err := jwtlib.Sign(idToken, jwtlib.WithKey(jwa.RS256, signer)) - if err != nil { - log.Fatalf("signing id_token: %+v", err) - } - - parsedIdToken, err := jwtlib.Parse(signedIdToken, jwtlib.WithKeySet(*jwksPublic)) - if err != nil { - log.Fatalf("parsing signed id_token: %+v", err) - } - - accessToken := "some-access-token" - - return &openid.Tokens{ - IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken), - AccessToken: accessToken, - } -} diff --git a/pkg/session/cookie.go b/pkg/session/cookie.go deleted file mode 100644 index bd5cfb7..0000000 --- a/pkg/session/cookie.go +++ /dev/null @@ -1,152 +0,0 @@ -package session - -import ( - "context" - "errors" - "fmt" - "net/http" - "time" - - "golang.org/x/oauth2" - - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/jwt" - "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/openid/provider" -) - -const ( - ExternalIDCookieName = "wonderwall-1" - IDTokenCookieName = "wonderwall-2" - AccessTokenCookieName = "wonderwall-3" -) - -type CookieStore interface { - Write(data *Data, expiration time.Duration) error - Read(ctx context.Context) (*Data, error) - Delete() -} - -type cookieSessionStore struct { - req *http.Request - rw http.ResponseWriter - crypter crypto.Crypter - provider provider.Provider - cookieOpts cookie.Options -} - -var _ CookieStore = &cookieSessionStore{} - -func NewCookie(rw http.ResponseWriter, req *http.Request, crypter crypto.Crypter, provider provider.Provider, opts cookie.Options) CookieStore { - return &cookieSessionStore{ - req: req, - rw: rw, - crypter: crypter, - provider: provider, - cookieOpts: opts, - } -} - -func (c *cookieSessionStore) Write(data *Data, expiration time.Duration) error { - opts := c.cookieOpts.WithExpiresIn(expiration) - - err := c.setCookie(ExternalIDCookieName, data.ExternalSessionID, opts) - if err != nil { - return fmt.Errorf("setting session id fallback cookie: %w", err) - } - - err = c.setCookie(IDTokenCookieName, data.IDToken, opts) - if err != nil { - return fmt.Errorf("setting session id_token fallback cookie: %w", err) - } - - err = c.setCookie(AccessTokenCookieName, data.AccessToken, opts) - if err != nil { - return fmt.Errorf("setting session access_token fallback cookie: %w", err) - } - - return nil -} - -func (c *cookieSessionStore) Read(ctx context.Context) (*Data, error) { - externalSessionID, err := c.getValue(ExternalIDCookieName) - if err != nil { - return nil, fmt.Errorf("reading session ID from fallback cookie: %w", err) - } - - idToken, err := c.getValue(IDTokenCookieName) - if err != nil { - return nil, fmt.Errorf("reading id_token from fallback cookie: %w", err) - } - - accessToken, err := c.getValue(AccessTokenCookieName) - if err != nil { - return nil, fmt.Errorf("reading access_token from fallback cookie: %w", err) - } - - jwkSet, err := c.provider.GetPublicJwkSet(ctx) - if err != nil { - return nil, fmt.Errorf("callback: getting jwks: %w", err) - } - - // TODO: currently a placeholder fallback value, should fetch from metadata cookie - expiry := time.Now().Add(time.Hour) - - // attempt to get expiry from access_token if it is a JWT - parsedAccessToken, err := jwt.Parse(accessToken, *jwkSet) - if err == nil { - expiry = parsedAccessToken.Expiration() - } - - // TODO: set refresh token and metadata - rawTokens := &oauth2.Token{ - AccessToken: accessToken, - TokenType: "Bearer", - RefreshToken: "", - Expiry: expiry, - } - rawTokens = rawTokens.WithExtra(map[string]interface{}{ - "id_token": idToken, - }) - - tokens, err := openid.NewTokens(rawTokens, *jwkSet) - if err != nil { - // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt - _, _ = c.provider.RefreshPublicJwkSet(ctx) - return nil, fmt.Errorf("parsing tokens: %w", err) - } - - return NewData(externalSessionID, tokens, nil), nil -} - -func (c *cookieSessionStore) Delete() { - for _, name := range c.allCookieNames() { - c.deleteIfNotFound(name) - } -} - -func (c *cookieSessionStore) allCookieNames() []string { - return []string{ - ExternalIDCookieName, - IDTokenCookieName, - AccessTokenCookieName, - } -} - -func (c *cookieSessionStore) deleteIfNotFound(cookieName string) { - _, err := c.req.Cookie(cookieName) - if errors.Is(err, http.ErrNoCookie) { - return - } - - cookie.Clear(c.rw, cookieName, c.cookieOpts) -} - -func (c *cookieSessionStore) setCookie(name, value string, opts cookie.Options) error { - return cookie.EncryptAndSet(c.rw, name, value, opts, c.crypter) -} - -func (c *cookieSessionStore) getValue(name string) (string, error) { - return cookie.GetDecrypted(c.req, name, c.crypter) -} diff --git a/pkg/session/session.go b/pkg/session/session.go index 64d1d60..d9ee092 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -3,6 +3,7 @@ package session import ( "context" "errors" + "fmt" "time" log "github.com/sirupsen/logrus" @@ -20,15 +21,15 @@ type Store interface { Delete(ctx context.Context, keys ...string) error } -func NewStore(cfg *config.Config) Store { +func NewStore(cfg *config.Config) (Store, error) { if len(cfg.Redis.Address) == 0 { log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") - return NewMemory() + return NewMemory(), nil } redisClient, err := cfg.Redis.Client() if err != nil { - log.Fatalf("Failed to configure Redis: %v", err) + return nil, fmt.Errorf("failed to create Redis Client: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) @@ -36,10 +37,10 @@ func NewStore(cfg *config.Config) Store { err = redisClient.Ping(ctx).Err() if err != nil { - log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err) + return nil, fmt.Errorf("failed to connect to configured Redis: %w", err) } else { log.Infof("Using Redis as session backing store") } - return NewRedis(redisClient) + return NewRedis(redisClient), nil }