mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-14 12:26:34 +00:00
deduplication; store sessions with name
This commit is contained in:
@@ -42,7 +42,7 @@ type Handler struct {
|
||||
Crypter cryptutil.Crypter
|
||||
UpstreamHost string
|
||||
IdTokenVerifier *oidc.IDTokenVerifier
|
||||
sessions map[string]*oauth2.Token
|
||||
sessions map[string]session
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ type loginParams struct {
|
||||
}
|
||||
|
||||
func (h *Handler) Init() {
|
||||
h.sessions = make(map[string]*oauth2.Token)
|
||||
h.sessions = make(map[string]session)
|
||||
}
|
||||
|
||||
func (h *Handler) LoginURL() (*loginParams, error) {
|
||||
@@ -164,22 +164,24 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
|
||||
}
|
||||
|
||||
token, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...)
|
||||
tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, token, cookies.Nonce)
|
||||
idToken, err := auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, tokens, cookies.Nonce)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
SessionID string `json:"sid"`
|
||||
}
|
||||
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -193,7 +195,9 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.storeSession(claims.SessionID, token)
|
||||
h.storeSession(claims.SessionID, session{
|
||||
token: tokens,
|
||||
})
|
||||
|
||||
// fixme: distributed session store for multi-pod deployments
|
||||
|
||||
@@ -204,28 +208,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
upstreamRequest := r.Clone(ctx)
|
||||
|
||||
// Duplicate the incoming request, and delete any authentication.
|
||||
upstreamRequest := r.Clone(ctx)
|
||||
upstreamRequest.Header.Del("authorization")
|
||||
|
||||
// fixme: let upstream application decide what to do with unauthenticated clients
|
||||
// Get credentials from session cache
|
||||
sessionID, err := h.getEncryptedCookie(r, SessionCookieName)
|
||||
if err != nil {
|
||||
log.Tracef("no session cookie; should redirect to /oauth2/login")
|
||||
http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
token, ok := h.sessions[sessionID]
|
||||
if !ok {
|
||||
log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID)
|
||||
http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
session, err := h.getSessionFromCookie(r)
|
||||
if err == nil && session != nil && session.token != nil {
|
||||
// add authentication if session cookie and token checks out
|
||||
upstreamRequest.Header.Add("authorization", "Bearer "+session.token.AccessToken)
|
||||
upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing
|
||||
}
|
||||
|
||||
// Duplicate the incoming request, and add authentication.
|
||||
upstreamRequest.Header.Add("authorization", "Bearer "+token.AccessToken)
|
||||
upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing
|
||||
// Request should go to correct host
|
||||
// req.Header.Set("host", req.Host)
|
||||
upstreamRequest.Host = h.UpstreamHost // fixme
|
||||
@@ -264,23 +258,13 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, err := h.getEncryptedCookie(r, SessionCookieName)
|
||||
if err != nil {
|
||||
log.Tracef("no session cookie; should redirect to /oauth2/login")
|
||||
http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
session, err := h.getSessionFromCookie(r)
|
||||
|
||||
if err == nil && session != nil && session.token != nil {
|
||||
h.deleteSession(session.id)
|
||||
h.deleteCookie(w, SessionCookieName)
|
||||
}
|
||||
|
||||
_, ok := h.sessions[sessionID]
|
||||
if !ok {
|
||||
log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID)
|
||||
http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
h.deleteSession(sessionID)
|
||||
|
||||
h.deleteCookie(w, SessionCookieName)
|
||||
// todo: test logout without credentials
|
||||
|
||||
u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint)
|
||||
if err != nil {
|
||||
@@ -297,10 +281,12 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func New(handler *Handler) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.With(middleware.DefaultLogger)
|
||||
r.Get("/oauth2/login", handler.Login)
|
||||
r.Get("/oauth2/callback", handler.Callback)
|
||||
r.Get("/oauth2/logout_self", handler.Logout)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.With(middleware.NoCache)
|
||||
r.Get("/login", handler.Login)
|
||||
r.Get("/callback", handler.Callback)
|
||||
r.Get("/logout", handler.Logout)
|
||||
})
|
||||
r.HandleFunc("/*", handler.Default)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (h *Handler) storeSession(key string, token *oauth2.Token) {
|
||||
type session struct {
|
||||
id string
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
func (h *Handler) storeSession(key string, session session) {
|
||||
session.id = key
|
||||
h.lock.Lock()
|
||||
h.sessions[key] = token
|
||||
h.sessions[key] = session
|
||||
h.lock.Unlock()
|
||||
}
|
||||
|
||||
@@ -15,3 +23,17 @@ func (h *Handler) deleteSession(key string) {
|
||||
delete(h.sessions, key)
|
||||
h.lock.Unlock()
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionFromCookie(r *http.Request) (*session, error) {
|
||||
sessionID, err := h.getEncryptedCookie(r, SessionCookieName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
session, ok := h.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no token stored for session %s", sessionID)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user