feat: encrypt session data

Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
Trong Huu Nguyen
2021-09-30 13:47:22 +02:00
parent cf7ca9c5b8
commit 2ec1b7ace9
8 changed files with 127 additions and 32 deletions

View File

@@ -79,10 +79,10 @@ func run() error {
return fmt.Errorf("connecting to redis: %w", err)
}
sessionStore = session.NewRedis(redisClient)
sessionStore = session.NewRedis(redisClient, crypt)
log.Infof("Using Redis as session backing store")
} else {
sessionStore = session.NewMemory()
sessionStore = session.NewMemory(crypt)
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
}

View File

@@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler {
}
crypter := cryptutil.New(encryptionKey)
sessionStore := session.NewMemory()
sessionStore := session.NewMemory(crypter)
handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "")
if err != nil {
@@ -391,4 +391,4 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
}
return nil
}
}

View File

@@ -25,7 +25,12 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
return nil, fmt.Errorf("no session cookie: %w", err)
}
return h.Sessions.Read(r.Context(), sessionID)
sessionData, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
return nil, fmt.Errorf("reading session from store: %w", err)
}
return sessionData, nil
}
func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) {
@@ -58,11 +63,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
return fmt.Errorf("setting session cookie: %w", err)
}
err = h.Sessions.Write(r.Context(), sessionID, &session.Data{
sessionData := &session.Data{
ExternalSessionID: externalSessionID,
OAuth2Token: tokens,
IDTokenSerialized: idToken.Raw,
}, sessionLifetime)
}
err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime)
if err != nil {
return fmt.Errorf("writing session to store: %w", err)
}

View File

