From 55f26fb54c2ff9a1376fecad08eb927d34b4711a Mon Sep 17 00:00:00 2001 From: Kim Tore Jensen Date: Tue, 24 Aug 2021 12:58:16 +0200 Subject: [PATCH] incorporate new session storage code --- cmd/wonderwall/main.go | 4 +-- pkg/router/router.go | 51 ++++++++++++++++++++++++--------------- pkg/router/router_test.go | 5 ++-- pkg/router/session.go | 29 +++------------------- pkg/session/memory.go | 4 +-- pkg/session/redis.go | 4 +-- pkg/session/session.go | 2 +- 7 files changed, 44 insertions(+), 55 deletions(-) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index d146001..fab2543 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "github.com/nais/wonderwall/pkg/session" "net/http" "os" @@ -66,6 +67,7 @@ func run() error { OauthConfig: oauthConfig, UpstreamHost: cfg.UpstreamHost, SecureCookies: true, + Sessions: session.NewMemory(), IdTokenVerifier: oidc.NewVerifier( cfg.IDPorten.WellKnown.Issuer, oidc.NewRemoteKeySet(context.Background(), cfg.IDPorten.WellKnown.JwksURI), @@ -73,8 +75,6 @@ func run() error { ), } - handler.Init() - r := router.New(handler) return http.ListenAndServe(cfg.BindAddress, r) diff --git a/pkg/router/router.go b/pkg/router/router.go index 17d7add..5cb2c93 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "github.com/nais/wonderwall/pkg/session" "gopkg.in/square/go-jose.v2/jwt" "io" "net/http" @@ -44,7 +45,7 @@ type Handler struct { UpstreamHost string IdTokenVerifier *oidc.IDTokenVerifier SecureCookies bool - sessions map[string]session + Sessions session.Store lock sync.Mutex } @@ -55,10 +56,6 @@ type loginParams struct { nonce string } -func (h *Handler) Init() { - h.sessions = make(map[string]session) -} - func (h *Handler) LoginURL() (*loginParams, error) { codeVerifier := make([]byte, 64) nonce := make([]byte, 32) @@ -197,11 +194,15 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - h.storeSession(claims.SessionID, session{ - token: tokens, - }) - - // fixme: distributed session store for multi-pod deployments + err = h.Sessions.Write(r.Context(), claims.SessionID, &session.Data{ + ID: claims.SessionID, + Token: tokens, + }, SessionMaxLifetime) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } @@ -215,10 +216,10 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { upstreamRequest := r.Clone(ctx) upstreamRequest.Header.Del("authorization") - session, err := h.getSessionFromCookie(r) - if err == nil && session != nil && session.token != nil { + sess, err := h.getSessionFromCookie(r) + if err == nil && sess != nil && sess.Token != nil { // add authentication if session cookie and token checks out - upstreamRequest.Header.Add("authorization", "Bearer "+session.token.AccessToken) + upstreamRequest.Header.Add("authorization", "Bearer "+sess.Token.AccessToken) upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing } @@ -260,10 +261,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - session, err := h.getSessionFromCookie(r) + sess, err := h.getSessionFromCookie(r) - if err == nil && session != nil && session.token != nil { - h.deleteSession(session.id) + if err == nil && sess != nil && sess.Token != nil { + err = h.Sessions.Delete(r.Context(), sess.ID) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } h.deleteCookie(w, SessionCookieName) } // todo: test logout without credentials @@ -293,15 +299,15 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { return } - session, ok := h.sessions[sid] - if !ok { + sess, err := h.Sessions.Read(r.Context(), sid) + if err != nil { // Can't remove session because it doesn't exist. Maybe it was garbage collected. // We regard this as a redundant logout and return 200 OK. return } // From here on, check that 'iss' from request matches data found in access token. - tok, err := jwt.ParseSigned(session.token.AccessToken) + tok, err := jwt.ParseSigned(sess.Token.AccessToken) if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) @@ -325,7 +331,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { } // All verified; delete session. - h.deleteSession(sid) + err = h.Sessions.Delete(r.Context(), sid) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } } func New(handler *Handler) chi.Router { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 888bed3..6f6322c 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "github.com/nais/wonderwall/pkg/session" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -62,7 +63,8 @@ var idp = NewIDPorten(clients) func handler() *router.Handler { handler := router.Handler{ - Config: cfg, + Config: cfg, + Sessions: session.NewMemory(), OauthConfig: oauth2.Config{ ClientID: "client-id", ClientSecret: "client-secret", @@ -77,7 +79,6 @@ func handler() *router.Handler { UpstreamHost: "", IdTokenVerifier: nil, } - handler.Init() return &handler } diff --git a/pkg/router/session.go b/pkg/router/session.go index d0d38ae..adbc881 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -2,38 +2,15 @@ package router import ( "fmt" - "golang.org/x/oauth2" + "github.com/nais/wonderwall/pkg/session" "net/http" ) -type session struct { - id string - token *oauth2.Token -} - -func (h *Handler) storeSession(key string, session session) { - session.id = key - h.lock.Lock() - h.sessions[key] = session - h.lock.Unlock() -} - -func (h *Handler) deleteSession(key string) { - h.lock.Lock() - delete(h.sessions, key) - h.lock.Unlock() -} - -func (h *Handler) getSessionFromCookie(r *http.Request) (*session, error) { +func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { sessionID, err := h.getEncryptedCookie(r, SessionCookieName) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) } - session, ok := h.sessions[sessionID] - if !ok { - return nil, fmt.Errorf("no token stored for session %s", sessionID) - } - - return &session, nil + return h.Sessions.Read(r.Context(), sessionID) } diff --git a/pkg/session/memory.go b/pkg/session/memory.go index 19d4121..2e1f91c 100644 --- a/pkg/session/memory.go +++ b/pkg/session/memory.go @@ -12,9 +12,9 @@ type memorySessionStore struct { sessions map[string]*Data } -var _ Session = &memorySessionStore{} +var _ Store = &memorySessionStore{} -func NewMemory() Session { +func NewMemory() Store { return &memorySessionStore{ sessions: make(map[string]*Data), } diff --git a/pkg/session/redis.go b/pkg/session/redis.go index 3d5df19..6963385 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -10,9 +10,9 @@ type redisSessionStore struct { client redis.Cmdable } -var _ Session = &redisSessionStore{} +var _ Store = &redisSessionStore{} -func NewRedis(client redis.Cmdable) Session { +func NewRedis(client redis.Cmdable) Store { return &redisSessionStore{ client: client, } diff --git a/pkg/session/session.go b/pkg/session/session.go index 5078834..e3aa11e 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -8,7 +8,7 @@ import ( "time" ) -type Session interface { +type Store interface { Write(ctx context.Context, key string, value *Data, expiration time.Duration) error Read(ctx context.Context, key string) (*Data, error) Delete(ctx context.Context, keys ...string) error