refactor: log session store unavailability, ensure fallback cookies are deleted when no longer needed

This commit is contained in:
Trong Huu Nguyen
2021-10-13 08:49:53 +02:00
parent f7f476db87
commit d0482b3490
5 changed files with 29 additions and 23 deletions

View File

@@ -9,9 +9,9 @@ import (
)
func TestParseIngress(t *testing.T) {
for _, test := range []struct{
for _, test := range []struct {
ingress string
want string
want string
}{
{
ingress: "https://tjenester.nav.no/sykepenger/",
@@ -33,7 +33,6 @@ func TestParseIngress(t *testing.T) {
ingress: "https://sykepenger-test.nav.no",
want: "",
},
} {
t.Run(test.ingress, func(t *testing.T) {
prefix := config.ParseIngress(test.ingress)

View File

@@ -32,7 +32,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second * 30)
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err))
return

View File

@@ -18,7 +18,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
upstreamRequest.Header.Del("authorization")
upstreamRequest.Header.Del("x-pwned-by")
sess, err := h.getSessionFromCookie(r)
sess, err := h.getSessionFromCookie(w, r)
if err == nil && sess != nil && len(sess.AccessToken) > 0 {
// add authentication if session cookie and token checks out
upstreamRequest.Header.Add("authorization", "Bearer "+sess.AccessToken)

View File

@@ -17,7 +17,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
var idToken string
sess, err := h.getSessionFromCookie(r)
sess, err := h.getSessionFromCookie(w, r)
if err == nil && sess != nil {
idToken = sess.IDToken
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))

View File

@@ -3,13 +3,16 @@ package router
import (
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/lestrrat-go/jwx/jwt"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
"golang.org/x/oauth2"
"net/http"
"time"
"github.com/go-redis/redis/v8"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
)
// localSessionID prefixes the given `sid` with the given client ID to prevent key collisions.
@@ -21,32 +24,34 @@ func (h *Handler) localSessionID(sid string) string {
return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid)
}
func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
sessionID, err := h.getEncryptedCookie(r, h.GetSessionCookieName())
if err != nil {
return nil, fmt.Errorf("no session cookie: %w", err)
}
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("session not found in store: %w", err)
}
fallbackSessionData, err := h.GetSessionFallback(r)
if err == nil {
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
if err != nil {
return nil, fmt.Errorf("fallback session not found: %w", err)
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return fallbackSessionData, nil
h.DeleteSessionFallback(w)
return sessionData, nil
}
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
if errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("session not found in store: %w", err)
}
log.Warnf("get session: store is unavailable; using cookie fallback: %+v", err)
fallbackSessionData, err := h.GetSessionFallback(r)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
return nil, fmt.Errorf("fallback session not found: %w", err)
}
return sessionData, nil
return fallbackSessionData, nil
}
func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) {
@@ -88,9 +93,11 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
if err == nil {
h.DeleteSessionFallback(w)
return nil
}
log.Warnf("create session: store is unavailable; using cookie fallback: %+v", err)
err = h.SetSessionFallback(w, sessionData, sessionLifetime)
if err != nil {
return fmt.Errorf("writing session to fallback store: %w", err)