diff --git a/pkg/router/router.go b/pkg/router/router.go index c5d8fe6..99d6f37 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -42,7 +42,7 @@ type Handler struct { Crypter cryptutil.Crypter UpstreamHost string IdTokenVerifier *oidc.IDTokenVerifier - sessions map[string]*oauth2.Token + sessions map[string]session lock sync.Mutex } @@ -54,7 +54,7 @@ type loginParams struct { } func (h *Handler) Init() { - h.sessions = make(map[string]*oauth2.Token) + h.sessions = make(map[string]session) } func (h *Handler) LoginURL() (*loginParams, error) { @@ -164,22 +164,24 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), } - token, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...) + tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...) if err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) return } - idToken, err := auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, token, cookies.Nonce) + idToken, err := auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, tokens, cookies.Nonce) if err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) return } + var claims struct { SessionID string `json:"sid"` } + if err := idToken.Claims(&claims); err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) @@ -193,7 +195,9 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - h.storeSession(claims.SessionID, token) + h.storeSession(claims.SessionID, session{ + token: tokens, + }) // fixme: distributed session store for multi-pod deployments @@ -204,28 +208,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithCancel(r.Context()) defer cancel() - upstreamRequest := r.Clone(ctx) + // Duplicate the incoming request, and delete any authentication. + upstreamRequest := r.Clone(ctx) upstreamRequest.Header.Del("authorization") - // fixme: let upstream application decide what to do with unauthenticated clients - // Get credentials from session cache - sessionID, err := h.getEncryptedCookie(r, SessionCookieName) - if err != nil { - log.Tracef("no session cookie; should redirect to /oauth2/login") - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) - return - } - token, ok := h.sessions[sessionID] - if !ok { - log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID) - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) - return + session, err := h.getSessionFromCookie(r) + if err == nil && session != nil && session.token != nil { + // add authentication if session cookie and token checks out + upstreamRequest.Header.Add("authorization", "Bearer "+session.token.AccessToken) + upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing } - // Duplicate the incoming request, and add authentication. - upstreamRequest.Header.Add("authorization", "Bearer "+token.AccessToken) - upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing // Request should go to correct host // req.Header.Set("host", req.Host) upstreamRequest.Host = h.UpstreamHost // fixme @@ -264,23 +258,13 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - sessionID, err := h.getEncryptedCookie(r, SessionCookieName) - if err != nil { - log.Tracef("no session cookie; should redirect to /oauth2/login") - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) - return + session, err := h.getSessionFromCookie(r) + + if err == nil && session != nil && session.token != nil { + h.deleteSession(session.id) + h.deleteCookie(w, SessionCookieName) } - - _, ok := h.sessions[sessionID] - if !ok { - log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID) - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) - return - } - - h.deleteSession(sessionID) - - h.deleteCookie(w, SessionCookieName) + // todo: test logout without credentials u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint) if err != nil { @@ -297,10 +281,12 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { func New(handler *Handler) chi.Router { r := chi.NewRouter() - r.With(middleware.DefaultLogger) - r.Get("/oauth2/login", handler.Login) - r.Get("/oauth2/callback", handler.Callback) - r.Get("/oauth2/logout_self", handler.Logout) + r.Route("/oauth2", func(r chi.Router) { + r.With(middleware.NoCache) + r.Get("/login", handler.Login) + r.Get("/callback", handler.Callback) + r.Get("/logout", handler.Logout) + }) r.HandleFunc("/*", handler.Default) return r } diff --git a/pkg/router/session.go b/pkg/router/session.go index d570140..d0d38ae 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -1,12 +1,20 @@ package router import ( + "fmt" "golang.org/x/oauth2" + "net/http" ) -func (h *Handler) storeSession(key string, token *oauth2.Token) { +type session struct { + id string + token *oauth2.Token +} + +func (h *Handler) storeSession(key string, session session) { + session.id = key h.lock.Lock() - h.sessions[key] = token + h.sessions[key] = session h.lock.Unlock() } @@ -15,3 +23,17 @@ func (h *Handler) deleteSession(key string) { delete(h.sessions, key) h.lock.Unlock() } + +func (h *Handler) getSessionFromCookie(r *http.Request) (*session, error) { + sessionID, err := h.getEncryptedCookie(r, SessionCookieName) + if err != nil { + return nil, fmt.Errorf("no session cookie: %w", err) + } + + session, ok := h.sessions[sessionID] + if !ok { + return nil, fmt.Errorf("no token stored for session %s", sessionID) + } + + return &session, nil +}