incorporate new session storage code

This commit is contained in:
Kim Tore Jensen
2021-08-24 12:58:16 +02:00
parent 15a7c14324
commit 55f26fb54c
7 changed files with 44 additions and 55 deletions

View File

@@ -2,6 +2,7 @@ package main
import (
"context"
"github.com/nais/wonderwall/pkg/session"
"net/http"
"os"
@@ -66,6 +67,7 @@ func run() error {
OauthConfig: oauthConfig,
UpstreamHost: cfg.UpstreamHost,
SecureCookies: true,
Sessions: session.NewMemory(),
IdTokenVerifier: oidc.NewVerifier(
cfg.IDPorten.WellKnown.Issuer,
oidc.NewRemoteKeySet(context.Background(), cfg.IDPorten.WellKnown.JwksURI),
@@ -73,8 +75,6 @@ func run() error {
),
}
handler.Init()
r := router.New(handler)
return http.ListenAndServe(cfg.BindAddress, r)

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/nais/wonderwall/pkg/session"
"gopkg.in/square/go-jose.v2/jwt"
"io"
"net/http"
@@ -44,7 +45,7 @@ type Handler struct {
UpstreamHost string
IdTokenVerifier *oidc.IDTokenVerifier
SecureCookies bool
sessions map[string]session
Sessions session.Store
lock sync.Mutex
}
@@ -55,10 +56,6 @@ type loginParams struct {
nonce string
}
func (h *Handler) Init() {
h.sessions = make(map[string]session)
}
func (h *Handler) LoginURL() (*loginParams, error) {
codeVerifier := make([]byte, 64)
nonce := make([]byte, 32)
@@ -197,11 +194,15 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
h.storeSession(claims.SessionID, session{
token: tokens,
})
// fixme: distributed session store for multi-pod deployments
err = h.Sessions.Write(r.Context(), claims.SessionID, &session.Data{
ID: claims.SessionID,
Token: tokens,
}, SessionMaxLifetime)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}
@@ -215,10 +216,10 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
upstreamRequest := r.Clone(ctx)
upstreamRequest.Header.Del("authorization")
session, err := h.getSessionFromCookie(r)
if err == nil && session != nil && session.token != nil {
sess, err := h.getSessionFromCookie(r)
if err == nil && sess != nil && sess.Token != nil {
// add authentication if session cookie and token checks out
upstreamRequest.Header.Add("authorization", "Bearer "+session.token.AccessToken)
upstreamRequest.Header.Add("authorization", "Bearer "+sess.Token.AccessToken)
upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing
}
@@ -260,10 +261,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
// Logout triggers self-initiated for the current user
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
session, err := h.getSessionFromCookie(r)
sess, err := h.getSessionFromCookie(r)
if err == nil && session != nil && session.token != nil {
h.deleteSession(session.id)
if err == nil && sess != nil && sess.Token != nil {
err = h.Sessions.Delete(r.Context(), sess.ID)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
h.deleteCookie(w, SessionCookieName)
}
// todo: test logout without credentials
@@ -293,15 +299,15 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
return
}
session, ok := h.sessions[sid]
if !ok {
sess, err := h.Sessions.Read(r.Context(), sid)
if err != nil {
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
// We regard this as a redundant logout and return 200 OK.
return
}
// From here on, check that 'iss' from request matches data found in access token.
tok, err := jwt.ParseSigned(session.token.AccessToken)
tok, err := jwt.ParseSigned(sess.Token.AccessToken)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
@@ -325,7 +331,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
}
// All verified; delete session.
h.deleteSession(sid)
err = h.Sessions.Delete(r.Context(), sid)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func New(handler *Handler) chi.Router {

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
"github.com/nais/wonderwall/pkg/session"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
@@ -62,7 +63,8 @@ var idp = NewIDPorten(clients)
func handler() *router.Handler {
handler := router.Handler{
Config: cfg,
Config: cfg,
Sessions: session.NewMemory(),
OauthConfig: oauth2.Config{
ClientID: "client-id",
ClientSecret: "client-secret",
@@ -77,7 +79,6 @@ func handler() *router.Handler {
UpstreamHost: "",
IdTokenVerifier: nil,
}
handler.Init()
return &handler
}

View File

@@ -2,38 +2,15 @@ package router
import (
"fmt"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/session"
"net/http"
)
type session struct {
id string
token *oauth2.Token
}
func (h *Handler) storeSession(key string, session session) {
session.id = key
h.lock.Lock()
h.sessions[key] = session
h.lock.Unlock()
}
func (h *Handler) deleteSession(key string) {
h.lock.Lock()
delete(h.sessions, key)
h.lock.Unlock()
}
func (h *Handler) getSessionFromCookie(r *http.Request) (*session, error) {
func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
sessionID, err := h.getEncryptedCookie(r, SessionCookieName)
if err != nil {
return nil, fmt.Errorf("no session cookie: %w", err)
}
session, ok := h.sessions[sessionID]
if !ok {
return nil, fmt.Errorf("no token stored for session %s", sessionID)
}
return &session, nil
return h.Sessions.Read(r.Context(), sessionID)
}

View File

@@ -12,9 +12,9 @@ type memorySessionStore struct {
sessions map[string]*Data
}
var _ Session = &memorySessionStore{}
var _ Store = &memorySessionStore{}
func NewMemory() Session {
func NewMemory() Store {
return &memorySessionStore{
sessions: make(map[string]*Data),
}

View File

@@ -10,9 +10,9 @@ type redisSessionStore struct {
client redis.Cmdable
}
var _ Session = &redisSessionStore{}
var _ Store = &redisSessionStore{}
func NewRedis(client redis.Cmdable) Session {
func NewRedis(client redis.Cmdable) Store {
return &redisSessionStore{
client: client,
}

View File

@@ -8,7 +8,7 @@ import (
"time"
)
type Session interface {
type Store interface {
Write(ctx context.Context, key string, value *Data, expiration time.Duration) error
Read(ctx context.Context, key string) (*Data, error)
Delete(ctx context.Context, keys ...string) error