From 2ec1b7ace93b31374a74917a74c0d33f320aa4c0 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 30 Sep 2021 13:47:22 +0200 Subject: [PATCH] feat: encrypt session data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Sindre Rødseth Hansen --- cmd/wonderwall/main.go | 4 +-- pkg/router/router_test.go | 4 +-- pkg/router/session.go | 13 ++++++-- pkg/session/memory.go | 28 ++++++++++++++--- pkg/session/memory_test.go | 10 ++++-- pkg/session/redis.go | 27 +++++++++++++---- pkg/session/redis_test.go | 11 +++++-- pkg/session/session.go | 62 ++++++++++++++++++++++++++++++++------ 8 files changed, 127 insertions(+), 32 deletions(-) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 4aaa0f2..0577705 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -79,10 +79,10 @@ func run() error { return fmt.Errorf("connecting to redis: %w", err) } - sessionStore = session.NewRedis(redisClient) + sessionStore = session.NewRedis(redisClient, crypt) log.Infof("Using Redis as session backing store") } else { - sessionStore = session.NewMemory() + sessionStore = session.NewMemory(crypt) log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index ab291a3..da038dd 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler { } crypter := cryptutil.New(encryptionKey) - sessionStore := session.NewMemory() + sessionStore := session.NewMemory(crypter) handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "") if err != nil { @@ -391,4 +391,4 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { } return nil -} \ No newline at end of file +} diff --git a/pkg/router/session.go b/pkg/router/session.go index e8a3169..b32dcdf 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -25,7 +25,12 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { return nil, fmt.Errorf("no session cookie: %w", err) } - return h.Sessions.Read(r.Context(), sessionID) + sessionData, err := h.Sessions.Read(r.Context(), sessionID) + if err != nil { + return nil, fmt.Errorf("reading session from store: %w", err) + } + + return sessionData, nil } func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) { @@ -58,11 +63,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external return fmt.Errorf("setting session cookie: %w", err) } - err = h.Sessions.Write(r.Context(), sessionID, &session.Data{ + sessionData := &session.Data{ ExternalSessionID: externalSessionID, OAuth2Token: tokens, IDTokenSerialized: idToken.Raw, - }, sessionLifetime) + } + + err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime) if err != nil { return fmt.Errorf("writing session to store: %w", err) } diff --git a/pkg/session/memory.go b/pkg/session/memory.go index 2e1f91c..71b13c3 100644 --- a/pkg/session/memory.go +++ b/pkg/session/memory.go @@ -3,45 +3,63 @@ package session import ( "context" "fmt" + "github.com/nais/wonderwall/pkg/cryptutil" "sync" "time" ) type memorySessionStore struct { lock sync.Mutex - sessions map[string]*Data + sessions map[string]*EncryptedData + crypter cryptutil.Crypter } var _ Store = &memorySessionStore{} -func NewMemory() Store { +func NewMemory(crypter cryptutil.Crypter) Store { return &memorySessionStore{ - sessions: make(map[string]*Data), + sessions: make(map[string]*EncryptedData), + crypter: crypter, } } func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) { s.lock.Lock() defer s.lock.Unlock() + data, ok := s.sessions[key] if !ok { return nil, fmt.Errorf("no such session: %s", key) } - return data, nil + + decrypted, err := data.Decrypt(s.crypter) + if err != nil { + return nil, fmt.Errorf("decrypting session data: %w", err) + } + + return decrypted, nil } func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error { s.lock.Lock() defer s.lock.Unlock() - s.sessions[key] = value + + encrypted, err := value.Encrypt(s.crypter) + if err != nil { + return fmt.Errorf("encrypting session data: %w", err) + } + + s.sessions[key] = encrypted return nil } func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error { s.lock.Lock() defer s.lock.Unlock() + for _, key := range keys { delete(s.sessions, key) } + return nil } diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index 5efefcf..672d94d 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -2,6 +2,8 @@ package session_test import ( "context" + "github.com/nais/liberator/pkg/keygen" + "github.com/nais/wonderwall/pkg/cryptutil" "testing" "time" @@ -12,6 +14,10 @@ import ( ) func TestMemory(t *testing.T) { + key, err := keygen.Keygen(32) + assert.NoError(t, err) + crypter := cryptutil.New(key) + data := &session.Data{ ExternalSessionID: "myid", OAuth2Token: &oauth2.Token{ @@ -20,8 +26,8 @@ func TestMemory(t *testing.T) { IDTokenSerialized: "idtoken", } - sess := session.NewMemory() - err := sess.Write(context.Background(), "key", data, time.Minute) + sess := session.NewMemory(crypter) + err = sess.Write(context.Background(), "key", data, time.Minute) assert.NoError(t, err) result, err := sess.Read(context.Background(), "key") diff --git a/pkg/session/redis.go b/pkg/session/redis.go index cf901b1..42ff07d 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -2,40 +2,55 @@ package session import ( "context" + "fmt" "github.com/go-redis/redis/v8" + "github.com/nais/wonderwall/pkg/cryptutil" "github.com/nais/wonderwall/pkg/metrics" "time" ) type redisSessionStore struct { - client redis.Cmdable + client redis.Cmdable + crypter cryptutil.Crypter } var _ Store = &redisSessionStore{} -func NewRedis(client redis.Cmdable) Store { +func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store { return &redisSessionStore{ - client: client, + client: client, + crypter: crypter, } } func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) { - data := &Data{} + encryptedData := &EncryptedData{} err := metrics.ObserveRedisLatency("Read", func() error { var err error status := s.client.Get(ctx, key) - err = status.Scan(data) + err = status.Scan(encryptedData) return err }) if err != nil { return nil, err } + + data, err := encryptedData.Decrypt(s.crypter) + if err != nil { + return nil, fmt.Errorf("decrypting session data: %w", err) + } + return data, nil } func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error { + encryptedData, err := value.Encrypt(s.crypter) + if err != nil { + return err + } + return metrics.ObserveRedisLatency("Write", func() error { - status := s.client.Set(ctx, key, value, expiration) + status := s.client.Set(ctx, key, encryptedData, expiration) return status.Err() }) } diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index 593f42a..2f8be9e 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -1,9 +1,12 @@ +//go:build integration // +build integration package session_test import ( "context" + "github.com/nais/liberator/pkg/keygen" + "github.com/nais/wonderwall/pkg/cryptutil" "testing" "time" @@ -15,6 +18,10 @@ import ( ) func TestRedis(t *testing.T) { + key, err := keygen.Keygen(32) + assert.NoError(t, err) + crypter := cryptutil.New(key) + data := &session.Data{ ExternalSessionID: "myid", OAuth2Token: &oauth2.Token{ @@ -28,8 +35,8 @@ func TestRedis(t *testing.T) { Addr: "127.0.0.1:6379", }) - sess := session.NewRedis(client) - err := sess.Write(context.Background(), "key", data, time.Minute) + sess := session.NewRedis(client, crypter) + err = sess.Write(context.Background(), "key", data, time.Minute) assert.NoError(t, err) result, err := sess.Read(context.Background(), "key") diff --git a/pkg/session/session.go b/pkg/session/session.go index 8a63a9d..37dfaa7 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -3,7 +3,9 @@ package session import ( "context" "encoding" + "encoding/base64" "encoding/json" + "github.com/nais/wonderwall/pkg/cryptutil" "time" "golang.org/x/oauth2" @@ -15,19 +17,59 @@ type Store interface { Delete(ctx context.Context, keys ...string) error } +type EncryptedData struct { + Data string `json:"data"` +} + +var _ encoding.BinaryMarshaler = &EncryptedData{} +var _ encoding.BinaryUnmarshaler = &EncryptedData{} + +func (in *EncryptedData) MarshalBinary() ([]byte, error) { + return json.Marshal(in) +} + +func (in *EncryptedData) UnmarshalBinary(bytes []byte) error { + return json.Unmarshal(bytes, in) +} + +func (in *EncryptedData) Decrypt(crypter cryptutil.Crypter) (*Data, error) { + ciphertext, err := base64.StdEncoding.DecodeString(in.Data) + if err != nil { + return nil, err + } + + rawData, err := crypter.Decrypt(ciphertext) + if err != nil { + return nil, err + } + + var data Data + err = json.Unmarshal(rawData, &data) + if err != nil { + return nil, err + } + + return &data, nil +} + type Data struct { - ExternalSessionID string - OAuth2Token *oauth2.Token - IDTokenSerialized string + ExternalSessionID string `json:"external_session_id"` + OAuth2Token *oauth2.Token `json:"oauth2_token"` + IDTokenSerialized string `json:"id_token_serialized"` } -var _ encoding.BinaryMarshaler = &Data{} -var _ encoding.BinaryUnmarshaler = &Data{} +func (in *Data) Encrypt(crypter cryptutil.Crypter) (*EncryptedData, error) { + bytes, err := json.Marshal(in) + if err != nil { + return nil, err + } -func (data *Data) MarshalBinary() ([]byte, error) { - return json.Marshal(data) -} + ciphertext, err := crypter.Encrypt(bytes) + if err != nil { + return nil, err + } -func (data *Data) UnmarshalBinary(bytes []byte) error { - return json.Unmarshal(bytes, data) + return &EncryptedData{ + Data: base64.StdEncoding.EncodeToString(ciphertext), + }, nil }