diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 0577705..4aaa0f2 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -79,10 +79,10 @@ func run() error { return fmt.Errorf("connecting to redis: %w", err) } - sessionStore = session.NewRedis(redisClient, crypt) + sessionStore = session.NewRedis(redisClient) log.Infof("Using Redis as session backing store") } else { - sessionStore = session.NewMemory(crypt) + sessionStore = session.NewMemory() log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") } diff --git a/pkg/router/router.go b/pkg/router/router.go index 5f98856..1d8a463 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -372,22 +372,20 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { sessionID := h.localSessionID(sid) - sess, err := h.Sessions.Read(r.Context(), sessionID) - if err != nil { + // From here on, check that 'iss' from request matches data found in access token. + accessToken, err := h.getAccessTokenFromSession(r, sessionID) + if accessToken == nil { // Can't remove session because it doesn't exist. Maybe it was garbage collected. // We regard this as a redundant logout and return 200 OK. return } - - // From here on, check that 'iss' from request matches data found in access token. - tok, err := jwt.Parse([]byte(sess.OAuth2Token.AccessToken)) if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) return } - err = jwt.Validate(tok, jwt.WithClaimValue("iss", iss)) + err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss)) if err != nil { w.WriteHeader(http.StatusBadRequest) return diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index da038dd..5fb8222 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler { } crypter := cryptutil.New(encryptionKey) - sessionStore := session.NewMemory(crypter) + sessionStore := session.NewMemory() handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "") if err != nil { diff --git a/pkg/router/session.go b/pkg/router/session.go index b32dcdf..6a5769d 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -25,11 +25,16 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { return nil, fmt.Errorf("no session cookie: %w", err) } - sessionData, err := h.Sessions.Read(r.Context(), sessionID) + encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID) if err != nil { return nil, fmt.Errorf("reading session from store: %w", err) } + sessionData, err := encryptedSessionData.Decrypt(h.Crypter) + if err != nil { + return nil, fmt.Errorf("decrypting session data: %w", err) + } + return sessionData, nil } @@ -69,10 +74,36 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external IDTokenSerialized: idToken.Raw, } - err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime) + encryptedSessionData, err := sessionData.Encrypt(h.Crypter) + if err != nil { + return fmt.Errorf("encrypting session data: %w", err) + } + + err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime) if err != nil { return fmt.Errorf("writing session to store: %w", err) } return nil } + +func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) { + encryptedSession, err := h.Sessions.Read(r.Context(), sessionID) + if err != nil { + // Session not found; ignoring + return nil, nil + } + + sessionData, err := encryptedSession.Decrypt(h.Crypter) + if err != nil { + // Can't decrypt, likely not our session; ignoring + return nil, nil + } + + accessToken, err := jwt.Parse([]byte(sessionData.OAuth2Token.AccessToken)) + if err != nil { + return nil, fmt.Errorf("parsing session access token: %w", err) + } + + return accessToken, nil +} diff --git a/pkg/session/memory.go b/pkg/session/memory.go index 71b13c3..4615e72 100644 --- a/pkg/session/memory.go +++ b/pkg/session/memory.go @@ -3,7 +3,6 @@ package session import ( "context" "fmt" - "github.com/nais/wonderwall/pkg/cryptutil" "sync" "time" ) @@ -11,19 +10,17 @@ import ( type memorySessionStore struct { lock sync.Mutex sessions map[string]*EncryptedData - crypter cryptutil.Crypter } var _ Store = &memorySessionStore{} -func NewMemory(crypter cryptutil.Crypter) Store { +func NewMemory() Store { return &memorySessionStore{ sessions: make(map[string]*EncryptedData), - crypter: crypter, } } -func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) { +func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData, error) { s.lock.Lock() defer s.lock.Unlock() @@ -32,24 +29,14 @@ func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) return nil, fmt.Errorf("no such session: %s", key) } - decrypted, err := data.Decrypt(s.crypter) - if err != nil { - return nil, fmt.Errorf("decrypting session data: %w", err) - } - - return decrypted, nil + return data, nil } -func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error { +func (s *memorySessionStore) Write(_ context.Context, key string, value *EncryptedData, expiration time.Duration) error { s.lock.Lock() defer s.lock.Unlock() - encrypted, err := value.Encrypt(s.crypter) - if err != nil { - return fmt.Errorf("encrypting session data: %w", err) - } - - s.sessions[key] = encrypted + s.sessions[key] = value return nil } diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index 672d94d..f0d0366 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -26,13 +26,20 @@ func TestMemory(t *testing.T) { IDTokenSerialized: "idtoken", } - sess := session.NewMemory(crypter) - err = sess.Write(context.Background(), "key", data, time.Minute) + encryptedData, err := data.Encrypt(crypter) + assert.NoError(t, err) + + sess := session.NewMemory() + err = sess.Write(context.Background(), "key", encryptedData, time.Minute) assert.NoError(t, err) result, err := sess.Read(context.Background(), "key") assert.NoError(t, err) - assert.Equal(t, data, result) + assert.Equal(t, encryptedData, result) + + decrypted, err := result.Decrypt(crypter) + assert.NoError(t, err) + assert.Equal(t, data, decrypted) err = sess.Delete(context.Background(), "key") diff --git a/pkg/session/redis.go b/pkg/session/redis.go index 42ff07d..ac672ef 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -2,28 +2,24 @@ package session import ( "context" - "fmt" "github.com/go-redis/redis/v8" - "github.com/nais/wonderwall/pkg/cryptutil" "github.com/nais/wonderwall/pkg/metrics" "time" ) type redisSessionStore struct { client redis.Cmdable - crypter cryptutil.Crypter } var _ Store = &redisSessionStore{} -func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store { +func NewRedis(client redis.Cmdable) Store { return &redisSessionStore{ client: client, - crypter: crypter, } } -func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) { +func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) { encryptedData := &EncryptedData{} err := metrics.ObserveRedisLatency("Read", func() error { var err error @@ -35,22 +31,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) return nil, err } - data, err := encryptedData.Decrypt(s.crypter) - if err != nil { - return nil, fmt.Errorf("decrypting session data: %w", err) - } - - return data, nil + return encryptedData, nil } -func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error { - encryptedData, err := value.Encrypt(s.crypter) - if err != nil { - return err - } - +func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error { return metrics.ObserveRedisLatency("Write", func() error { - status := s.client.Set(ctx, key, encryptedData, expiration) + status := s.client.Set(ctx, key, value, expiration) return status.Err() }) } diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index 2f8be9e..e6f236f 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -30,18 +30,25 @@ func TestRedis(t *testing.T) { IDTokenSerialized: "idtoken", } + encryptedData, err := data.Encrypt(crypter) + assert.NoError(t, err) + client := redis.NewClient(&redis.Options{ Network: "tcp", Addr: "127.0.0.1:6379", }) - sess := session.NewRedis(client, crypter) - err = sess.Write(context.Background(), "key", data, time.Minute) + sess := session.NewRedis(client) + err = sess.Write(context.Background(), "key", encryptedData, time.Minute) assert.NoError(t, err) result, err := sess.Read(context.Background(), "key") assert.NoError(t, err) - assert.Equal(t, data, result) + assert.Equal(t, encryptedData, result) + + decrypted, err := result.Decrypt(crypter) + assert.NoError(t, err) + assert.Equal(t, data, decrypted) err = sess.Delete(context.Background(), "key") diff --git a/pkg/session/session.go b/pkg/session/session.go index 37dfaa7..4a74ff2 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -12,8 +12,8 @@ import ( ) type Store interface { - Write(ctx context.Context, key string, value *Data, expiration time.Duration) error - Read(ctx context.Context, key string) (*Data, error) + Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error + Read(ctx context.Context, key string) (*EncryptedData, error) Delete(ctx context.Context, keys ...string) error }