refactor: use httputil.ReverseProxy for default route

This commit is contained in:
Trong Huu Nguyen
2021-10-14 13:07:57 +02:00
parent 8724e37e0d
commit 5ce7d979c7

View File

@@ -1,67 +1,62 @@
package router
import (
"context"
"io"
"net/http"
"net/http/httputil"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/session"
)
// Default proxies all requests upstream
func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
isAuthenticated := false
// Duplicate the incoming request, and delete any authentication.
upstreamRequest := r.Clone(ctx)
upstreamRequest.Header.Del("authorization")
upstreamRequest.Header.Del("x-pwned-by")
sess, err := h.getSessionFromCookie(w, r)
if err == nil && sess != nil && len(sess.AccessToken) > 0 {
sessionData, err := h.getSessionFromCookie(w, r)
if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 {
// add authentication if session cookie and token checks out
upstreamRequest.Header.Add("authorization", "Bearer "+sess.AccessToken)
upstreamRequest.Header.Add("x-pwned-by", "wonderwall") // todo: request id for tracing
isAuthenticated = true
} else if h.Config.AutoLogin {
r.Header.Add("Referer", r.URL.String())
h.Login(w, r)
return
}
// Request should go to correct host
upstreamRequest.Host = r.Host
upstreamRequest.URL.Host = h.UpstreamHost
upstreamRequest.URL.Scheme = "http"
upstreamRequest.RequestURI = ""
// Attach request body from original request
upstreamRequest.Body = r.Body
defer upstreamRequest.Body.Close()
director := func(upstreamRequest *http.Request) {
modifyRequest(upstreamRequest, r, h.UpstreamHost)
// Make sure requests aren't silently redirected
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
upstreamResponse, err := client.Do(upstreamRequest)
if err != nil {
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte(err.Error()))
return
}
for key, values := range upstreamResponse.Header {
for _, value := range values {
w.Header().Add(key, value)
if isAuthenticated {
withAuthentication(upstreamRequest, sessionData)
}
}
w.WriteHeader(upstreamResponse.StatusCode)
// Forward server's reply downstream
_, err = io.Copy(w, upstreamResponse.Body)
if err != nil {
log.Errorf("proxy data from upstream to client: %s", err)
errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte(err.Error()))
}
reverseProxy := httputil.ReverseProxy{
Director: director,
ErrorHandler: errorHandler,
}
reverseProxy.ServeHTTP(w, r)
}
func modifyRequest(dst, src *http.Request, upstreamHost string) {
// Delete incoming authentication
dst.Header.Del("authorization")
dst.Header.Del("x-pwned-by")
// Instruct http.ReverseProxy to not modify X-Forwarded-For header
dst.Header["X-Forwarded-For"] = nil
// Request should go to correct host
dst.Host = src.Host
dst.URL.Host = upstreamHost
dst.URL.Scheme = "http"
dst.RequestURI = ""
// Attach request body from original request
dst.Body = src.Body
}
func withAuthentication(dst *http.Request, sessionData *session.Data) {
dst.Header.Add("authorization", "Bearer "+sessionData.AccessToken)
dst.Header.Add("x-pwned-by", "wonderwall")
}