mirror of
https://github.com/nais/wonderwall.git
synced 2026-02-14 17:49:54 +00:00
235 lines
7.2 KiB
Go
235 lines
7.2 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
urllib "net/url"
|
|
|
|
httpinternal "github.com/nais/wonderwall/internal/http"
|
|
"github.com/nais/wonderwall/internal/o11y/otel"
|
|
"github.com/nais/wonderwall/pkg/handler/acr"
|
|
"github.com/nais/wonderwall/pkg/handler/autologin"
|
|
mw "github.com/nais/wonderwall/pkg/middleware"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
"github.com/nais/wonderwall/pkg/url"
|
|
"github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
type ReverseProxySource interface {
|
|
GetAcrHandler() *acr.Handler
|
|
GetAutoLogin() *autologin.AutoLogin
|
|
GetPath(r *http.Request) string
|
|
GetSession(r *http.Request) (*session.Session, error)
|
|
}
|
|
|
|
type ReverseProxy struct {
|
|
*httputil.ReverseProxy
|
|
EnableAccessLogs bool
|
|
IncludeIDToken bool
|
|
}
|
|
|
|
func NewUpstreamProxy(upstream *urllib.URL, enableAccessLogs bool, includeIDToken bool) *ReverseProxy {
|
|
rp := NewReverseProxy(upstream, true)
|
|
rp.EnableAccessLogs = enableAccessLogs
|
|
rp.IncludeIDToken = includeIDToken
|
|
return rp
|
|
}
|
|
|
|
func NewReverseProxy(upstream *urllib.URL, preserveInboundHostHeader bool) *ReverseProxy {
|
|
rp := &httputil.ReverseProxy{
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger := mw.LogEntryFrom(r).WithFields(httpinternal.Attributes(r))
|
|
|
|
if errors.Is(err, context.Canceled) {
|
|
w.WriteHeader(499)
|
|
} else {
|
|
logger.Warnf("reverseproxy: proxy error: %+v", err)
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
},
|
|
ErrorLog: log.New(logrusErrorWriter{}, "reverseproxy: ", 0),
|
|
Rewrite: func(r *httputil.ProxyRequest) {
|
|
// preserve inbound Forwarded and X-Forwarded-* headers that is stripped when using Rewrite
|
|
// this presumes that we're behind a trusted reverse proxy (e.g. gateway or ingress controller)
|
|
r.Out.Header["Forwarded"] = r.In.Header["Forwarded"]
|
|
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
|
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
|
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
|
r.SetURL(upstream)
|
|
|
|
if preserveInboundHostHeader {
|
|
// preserve the inbound request's Host header
|
|
r.Out.Host = r.In.Host
|
|
}
|
|
|
|
accessToken, ok := mw.AccessTokenFrom(r.In.Context())
|
|
if ok {
|
|
r.Out.Header.Set("authorization", "Bearer "+accessToken)
|
|
}
|
|
|
|
idToken, ok := mw.IDTokenFrom(r.In.Context())
|
|
if ok {
|
|
r.Out.Header.Set("X-Wonderwall-Id-Token", idToken)
|
|
} else {
|
|
// remove the header if it was set by the client
|
|
r.Out.Header.Del("X-Wonderwall-Id-Token")
|
|
}
|
|
},
|
|
Transport: httpinternal.Transport(),
|
|
}
|
|
return &ReverseProxy{
|
|
ReverseProxy: rp,
|
|
}
|
|
}
|
|
|
|
func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r *http.Request) {
|
|
r, span := otel.StartSpanFromRequest(r, "ReverseProxy")
|
|
defer span.End()
|
|
|
|
logger := mw.LogEntryFrom(r).WithFields(httpinternal.Attributes(r))
|
|
|
|
unauthenticatedEvent := func(level logrus.Level, errType string, err error) {
|
|
logger.WithError(err).Logf(level, "unauthenticated: %+v", err)
|
|
otel.AddErrorEvent(span, "unauthenticated", errType, err)
|
|
}
|
|
|
|
isAuthenticated := false
|
|
sess, accessToken, err := getSessionWithValidToken(src, r)
|
|
switch {
|
|
case err == nil:
|
|
// add authentication if session checks out
|
|
isAuthenticated = true
|
|
case errors.Is(err, context.Canceled):
|
|
unauthenticatedEvent(logrus.DebugLevel, "context.Canceled", fmt.Errorf("client disconnected before we could respond: %w", err))
|
|
case errors.Is(err, session.ErrInvalidExternal):
|
|
unauthenticatedEvent(logrus.WarnLevel, "session.ErrInvalidExternal", err)
|
|
case errors.Is(err, session.ErrNotFound):
|
|
unauthenticatedEvent(logrus.DebugLevel, "session.ErrNotFound", err)
|
|
case errors.Is(err, session.ErrInvalid):
|
|
unauthenticatedEvent(logrus.InfoLevel, "session.ErrInvalid", err)
|
|
default:
|
|
unauthenticatedEvent(logrus.ErrorLevel, "unexpected", fmt.Errorf("unexpected error: %w", err))
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
|
|
ctx := r.Context()
|
|
if sess != nil {
|
|
if sid := sess.ExternalSessionID(); sid != "" {
|
|
logger = logger.WithField("sid", sid)
|
|
}
|
|
}
|
|
|
|
err = src.GetAcrHandler().Validate(sess)
|
|
if err != nil {
|
|
isAuthenticated = false
|
|
logger.Infof("default: unauthenticated: acr: %+v; checking for autologin...", err)
|
|
}
|
|
|
|
span.SetAttributes(attribute.Bool("session.authenticated", isAuthenticated))
|
|
|
|
if src.GetAutoLogin().NeedsLogin(r, isAuthenticated) {
|
|
span.SetAttributes(attribute.Bool("proxy.needs_autologin", true))
|
|
handleAutologin(src, w, r, logger)
|
|
return
|
|
}
|
|
|
|
if isAuthenticated {
|
|
ctx = mw.WithAccessToken(ctx, accessToken)
|
|
span.SetAttributes(attribute.Bool("proxy.with_access_token", true))
|
|
if rp.IncludeIDToken && sess != nil {
|
|
ctx = mw.WithIDToken(ctx, sess.IDToken())
|
|
span.SetAttributes(attribute.Bool("proxy.with_id_token", true))
|
|
}
|
|
|
|
if rp.EnableAccessLogs && isRelevantAccessLog(r) {
|
|
logger.Info("default: authenticated request")
|
|
}
|
|
}
|
|
|
|
ctx, span = otel.StartSpan(ctx, "ReverseProxy.ServeHTTP")
|
|
defer span.End()
|
|
rp.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
func getSessionWithValidToken(src ReverseProxySource, r *http.Request) (*session.Session, string, error) {
|
|
sess, err := src.GetSession(r)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
accessToken, err := sess.AccessToken()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return sess, accessToken, nil
|
|
}
|
|
|
|
func handleAutologin(src ReverseProxySource, w http.ResponseWriter, r *http.Request, logger *logrus.Entry) {
|
|
span := trace.SpanFromContext(r.Context())
|
|
path := src.GetPath(r)
|
|
|
|
loginURL := func(redirectTarget, message string) string {
|
|
// we don't validate/clean the redirect target as this is done by the login handler anyway
|
|
loginURL := url.LoginRelative(path, redirectTarget)
|
|
|
|
logger.WithFields(logrus.Fields{
|
|
"redirect_after_login": redirectTarget,
|
|
"login_url": loginURL,
|
|
}).Infof("default: unauthenticated: autologin: %s", message)
|
|
span.SetAttributes(attribute.String("autologin.redirect_after", redirectTarget))
|
|
span.SetAttributes(attribute.String("autologin.login_url", loginURL))
|
|
|
|
return loginURL
|
|
}
|
|
|
|
if httpinternal.IsNavigationRequest(r) {
|
|
target := r.URL.String()
|
|
location := loginURL(target, "navigation request detected; redirecting to login...")
|
|
http.Redirect(w, r, location, http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// not a navigation request, so we can't respond with 3xx to redirect
|
|
target := r.Referer()
|
|
if target == "" {
|
|
target = path
|
|
}
|
|
|
|
location := loginURL(target, "non-navigation request detected; responding with 401 and Location header")
|
|
w.Header().Set("Location", location)
|
|
|
|
if httpinternal.Accepts(r, "*/*", "application/json") {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "unauthenticated, please log in"}`))
|
|
} else {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte("unauthenticated, please log in"))
|
|
}
|
|
}
|
|
|
|
func isRelevantAccessLog(r *http.Request) bool {
|
|
if r.Method == http.MethodGet {
|
|
// only log GET requests that are navigation requests
|
|
return httpinternal.IsNavigationRequest(r)
|
|
}
|
|
|
|
// all other methods are relevant
|
|
return true
|
|
}
|
|
|
|
type logrusErrorWriter struct{}
|
|
|
|
func (w logrusErrorWriter) Write(p []byte) (n int, err error) {
|
|
logrus.Warnf("%s", string(p))
|
|
return len(p), nil
|
|
}
|