mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-07 00:46:56 +00:00
The use of these headers in upstreams may be risky, espeically if Wonderwall is accidentally misconfigured or disabled, or requests are performed directly to the upstream circumventing Wonderwall. We should prefer using a signed token or similar that can be verified by the upstreams.
235 lines
6.3 KiB
Go
235 lines
6.3 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
urllib "net/url"
|
|
"strings"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"github.com/nais/wonderwall/pkg/handler/acr"
|
|
"github.com/nais/wonderwall/pkg/handler/autologin"
|
|
mw "github.com/nais/wonderwall/pkg/middleware"
|
|
"github.com/nais/wonderwall/pkg/server"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
"github.com/nais/wonderwall/pkg/url"
|
|
)
|
|
|
|
type ReverseProxySource interface {
|
|
GetAcrHandler() *acr.Handler
|
|
GetAutoLogin() *autologin.AutoLogin
|
|
GetPath(r *http.Request) string
|
|
GetSession(r *http.Request) (*session.Session, error)
|
|
}
|
|
|
|
type ReverseProxy struct {
|
|
*httputil.ReverseProxy
|
|
EnableAccessLogs bool
|
|
}
|
|
|
|
func NewUpstreamProxy(upstream *urllib.URL, enableAccessLogs bool) *ReverseProxy {
|
|
rp := NewReverseProxy(upstream, true)
|
|
rp.EnableAccessLogs = enableAccessLogs
|
|
return rp
|
|
}
|
|
|
|
func NewReverseProxy(upstream *urllib.URL, preserveInboundHostHeader bool) *ReverseProxy {
|
|
rp := &httputil.ReverseProxy{
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger := mw.LogEntryFrom(r)
|
|
|
|
if errors.Is(err, context.Canceled) {
|
|
w.WriteHeader(499)
|
|
} else {
|
|
logger.Warnf("reverseproxy: proxy error: %+v", err)
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
},
|
|
ErrorLog: log.New(logrusErrorWriter{}, "reverseproxy: ", 0),
|
|
Rewrite: func(r *httputil.ProxyRequest) {
|
|
// preserve inbound Forwarded and X-Forwarded-* headers that is stripped when using Rewrite
|
|
r.Out.Header["Forwarded"] = r.In.Header["Forwarded"]
|
|
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
|
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
|
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
|
r.SetURL(upstream)
|
|
|
|
if preserveInboundHostHeader {
|
|
// preserve the inbound request's Host header
|
|
r.Out.Host = r.In.Host
|
|
}
|
|
|
|
accessToken, ok := mw.AccessTokenFrom(r.In.Context())
|
|
if ok {
|
|
r.Out.Header.Set("authorization", "Bearer "+accessToken)
|
|
}
|
|
},
|
|
Transport: server.DefaultTransport(),
|
|
}
|
|
return &ReverseProxy{
|
|
ReverseProxy: rp,
|
|
}
|
|
}
|
|
|
|
func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r *http.Request) {
|
|
logger := mw.LogEntryFrom(r)
|
|
isAuthenticated := false
|
|
|
|
sess, accessToken, err := getSessionWithValidToken(src, r)
|
|
switch {
|
|
case err == nil:
|
|
// add authentication if session checks out
|
|
isAuthenticated = true
|
|
case errors.Is(err, context.Canceled):
|
|
logger.Debugf("default: unauthenticated: %+v (client disconnected before we could respond)", err)
|
|
case errors.Is(err, session.ErrInvalidExternal):
|
|
logger.Warnf("default: unauthenticated: %+v", err)
|
|
case errors.Is(err, session.ErrNotFound):
|
|
logger.Debugf("default: unauthenticated: %+v", err)
|
|
case errors.Is(err, session.ErrInvalid):
|
|
logger.Infof("default: unauthenticated: %+v", err)
|
|
default:
|
|
logger.Errorf("default: unauthenticated: unexpected error: %+v", err)
|
|
}
|
|
|
|
ctx := r.Context()
|
|
if sess != nil {
|
|
if sid := sess.ExternalSessionID(); sid != "" {
|
|
logger = logger.WithField("sid", sid)
|
|
}
|
|
}
|
|
|
|
err = src.GetAcrHandler().Validate(sess)
|
|
if err != nil {
|
|
isAuthenticated = false
|
|
logger.Infof("default: unauthenticated: acr: %+v; checking for autologin...", err)
|
|
}
|
|
|
|
if src.GetAutoLogin().NeedsLogin(r, isAuthenticated) {
|
|
handleAutologin(src, w, r, logger)
|
|
return
|
|
}
|
|
|
|
if isAuthenticated {
|
|
ctx = mw.WithAccessToken(ctx, accessToken)
|
|
|
|
if rp.EnableAccessLogs && isRelevantAccessLog(r) {
|
|
logger.Info("default: authenticated request")
|
|
}
|
|
}
|
|
|
|
rp.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
func getSessionWithValidToken(src ReverseProxySource, r *http.Request) (*session.Session, string, error) {
|
|
sess, err := src.GetSession(r)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
accessToken, err := sess.AccessToken()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return sess, accessToken, nil
|
|
}
|
|
|
|
func handleAutologin(src ReverseProxySource, w http.ResponseWriter, r *http.Request, logger *logrus.Entry) {
|
|
path := src.GetPath(r)
|
|
|
|
loginURL := func(redirectTarget, message string) string {
|
|
// we don't validate/clean the redirect target as this is done by the login handler anyway
|
|
loginURL := url.LoginRelative(path, redirectTarget)
|
|
|
|
logger.WithFields(logrus.Fields{
|
|
"redirect_after_login": redirectTarget,
|
|
"login_url": loginURL,
|
|
}).Infof("default: unauthenticated: autologin: %s", message)
|
|
|
|
return loginURL
|
|
}
|
|
|
|
if isNavigationRequest(r) {
|
|
target := r.URL.String()
|
|
location := loginURL(target, "navigation request detected; redirecting to login...")
|
|
http.Redirect(w, r, location, http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// not a navigation request, so we can't respond with 3xx to redirect
|
|
target := r.Referer()
|
|
if target == "" {
|
|
target = path
|
|
}
|
|
|
|
location := loginURL(target, "non-navigation request detected; responding with 401 and Location header")
|
|
w.Header().Set("Location", location)
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
|
|
if accepts(r, "*/*", "application/json") {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(`{"error": "unauthenticated, please log in"}`))
|
|
} else {
|
|
w.Write([]byte("unauthenticated, please log in"))
|
|
}
|
|
}
|
|
|
|
func isRelevantAccessLog(r *http.Request) bool {
|
|
if r.Method == http.MethodGet {
|
|
// only log GET requests that are navigation requests
|
|
return isNavigationRequest(r)
|
|
}
|
|
|
|
// all other methods are relevant
|
|
return true
|
|
}
|
|
|
|
func isNavigationRequest(r *http.Request) bool {
|
|
// we assume that navigation requests are always GET requests
|
|
if r.Method != http.MethodGet {
|
|
return false
|
|
}
|
|
|
|
// check for top-level navigation requests
|
|
mode := r.Header.Get("Sec-Fetch-Mode")
|
|
dest := r.Header.Get("Sec-Fetch-Dest")
|
|
if mode != "" && dest != "" {
|
|
return mode == "navigate" && dest == "document"
|
|
}
|
|
|
|
// fallback if browser doesn't support fetch metadata
|
|
return accepts(r, "text/html")
|
|
}
|
|
|
|
func accepts(r *http.Request, accepted ...string) bool {
|
|
// iterate over all Accept headers
|
|
for _, header := range r.Header.Values("Accept") {
|
|
// iterate over all comma-separated values in a single Accept header
|
|
for _, v := range strings.Split(header, ",") {
|
|
v = strings.ToLower(v)
|
|
v = strings.TrimSpace(v)
|
|
v = strings.Split(v, ";")[0]
|
|
|
|
for _, accept := range accepted {
|
|
if v == accept {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
type logrusErrorWriter struct{}
|
|
|
|
func (w logrusErrorWriter) Write(p []byte) (n int, err error) {
|
|
logrus.Warnf("%s", string(p))
|
|
return len(p), nil
|
|
}
|