mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-21 15:52:54 +00:00
refactor: log session store unavailability, ensure fallback cookies are deleted when no longer needed
This commit is contained in:
@@ -9,9 +9,9 @@ import (
|
||||
)
|
||||
|
||||
func TestParseIngress(t *testing.T) {
|
||||
for _, test := range []struct{
|
||||
for _, test := range []struct {
|
||||
ingress string
|
||||
want string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
ingress: "https://tjenester.nav.no/sykepenger/",
|
||||
@@ -33,7 +33,6 @@ func TestParseIngress(t *testing.T) {
|
||||
ingress: "https://sykepenger-test.nav.no",
|
||||
want: "",
|
||||
},
|
||||
|
||||
} {
|
||||
t.Run(test.ingress, func(t *testing.T) {
|
||||
prefix := config.ParseIngress(test.ingress)
|
||||
|
||||
@@ -32,7 +32,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second * 30)
|
||||
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err))
|
||||
return
|
||||
|
||||
@@ -18,7 +18,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamRequest.Header.Del("authorization")
|
||||
upstreamRequest.Header.Del("x-pwned-by")
|
||||
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
sess, err := h.getSessionFromCookie(w, r)
|
||||
if err == nil && sess != nil && len(sess.AccessToken) > 0 {
|
||||
// add authentication if session cookie and token checks out
|
||||
upstreamRequest.Header.Add("authorization", "Bearer "+sess.AccessToken)
|
||||
|
||||
@@ -17,7 +17,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var idToken string
|
||||
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
sess, err := h.getSessionFromCookie(w, r)
|
||||
if err == nil && sess != nil {
|
||||
idToken = sess.IDToken
|
||||
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))
|
||||
|
||||
@@ -3,13 +3,16 @@ package router
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
// localSessionID prefixes the given `sid` with the given client ID to prevent key collisions.
|
||||
@@ -21,32 +24,34 @@ func (h *Handler) localSessionID(sid string) string {
|
||||
return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid)
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
|
||||
sessionID, err := h.getEncryptedCookie(r, h.GetSessionCookieName())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("session not found in store: %w", err)
|
||||
}
|
||||
|
||||
fallbackSessionData, err := h.GetSessionFallback(r)
|
||||
if err == nil {
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback session not found: %w", err)
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return fallbackSessionData, nil
|
||||
h.DeleteSessionFallback(w)
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("session not found in store: %w", err)
|
||||
}
|
||||
|
||||
log.Warnf("get session: store is unavailable; using cookie fallback: %+v", err)
|
||||
fallbackSessionData, err := h.GetSessionFallback(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
return nil, fmt.Errorf("fallback session not found: %w", err)
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
return fallbackSessionData, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) {
|
||||
@@ -88,9 +93,11 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
||||
if err == nil {
|
||||
h.DeleteSessionFallback(w)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warnf("create session: store is unavailable; using cookie fallback: %+v", err)
|
||||
err = h.SetSessionFallback(w, sessionData, sessionLifetime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing session to fallback store: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user