mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
incorporate new session storage code
This commit is contained in:
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
@@ -66,6 +67,7 @@ func run() error {
|
||||
OauthConfig: oauthConfig,
|
||||
UpstreamHost: cfg.UpstreamHost,
|
||||
SecureCookies: true,
|
||||
Sessions: session.NewMemory(),
|
||||
IdTokenVerifier: oidc.NewVerifier(
|
||||
cfg.IDPorten.WellKnown.Issuer,
|
||||
oidc.NewRemoteKeySet(context.Background(), cfg.IDPorten.WellKnown.JwksURI),
|
||||
@@ -73,8 +75,6 @@ func run() error {
|
||||
),
|
||||
}
|
||||
|
||||
handler.Init()
|
||||
|
||||
r := router.New(handler)
|
||||
|
||||
return http.ListenAndServe(cfg.BindAddress, r)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -44,7 +45,7 @@ type Handler struct {
|
||||
UpstreamHost string
|
||||
IdTokenVerifier *oidc.IDTokenVerifier
|
||||
SecureCookies bool
|
||||
sessions map[string]session
|
||||
Sessions session.Store
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
@@ -55,10 +56,6 @@ type loginParams struct {
|
||||
nonce string
|
||||
}
|
||||
|
||||
func (h *Handler) Init() {
|
||||
h.sessions = make(map[string]session)
|
||||
}
|
||||
|
||||
func (h *Handler) LoginURL() (*loginParams, error) {
|
||||
codeVerifier := make([]byte, 64)
|
||||
nonce := make([]byte, 32)
|
||||
@@ -197,11 +194,15 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.storeSession(claims.SessionID, session{
|
||||
token: tokens,
|
||||
})
|
||||
|
||||
// fixme: distributed session store for multi-pod deployments
|
||||
err = h.Sessions.Write(r.Context(), claims.SessionID, &session.Data{
|
||||
ID: claims.SessionID,
|
||||
Token: tokens,
|
||||
}, SessionMaxLifetime)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
@@ -215,10 +216,10 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamRequest := r.Clone(ctx)
|
||||
upstreamRequest.Header.Del("authorization")
|
||||
|
||||
session, err := h.getSessionFromCookie(r)
|
||||
if err == nil && session != nil && session.token != nil {
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
if err == nil && sess != nil && sess.Token != nil {
|
||||
// add authentication if session cookie and token checks out
|
||||
upstreamRequest.Header.Add("authorization", "Bearer "+session.token.AccessToken)
|
||||
upstreamRequest.Header.Add("authorization", "Bearer "+sess.Token.AccessToken)
|
||||
upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing
|
||||
}
|
||||
|
||||
@@ -260,10 +261,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := h.getSessionFromCookie(r)
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
|
||||
if err == nil && session != nil && session.token != nil {
|
||||
h.deleteSession(session.id)
|
||||
if err == nil && sess != nil && sess.Token != nil {
|
||||
err = h.Sessions.Delete(r.Context(), sess.ID)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.deleteCookie(w, SessionCookieName)
|
||||
}
|
||||
// todo: test logout without credentials
|
||||
@@ -293,15 +299,15 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := h.sessions[sid]
|
||||
if !ok {
|
||||
sess, err := h.Sessions.Read(r.Context(), sid)
|
||||
if err != nil {
|
||||
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
|
||||
// We regard this as a redundant logout and return 200 OK.
|
||||
return
|
||||
}
|
||||
|
||||
// From here on, check that 'iss' from request matches data found in access token.
|
||||
tok, err := jwt.ParseSigned(session.token.AccessToken)
|
||||
tok, err := jwt.ParseSigned(sess.Token.AccessToken)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -325,7 +331,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// All verified; delete session.
|
||||
h.deleteSession(sid)
|
||||
err = h.Sessions.Delete(r.Context(), sid)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func New(handler *Handler) chi.Router {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
@@ -62,7 +63,8 @@ var idp = NewIDPorten(clients)
|
||||
|
||||
func handler() *router.Handler {
|
||||
handler := router.Handler{
|
||||
Config: cfg,
|
||||
Config: cfg,
|
||||
Sessions: session.NewMemory(),
|
||||
OauthConfig: oauth2.Config{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
@@ -77,7 +79,6 @@ func handler() *router.Handler {
|
||||
UpstreamHost: "",
|
||||
IdTokenVerifier: nil,
|
||||
}
|
||||
handler.Init()
|
||||
return &handler
|
||||
}
|
||||
|
||||
|
||||
@@ -2,38 +2,15 @@ package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/oauth2"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
id string
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
func (h *Handler) storeSession(key string, session session) {
|
||||
session.id = key
|
||||
h.lock.Lock()
|
||||
h.sessions[key] = session
|
||||
h.lock.Unlock()
|
||||
}
|
||||
|
||||
func (h *Handler) deleteSession(key string) {
|
||||
h.lock.Lock()
|
||||
delete(h.sessions, key)
|
||||
h.lock.Unlock()
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionFromCookie(r *http.Request) (*session, error) {
|
||||
func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
sessionID, err := h.getEncryptedCookie(r, SessionCookieName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
session, ok := h.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no token stored for session %s", sessionID)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
return h.Sessions.Read(r.Context(), sessionID)
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ type memorySessionStore struct {
|
||||
sessions map[string]*Data
|
||||
}
|
||||
|
||||
var _ Session = &memorySessionStore{}
|
||||
var _ Store = &memorySessionStore{}
|
||||
|
||||
func NewMemory() Session {
|
||||
func NewMemory() Store {
|
||||
return &memorySessionStore{
|
||||
sessions: make(map[string]*Data),
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ type redisSessionStore struct {
|
||||
client redis.Cmdable
|
||||
}
|
||||
|
||||
var _ Session = &redisSessionStore{}
|
||||
var _ Store = &redisSessionStore{}
|
||||
|
||||
func NewRedis(client redis.Cmdable) Session {
|
||||
func NewRedis(client redis.Cmdable) Store {
|
||||
return &redisSessionStore{
|
||||
client: client,
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
type Store interface {
|
||||
Write(ctx context.Context, key string, value *Data, expiration time.Duration) error
|
||||
Read(ctx context.Context, key string) (*Data, error)
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
|
||||
Reference in New Issue
Block a user