From 1fdbe75c9e4bd5dfa4bdfdaaade4544f2af07683 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 3 Feb 2023 14:01:38 +0100 Subject: [PATCH] feat(sso/proxy): implement login handler --- pkg/handler/error/error.go | 2 +- pkg/handler/handler_sso_proxy.go | 53 ++++++++++++++++++++++++++++++-- pkg/handler/reverseproxy.go | 2 +- pkg/handler/url/url.go | 50 +++++++++++++++++++----------- pkg/handler/url/url_test.go | 46 +++++++++++++++++++++++++-- pkg/ingress/ingress.go | 16 ++++++++++ 6 files changed, 145 insertions(+), 24 deletions(-) diff --git a/pkg/handler/error/error.go b/pkg/handler/error/error.go index bf1f340..b93319e 100644 --- a/pkg/handler/error/error.go +++ b/pkg/handler/error/error.go @@ -75,7 +75,7 @@ func (h Handler) Retry(r *http.Request, loginCookie *openid.LoginCookie) string redirect = loginCookie.Referer } - return urlpkg.LoginURL(ingressPath, redirect) + return urlpkg.LoginRelative(ingressPath, redirect) } func (h Handler) respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error, level log.Level) { diff --git a/pkg/handler/handler_sso_proxy.go b/pkg/handler/handler_sso_proxy.go index b7ae766..05f283f 100644 --- a/pkg/handler/handler_sso_proxy.go +++ b/pkg/handler/handler_sso_proxy.go @@ -5,8 +5,13 @@ import ( "net/http" urllib "net/url" + log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/handler/url" "github.com/nais/wonderwall/pkg/ingress" + mw "github.com/nais/wonderwall/pkg/middleware" + openidclient "github.com/nais/wonderwall/pkg/openid/client" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/router/paths" ) @@ -30,6 +35,18 @@ func NewSSOProxyHandler(cfg *config.Config) (*SSOProxyHandler, error) { return nil, fmt.Errorf("parsing sso server url: %w", err) } + query := u.Query() + + if len(cfg.OpenID.ACRValues) > 0 { + query.Set(openidclient.SecurityLevelURLParameter, cfg.OpenID.ACRValues) + } + + if len(cfg.OpenID.UILocales) > 0 { + query.Set(openidclient.LocaleURLParameter, cfg.OpenID.UILocales) + } + + u.RawQuery = query.Encode() + return &SSOProxyHandler{ Config: cfg, Ingresses: ingresses, @@ -38,8 +55,40 @@ func NewSSOProxyHandler(cfg *config.Config) (*SSOProxyHandler, error) { } func (s *SSOProxyHandler) Login(w http.ResponseWriter, r *http.Request) { - // TODO redirect to sso-server - panic("implement me") + target := *s.SSOServerURL + targetQuery := target.Query() + + reqQuery := r.URL.Query() + + if reqQuery.Has(openidclient.SecurityLevelURLParameter) { + targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter)) + } + + if reqQuery.Has(openidclient.LocaleURLParameter) { + targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter)) + } + + target.RawQuery = reqQuery.Encode() + + redirect, err := url.Ingress(r) + if err != nil { + redirect = s.Ingresses.Single().NewURL() + } + parsedRedirect, err := urllib.ParseRequestURI(reqQuery.Get(url.RedirectURLParameter)) + if err == nil { + redirect = redirect.JoinPath(parsedRedirect.Path) + } + + ssoServerLoginURL := url.Login(&target, redirect.String()) + + mw.LogEntryFrom(r). + WithFields(log.Fields{ + "redirect_to": ssoServerLoginURL, + "redirect_after_login": redirect.String(), + }). + Info("login: redirecting to sso server") + + http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect) } func (s *SSOProxyHandler) LoginCallback(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/handler/reverseproxy.go b/pkg/handler/reverseproxy.go index cf31bfd..84c5dec 100644 --- a/pkg/handler/reverseproxy.go +++ b/pkg/handler/reverseproxy.go @@ -92,7 +92,7 @@ func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r redirectTarget := r.URL.String() path := src.GetPath(r) - loginUrl := url.LoginURL(path, redirectTarget) + loginUrl := url.LoginRelative(path, redirectTarget) fields := logrus.Fields{ "redirect_after_login": redirectTarget, "redirect_to": loginUrl, diff --git a/pkg/handler/url/url.go b/pkg/handler/url/url.go index 6fb0fa2..220f428 100644 --- a/pkg/handler/url/url.go +++ b/pkg/handler/url/url.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "net/url" - "path" mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/router/paths" @@ -61,17 +60,31 @@ func CanonicalRedirect(r *http.Request) string { return redirect } -func LoginURL(prefix, redirectTarget string) string { - u := new(url.URL) - u.Path = path.Join(prefix, paths.OAuth2, paths.Login) +// Login constructs a URL string that points to the login path for the given target URL. +// The given redirect string should point to the location to be redirected to after login. +func Login(target *url.URL, redirect string) string { + u := target.JoinPath(paths.OAuth2, paths.Login) - v := url.Values{} - v.Set(RedirectURLParameter, redirectTarget) + v := u.Query() + v.Set(RedirectURLParameter, redirect) u.RawQuery = v.Encode() return u.String() } +// LoginRelative constructs the relative URL with an absolute path that points to the application's login path, given an optional path prefix. +// The given redirect string should point to the location to be redirected to after login. +func LoginRelative(prefix, redirect string) string { + u := new(url.URL) + u.Path = prefix + + if prefix == "" { + u.Path = "/" + } + + return Login(u, redirect) +} + func LoginCallbackURL(r *http.Request) (string, error) { return makeCallbackURL(r, paths.LoginCallback) } @@ -81,18 +94,19 @@ func LogoutCallbackURL(r *http.Request) (string, error) { } func makeCallbackURL(r *http.Request, callbackPath string) (string, error) { - match, found := mw.IngressFrom(r.Context()) - if !found { - return "", fmt.Errorf("request host does not match any configured ingresses") + u, err := Ingress(r) + if err != nil { + return "", err } - targetPath := path.Join(match.Path(), paths.OAuth2, callbackPath) - - targetUrl := url.URL{ - Host: match.Host(), - Path: targetPath, - Scheme: match.Scheme, - } - - return targetUrl.String(), nil + return u.JoinPath(paths.OAuth2, callbackPath).String(), nil +} + +func Ingress(r *http.Request) (*url.URL, error) { + ing, found := mw.IngressFrom(r.Context()) + if !found { + return nil, fmt.Errorf("request host does not match any configured ingresses") + } + + return ing.NewURL(), nil } diff --git a/pkg/handler/url/url_test.go b/pkg/handler/url/url_test.go index 43118f1..e3667b3 100644 --- a/pkg/handler/url/url_test.go +++ b/pkg/handler/url/url_test.go @@ -143,7 +143,49 @@ func TestCanonicalRedirect(t *testing.T) { }) } -func TestLoginURL(t *testing.T) { +func TestLogin(t *testing.T) { + for _, test := range []struct { + name string + targetURL string + redirectTarget string + want string + }{ + { + name: "root path", + targetURL: "https://sso.wonderwall", + redirectTarget: "https://test.example.com?some=param&other=param2", + want: "https://sso.wonderwall/oauth2/login?redirect=https%3A%2F%2Ftest.example.com%3Fsome%3Dparam%26other%3Dparam2", + }, + { + name: "with prefix", + targetURL: "https://sso.wonderwall/path", + redirectTarget: "https://test.example.com?some=param&other=param2", + want: "https://sso.wonderwall/path/oauth2/login?redirect=https%3A%2F%2Ftest.example.com%3Fsome%3Dparam%26other%3Dparam2", + }, + { + name: "we need to go deeper", + targetURL: "https://sso.wonderwall/deeper/path", + redirectTarget: "https://test.example.com?some=param&other=param2", + want: "https://sso.wonderwall/deeper/path/oauth2/login?redirect=https%3A%2F%2Ftest.example.com%3Fsome%3Dparam%26other%3Dparam2", + }, + { + name: "relative redirect target", + targetURL: "https://sso.wonderwall", + redirectTarget: "/path?some=param&other=param2", + want: "https://sso.wonderwall/oauth2/login?redirect=%2Fpath%3Fsome%3Dparam%26other%3Dparam2", + }, + } { + t.Run(test.name, func(t *testing.T) { + targetURL, err := url.Parse(test.targetURL) + assert.NoError(t, err) + + loginUrl := urlpkg.Login(targetURL, test.redirectTarget) + assert.Equal(t, test.want, loginUrl) + }) + } +} + +func TestLoginRelative(t *testing.T) { for _, test := range []struct { name string prefix string @@ -176,7 +218,7 @@ func TestLoginURL(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - loginUrl := urlpkg.LoginURL(test.prefix, test.redirectTarget) + loginUrl := urlpkg.LoginRelative(test.prefix, test.redirectTarget) assert.Equal(t, test.want, loginUrl) }) } diff --git a/pkg/ingress/ingress.go b/pkg/ingress/ingress.go index 9b5ef96..4ebd363 100644 --- a/pkg/ingress/ingress.go +++ b/pkg/ingress/ingress.go @@ -98,6 +98,17 @@ func (i *Ingresses) MatchingPath(r *http.Request) string { return result } +func (i *Ingresses) Single() Ingress { + var res Ingress + + for _, v := range i.ingressMap { + res = v + break + } + + return res +} + func mapIngresses(ingresses map[string]Ingress, fn func(i Ingress) string) []string { seen := make(map[string]bool, 0) result := make([]string, 0) @@ -172,3 +183,8 @@ func (i Ingress) Host() string { func (i Ingress) String() string { return i.URL.String() } + +func (i Ingress) NewURL() *url.URL { + u := *i.URL + return &u +}