From d0482b349059385dd143d3c76c2673a7f7002441 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Wed, 13 Oct 2021 08:49:53 +0200 Subject: [PATCH] refactor: log session store unavailability, ensure fallback cookies are deleted when no longer needed --- pkg/config/ingress_test.go | 5 ++--- pkg/router/handler_callback.go | 2 +- pkg/router/handler_default.go | 2 +- pkg/router/handler_logout.go | 2 +- pkg/router/session.go | 41 ++++++++++++++++++++-------------- 5 files changed, 29 insertions(+), 23 deletions(-) diff --git a/pkg/config/ingress_test.go b/pkg/config/ingress_test.go index 940ed08..a1057cd 100644 --- a/pkg/config/ingress_test.go +++ b/pkg/config/ingress_test.go @@ -9,9 +9,9 @@ import ( ) func TestParseIngress(t *testing.T) { - for _, test := range []struct{ + for _, test := range []struct { ingress string - want string + want string }{ { ingress: "https://tjenester.nav.no/sykepenger/", @@ -33,7 +33,6 @@ func TestParseIngress(t *testing.T) { ingress: "https://sykepenger-test.nav.no", want: "", }, - } { t.Run(test.ingress, func(t *testing.T) { prefix := config.ParseIngress(test.ingress) diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index cf42a48..5ac5513 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -32,7 +32,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second * 30) + assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err)) return diff --git a/pkg/router/handler_default.go b/pkg/router/handler_default.go index f0fde06..bf556b9 100644 --- a/pkg/router/handler_default.go +++ b/pkg/router/handler_default.go @@ -18,7 +18,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { upstreamRequest.Header.Del("authorization") upstreamRequest.Header.Del("x-pwned-by") - sess, err := h.getSessionFromCookie(r) + sess, err := h.getSessionFromCookie(w, r) if err == nil && sess != nil && len(sess.AccessToken) > 0 { // add authentication if session cookie and token checks out upstreamRequest.Header.Add("authorization", "Bearer "+sess.AccessToken) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 25f6715..1c993e7 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -17,7 +17,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { var idToken string - sess, err := h.getSessionFromCookie(r) + sess, err := h.getSessionFromCookie(w, r) if err == nil && sess != nil { idToken = sess.IDToken err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID)) diff --git a/pkg/router/session.go b/pkg/router/session.go index 585c6d5..32fdf61 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -3,13 +3,16 @@ package router import ( "errors" "fmt" - "github.com/go-redis/redis/v8" - "github.com/lestrrat-go/jwx/jwt" - "github.com/nais/wonderwall/pkg/session" - "github.com/nais/wonderwall/pkg/token" - "golang.org/x/oauth2" "net/http" "time" + + "github.com/go-redis/redis/v8" + "github.com/lestrrat-go/jwx/jwt" + log "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + + "github.com/nais/wonderwall/pkg/session" + "github.com/nais/wonderwall/pkg/token" ) // localSessionID prefixes the given `sid` with the given client ID to prevent key collisions. @@ -21,32 +24,34 @@ func (h *Handler) localSessionID(sid string) string { return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid) } -func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { +func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) { sessionID, err := h.getEncryptedCookie(r, h.GetSessionCookieName()) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) } encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID) - if err != nil { - if errors.Is(err, redis.Nil) { - return nil, fmt.Errorf("session not found in store: %w", err) - } - - fallbackSessionData, err := h.GetSessionFallback(r) + if err == nil { + sessionData, err := encryptedSessionData.Decrypt(h.Crypter) if err != nil { - return nil, fmt.Errorf("fallback session not found: %w", err) + return nil, fmt.Errorf("decrypting session data: %w", err) } - return fallbackSessionData, nil + h.DeleteSessionFallback(w) + return sessionData, nil } - sessionData, err := encryptedSessionData.Decrypt(h.Crypter) + if errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("session not found in store: %w", err) + } + + log.Warnf("get session: store is unavailable; using cookie fallback: %+v", err) + fallbackSessionData, err := h.GetSessionFallback(r) if err != nil { - return nil, fmt.Errorf("decrypting session data: %w", err) + return nil, fmt.Errorf("fallback session not found: %w", err) } - return sessionData, nil + return fallbackSessionData, nil } func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) { @@ -88,9 +93,11 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime) if err == nil { + h.DeleteSessionFallback(w) return nil } + log.Warnf("create session: store is unavailable; using cookie fallback: %+v", err) err = h.SetSessionFallback(w, sessionData, sessionLifetime) if err != nil { return fmt.Errorf("writing session to fallback store: %w", err)