mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
feat: encrypt session data
Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
@@ -79,10 +79,10 @@ func run() error {
|
||||
return fmt.Errorf("connecting to redis: %w", err)
|
||||
}
|
||||
|
||||
sessionStore = session.NewRedis(redisClient)
|
||||
sessionStore = session.NewRedis(redisClient, crypt)
|
||||
log.Infof("Using Redis as session backing store")
|
||||
} else {
|
||||
sessionStore = session.NewMemory()
|
||||
sessionStore = session.NewMemory(crypt)
|
||||
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler {
|
||||
}
|
||||
|
||||
crypter := cryptutil.New(encryptionKey)
|
||||
sessionStore := session.NewMemory()
|
||||
sessionStore := session.NewMemory(crypter)
|
||||
|
||||
handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "")
|
||||
if err != nil {
|
||||
@@ -391,4 +391,4 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,12 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
return h.Sessions.Read(r.Context(), sessionID)
|
||||
sessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading session from store: %w", err)
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionLifetime(accessToken string) (time.Duration, error) {
|
||||
@@ -58,11 +63,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
|
||||
return fmt.Errorf("setting session cookie: %w", err)
|
||||
}
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, &session.Data{
|
||||
sessionData := &session.Data{
|
||||
ExternalSessionID: externalSessionID,
|
||||
OAuth2Token: tokens,
|
||||
IDTokenSerialized: idToken.Raw,
|
||||
}, sessionLifetime)
|
||||
}
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing session to store: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,45 +3,63 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type memorySessionStore struct {
|
||||
lock sync.Mutex
|
||||
sessions map[string]*Data
|
||||
sessions map[string]*EncryptedData
|
||||
crypter cryptutil.Crypter
|
||||
}
|
||||
|
||||
var _ Store = &memorySessionStore{}
|
||||
|
||||
func NewMemory() Store {
|
||||
func NewMemory(crypter cryptutil.Crypter) Store {
|
||||
return &memorySessionStore{
|
||||
sessions: make(map[string]*Data),
|
||||
sessions: make(map[string]*EncryptedData),
|
||||
crypter: crypter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
data, ok := s.sessions[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such session: %s", key)
|
||||
}
|
||||
return data, nil
|
||||
|
||||
decrypted, err := data.Decrypt(s.crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.sessions[key] = value
|
||||
|
||||
encrypted, err := value.Encrypt(s.crypter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypting session data: %w", err)
|
||||
}
|
||||
|
||||
s.sessions[key] = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
delete(s.sessions, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package session_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
key, err := keygen.Keygen(32)
|
||||
assert.NoError(t, err)
|
||||
crypter := cryptutil.New(key)
|
||||
|
||||
data := &session.Data{
|
||||
ExternalSessionID: "myid",
|
||||
OAuth2Token: &oauth2.Token{
|
||||
@@ -20,8 +26,8 @@ func TestMemory(t *testing.T) {
|
||||
IDTokenSerialized: "idtoken",
|
||||
}
|
||||
|
||||
sess := session.NewMemory()
|
||||
err := sess.Write(context.Background(), "key", data, time.Minute)
|
||||
sess := session.NewMemory(crypter)
|
||||
err = sess.Write(context.Background(), "key", data, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := sess.Read(context.Background(), "key")
|
||||
|
||||
@@ -2,40 +2,55 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
"time"
|
||||
)
|
||||
|
||||
type redisSessionStore struct {
|
||||
client redis.Cmdable
|
||||
client redis.Cmdable
|
||||
crypter cryptutil.Crypter
|
||||
}
|
||||
|
||||
var _ Store = &redisSessionStore{}
|
||||
|
||||
func NewRedis(client redis.Cmdable) Store {
|
||||
func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store {
|
||||
return &redisSessionStore{
|
||||
client: client,
|
||||
client: client,
|
||||
crypter: crypter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) {
|
||||
data := &Data{}
|
||||
encryptedData := &EncryptedData{}
|
||||
err := metrics.ObserveRedisLatency("Read", func() error {
|
||||
var err error
|
||||
status := s.client.Get(ctx, key)
|
||||
err = status.Scan(data)
|
||||
err = status.Scan(encryptedData)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := encryptedData.Decrypt(s.crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error {
|
||||
encryptedData, err := value.Encrypt(s.crypter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return metrics.ObserveRedisLatency("Write", func() error {
|
||||
status := s.client.Set(ctx, key, value, expiration)
|
||||
status := s.client.Set(ctx, key, encryptedData, expiration)
|
||||
return status.Err()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package session_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +18,10 @@ import (
|
||||
)
|
||||
|
||||
func TestRedis(t *testing.T) {
|
||||
key, err := keygen.Keygen(32)
|
||||
assert.NoError(t, err)
|
||||
crypter := cryptutil.New(key)
|
||||
|
||||
data := &session.Data{
|
||||
ExternalSessionID: "myid",
|
||||
OAuth2Token: &oauth2.Token{
|
||||
@@ -28,8 +35,8 @@ func TestRedis(t *testing.T) {
|
||||
Addr: "127.0.0.1:6379",
|
||||
})
|
||||
|
||||
sess := session.NewRedis(client)
|
||||
err := sess.Write(context.Background(), "key", data, time.Minute)
|
||||
sess := session.NewRedis(client, crypter)
|
||||
err = sess.Write(context.Background(), "key", data, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := sess.Read(context.Background(), "key")
|
||||
|
||||
@@ -3,7 +3,9 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
@@ -15,19 +17,59 @@ type Store interface {
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
}
|
||||
|
||||
type EncryptedData struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
var _ encoding.BinaryMarshaler = &EncryptedData{}
|
||||
var _ encoding.BinaryUnmarshaler = &EncryptedData{}
|
||||
|
||||
func (in *EncryptedData) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(in)
|
||||
}
|
||||
|
||||
func (in *EncryptedData) UnmarshalBinary(bytes []byte) error {
|
||||
return json.Unmarshal(bytes, in)
|
||||
}
|
||||
|
||||
func (in *EncryptedData) Decrypt(crypter cryptutil.Crypter) (*Data, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(in.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawData, err := crypter.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data Data
|
||||
err = json.Unmarshal(rawData, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
ExternalSessionID string
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenSerialized string
|
||||
ExternalSessionID string `json:"external_session_id"`
|
||||
OAuth2Token *oauth2.Token `json:"oauth2_token"`
|
||||
IDTokenSerialized string `json:"id_token_serialized"`
|
||||
}
|
||||
|
||||
var _ encoding.BinaryMarshaler = &Data{}
|
||||
var _ encoding.BinaryUnmarshaler = &Data{}
|
||||
func (in *Data) Encrypt(crypter cryptutil.Crypter) (*EncryptedData, error) {
|
||||
bytes, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (data *Data) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
ciphertext, err := crypter.Encrypt(bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (data *Data) UnmarshalBinary(bytes []byte) error {
|
||||
return json.Unmarshal(bytes, data)
|
||||
return &EncryptedData{
|
||||
Data: base64.StdEncoding.EncodeToString(ciphertext),
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user