@@ -3,45 +3,63 @@ package session
import (
"context"
"fmt"
"github.com/nais/wonderwall/pkg/cryptutil"
"sync"
"time"
)
type memorySessionStore struct {
lock sync.Mutex
sessions map[string]*Data
sessions map[string]*EncryptedData
crypter cryptutil.Crypter
}
var _ Store = &memorySessionStore{}
func NewMemory() Store {
func NewMemory(crypter cryptutil.Crypter) Store {
return &memorySessionStore{
sessions: make(map[string]*Data),
sessions: make(map[string]*EncryptedData),
crypter: crypter,
}
}
func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) {
s.lock.Lock()
defer s.lock.Unlock()
data, ok := s.sessions[key]
if !ok {
return nil, fmt.Errorf("no such session: %s", key)
}
return data, nil
decrypted, err := data.Decrypt(s.crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return decrypted, nil
}
func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error {
s.lock.Lock()
defer s.lock.Unlock()
s.sessions[key] = value
encrypted, err := value.Encrypt(s.crypter)
if err != nil {
return fmt.Errorf("encrypting session data: %w", err)
}
s.sessions[key] = encrypted
return nil
}
func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error {
s.lock.Lock()
defer s.lock.Unlock()
for _, key := range keys {
delete(s.sessions, key)
}
return nil
}

View File

@@ -2,6 +2,8 @@ package session_test
import (
"context"
"github.com/nais/liberator/pkg/keygen"
"github.com/nais/wonderwall/pkg/cryptutil"
"testing"
"time"
@@ -12,6 +14,10 @@ import (
)
func TestMemory(t *testing.T) {
key, err := keygen.Keygen(32)
assert.NoError(t, err)
crypter := cryptutil.New(key)
data := &session.Data{
ExternalSessionID: "myid",
OAuth2Token: &oauth2.Token{
@@ -20,8 +26,8 @@ func TestMemory(t *testing.T) {
IDTokenSerialized: "idtoken",
}
sess := session.NewMemory()
err := sess.Write(context.Background(), "key", data, time.Minute)
sess := session.NewMemory(crypter)
err = sess.Write(context.Background(), "key", data, time.Minute)
assert.NoError(t, err)
result, err := sess.Read(context.Background(), "key")

View File

@@ -2,40 +2,55 @@ package session
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/nais/wonderwall/pkg/cryptutil"
"github.com/nais/wonderwall/pkg/metrics"
"time"
)
type redisSessionStore struct {
client redis.Cmdable
client redis.Cmdable
crypter cryptutil.Crypter
}
var _ Store = &redisSessionStore{}
func NewRedis(client redis.Cmdable) Store {
func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store {
return &redisSessionStore{
client: client,
client: client,
crypter: crypter,
}
}
func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) {
data := &Data{}
encryptedData := &EncryptedData{}
err := metrics.ObserveRedisLatency("Read", func() error {
var err error
status := s.client.Get(ctx, key)
err = status.Scan(data)
err = status.Scan(encryptedData)
return err
})
if err != nil {
return nil, err
}
data, err := encryptedData.Decrypt(s.crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return data, nil
}
func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error {
encryptedData, err := value.Encrypt(s.crypter)
if err != nil {
return err
}
return metrics.ObserveRedisLatency("Write", func() error {
status := s.client.Set(ctx, key, value, expiration)
status := s.client.Set(ctx, key, encryptedData, expiration)
return status.Err()
})
}

View File

@@ -1,9 +1,12 @@
//go:build integration
// +build integration
package session_test
import (
"context"
"github.com/nais/liberator/pkg/keygen"
"github.com/nais/wonderwall/pkg/cryptutil"
"testing"
"time"
@@ -15,6 +18,10 @@ import (
)
func TestRedis(t *testing.T) {
key, err := keygen.Keygen(32)
assert.NoError(t, err)
crypter := cryptutil.New(key)
data := &session.Data{
ExternalSessionID: "myid",
OAuth2Token: &oauth2.Token{
@@ -28,8 +35,8 @@ func TestRedis(t *testing.T) {
Addr: "127.0.0.1:6379",
})
sess := session.NewRedis(client)
err := sess.Write(context.Background(), "key", data, time.Minute)
sess := session.NewRedis(client, crypter)
err = sess.Write(context.Background(), "key", data, time.Minute)
assert.NoError(t, err)
result, err := sess.Read(context.Background(), "key")

View File

@@ -3,7 +3,9 @@ package session
import (
"context"
"encoding"
"encoding/base64"
"encoding/json"
"github.com/nais/wonderwall/pkg/cryptutil"
"time"
"golang.org/x/oauth2"
@@ -15,19 +17,59 @@ type Store interface {
Delete(ctx context.Context, keys ...string) error
}
type EncryptedData struct {
Data string `json:"data"`
}
var _ encoding.BinaryMarshaler = &EncryptedData{}
var _ encoding.BinaryUnmarshaler = &EncryptedData{}
func (in *EncryptedData) MarshalBinary() ([]byte, error) {
return json.Marshal(in)
}
func (in *EncryptedData) UnmarshalBinary(bytes []byte) error {
return json.Unmarshal(bytes, in)
}
func (in *EncryptedData) Decrypt(crypter cryptutil.Crypter) (*Data, error) {
ciphertext, err := base64.StdEncoding.DecodeString(in.Data)
if err != nil {
return nil, err
}
rawData, err := crypter.Decrypt(ciphertext)
if err != nil {
return nil, err
}
var data Data
err = json.Unmarshal(rawData, &data)
if err != nil {
return nil, err
}
return &data, nil
}
type Data struct {
ExternalSessionID string
OAuth2Token *oauth2.Token
IDTokenSerialized string
ExternalSessionID string `json:"external_session_id"`
OAuth2Token *oauth2.Token `json:"oauth2_token"`
IDTokenSerialized string `json:"id_token_serialized"`
}
var _ encoding.BinaryMarshaler = &Data{}
var _ encoding.BinaryUnmarshaler = &Data{}
func (in *Data) Encrypt(crypter cryptutil.Crypter) (*EncryptedData, error) {
bytes, err := json.Marshal(in)
if err != nil {
return nil, err
}
func (data *Data) MarshalBinary() ([]byte, error) {
return json.Marshal(data)
}
ciphertext, err := crypter.Encrypt(bytes)
if err != nil {
return nil, err
}
func (data *Data) UnmarshalBinary(bytes []byte) error {
return json.Unmarshal(bytes, data)
return &EncryptedData{
Data: base64.StdEncoding.EncodeToString(ciphertext),
}, nil
}