From 1f58b5ae153538d5fa81ff4e48e3aaf039c0fa93 Mon Sep 17 00:00:00 2001 From: Kent Daleng Date: Tue, 24 Aug 2021 09:59:34 +0200 Subject: [PATCH] write callback test Co-Authored-By: Trong Huu Nguyen --- go.mod | 1 + go.sum | 1 + pkg/router/idporten_mock_server_test.go | 140 ++++++++++++++++++++++-- pkg/router/router_test.go | 107 ++++++++++++------ 4 files changed, 202 insertions(+), 47 deletions(-) diff --git a/go.mod b/go.mod index 1e12c8f..0e0d0ad 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/coreos/go-oidc v2.1.0+incompatible github.com/go-chi/chi v1.5.4 + github.com/google/uuid v1.1.2 github.com/lestrrat-go/jwx v1.2.5 github.com/nais/liberator v0.0.0-20210809103005-edb0141d646d github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect diff --git a/go.sum b/go.sum index c651aeb..d8aa25f 100644 --- a/go.sum +++ b/go.sum @@ -254,6 +254,7 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= diff --git a/pkg/router/idporten_mock_server_test.go b/pkg/router/idporten_mock_server_test.go index d5c0d3a..e5c5b03 100644 --- a/pkg/router/idporten_mock_server_test.go +++ b/pkg/router/idporten_mock_server_test.go @@ -4,10 +4,16 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "fmt" "net/http" + "net/url" + "time" "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" ) type IDPorten struct { @@ -50,26 +56,132 @@ type TokenJSON struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token"` - ExpiresIn int32 `json:"expires_in"` + ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } -func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) { +func (ip *IDPorten) signToken(token jwt.Token) (string, error) { + signer, ok := ip.Keys.Get(0) + if !ok { + return "", fmt.Errorf("could not get signer") + } + + signedToken, err := jwt.Sign(token, jwa.RS256, signer) + if err != nil { + return "", err + } + + return string(signedToken), nil +} + +func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("malformed payload?")) + return + } + + code := r.PostForm.Get("code") + + if len(code) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing code")) + return + } + + auth, ok := ip.Codes[code] + if !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("no matching code")) + return + } + + expires := int64(1200) + + sub := uuid.New().String() + sid := uuid.New().String() + + accessToken := jwt.New() + accessToken.Set("sub", sub) + accessToken.Set("iss", cfg.WellKnown.Issuer) + accessToken.Set("acr", auth.AcrLevel) + accessToken.Set("iat", time.Now().Unix()) + accessToken.Set("exp", time.Now().Unix()+expires) + signedAccessToken, err := ip.signToken(accessToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not sign access token: " + err.Error())) + return + } + + idToken := jwt.New() + idToken.Set("sub", sub) + idToken.Set("iss", cfg.WellKnown.Issuer) + idToken.Set("aud", cfg.ClientID) + idToken.Set("locale", auth.Locale) + idToken.Set("nonce", auth.Nonce) + idToken.Set("acr", auth.AcrLevel) + idToken.Set("iat", time.Now().Unix()) + idToken.Set("exp", time.Now().Unix()+expires) + idToken.Set("sid", sid) + + signedIdToken, err := ip.signToken(idToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not sign access token: " + err.Error())) + return + } + + ip.Sessions[sid] = cfg.ClientID // fixme: generate valid access token and id token; sign them with the correct key token := &TokenJSON{ - AccessToken: "access-token", - TokenType: "token-type", - RefreshToken: "refresh-token", - IDToken: "id-token", - ExpiresIn: 1200, + AccessToken: string(signedAccessToken), + TokenType: "Bearer", + IDToken: string(signedIdToken), + ExpiresIn: expires, } w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(token) } -func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) { +func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + state := query.Get("state") + redirect := query.Get("redirect_uri") + acrLevel := query.Get("acr_values") + codeChallenge := query.Get("code_challenge") + locale := query.Get("ui_locales") + nonce := query.Get("nonce") + if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing required fields")) + return + } + + code := uuid.New().String() + ip.Codes[code] = AuthRequest{ + AcrLevel: acrLevel, + CodeChallenge: codeChallenge, + Locale: locale, + Nonce: nonce, + } + + u, err := url.Parse(redirect) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("couldn't parse redirect uri")) + return + } + v := url.Values{} + v.Set("code", code) + v.Set("state", state) + + u.RawQuery = v.Encode() + + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } func (ip *IDPorten) EndSession(w http.ResponseWriter, r *http.Request) { @@ -77,7 +189,13 @@ func (ip *IDPorten) EndSession(w http.ResponseWriter, r *http.Request) { } func (ip *IDPorten) Jwks(w http.ResponseWriter, r *http.Request) { - + publicSet, err := jwk.PublicSetOf(ip.Keys) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not create public set: " + err.Error())) + return + } + json.NewEncoder(w).Encode(publicSet) } func (ip *IDPorten) WellKnown(w http.ResponseWriter, r *http.Request) { @@ -86,8 +204,8 @@ func (ip *IDPorten) WellKnown(w http.ResponseWriter, r *http.Request) { func idportenRouter(ip *IDPorten) chi.Router { r := chi.NewRouter() - r.Post("/authorize", ip.Authorize) - r.Get("/token", ip.Token) + r.Get("/authorize", ip.Authorize) + r.Post("/token", ip.Token) r.Get("/endsession", ip.EndSession) r.Get("/jwks", ip.Jwks) r.Get("/.well-known/openid-configuration", ip.WellKnown) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 1214290..74500a8 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -1,7 +1,8 @@ package router_test import ( - "encoding/base64" + "context" + "fmt" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -10,6 +11,7 @@ import ( "golang.org/x/oauth2" + "github.com/coreos/go-oidc" "github.com/nais/wonderwall/pkg/cryptutil" "github.com/stretchr/testify/assert" @@ -18,10 +20,12 @@ import ( "github.com/nais/wonderwall/pkg/router" ) +const clientID = "clientid" + var encryptionKey = []byte(`G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`) // 256 bits AES var cfg = config.IDPorten{ - ClientID: "clientid", + ClientID: clientID, ClientJWK: ` { "kty": "RSA", @@ -37,7 +41,7 @@ var cfg = config.IDPorten{ "x5t": "9rJ_0ziKoGNjSS_l11hn0yQxEqg" } `, - RedirectURI: "http://localhost/redirect", + RedirectURI: "http://localhost/callback", WellKnownURL: "", WellKnown: config.IDPortenWellKnown{ Issuer: "issuer", @@ -48,8 +52,13 @@ var cfg = config.IDPorten{ PostLogoutRedirectURI: "", } +var clients = map[string]string{ + clientID: "http://localhost/oauth2/logout/frontchannel", +} +var idp = NewIDPorten(clients) + func handler() *router.Handler { - return &router.Handler{ + handler := router.Handler{ Config: cfg, OauthConfig: oauth2.Config{ ClientID: "client-id", @@ -65,6 +74,8 @@ func handler() *router.Handler { UpstreamHost: "", IdTokenVerifier: nil, } + handler.Init() + return &handler } func TestLoginURL(t *testing.T) { @@ -76,15 +87,21 @@ func TestLoginURL(t *testing.T) { } func TestHandler_Login(t *testing.T) { - r := router.New(handler()) - server := httptest.NewServer(r) + h := handler() + r := router.New(h) + server := httptest.NewServer(r) client := server.Client() client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } - req, err := client.Get(server.URL + "/oauth2/login") + idprouter := idportenRouter(idp) + idpserver := httptest.NewServer(idprouter) + + h.Config.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" + + req, err := client.Get(server.URL + "/oauth2/login") assert.NoError(t, err) defer req.Body.Close() @@ -92,14 +109,26 @@ func TestHandler_Login(t *testing.T) { u, err := url.Parse(location) assert.NoError(t, err) - assert.Equal(t, "localhost:1234", u.Host) + assert.Equal(t, idpserver.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host)) assert.Equal(t, "/authorize", u.Path) assert.Equal(t, cfg.SecurityLevel, u.Query().Get("acr_values")) + assert.Equal(t, cfg.Locale, u.Query().Get("ui_locales")) assert.Equal(t, cfg.ClientID, u.Query().Get("client_id")) assert.Equal(t, cfg.RedirectURI, u.Query().Get("redirect_uri")) assert.NotEmpty(t, u.Query().Get("state")) assert.NotEmpty(t, u.Query().Get("nonce")) assert.NotEmpty(t, u.Query().Get("code_challenge")) + + req, err = client.Get(u.String()) + assert.NoError(t, err) + defer req.Body.Close() + + location = req.Header.Get("location") + callbackURL, err := url.Parse(location) + assert.NoError(t, err) + + assert.Equal(t, u.Query().Get("state"), callbackURL.Query().Get("state")) + assert.NotEmpty(t, callbackURL.Query().Get("code")) } func TestHandler_Callback(t *testing.T) { @@ -107,13 +136,16 @@ func TestHandler_Callback(t *testing.T) { r := router.New(h) server := httptest.NewServer(r) - clients := map[string]string{ - h.Config.ClientID: server.URL + "/oauth2/logout/frontchannel", - } - idp := NewIDPorten(clients) idprouter := idportenRouter(idp) idpserver := httptest.NewServer(idprouter) - h.OauthConfig.Endpoint.TokenURL = idpserver.URL + "/authorize" + h.OauthConfig.Endpoint.TokenURL = idpserver.URL + "/token" + h.Config.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" + h.Config.RedirectURI = server.URL + "/oauth2/callback" + h.IdTokenVerifier = oidc.NewVerifier( + cfg.WellKnown.Issuer, + oidc.NewRemoteKeySet(context.Background(), idpserver.URL+"/jwks"), + &oidc.Config{ClientID: cfg.ClientID}, + ) jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -127,32 +159,35 @@ func TestHandler_Callback(t *testing.T) { // First, run /oauth2/login to set cookies req, err := client.Get(server.URL + "/oauth2/login") assert.NoError(t, err) - req.Body.Close() + defer req.Body.Close() - u, err := url.Parse(server.URL) + // Get authorization URL + location := req.Header.Get("location") + u, err := url.Parse(location) assert.NoError(t, err) - u.Path = "/oauth2/callback" - v := &url.Values{} - - mapping := map[string]string{ - router.NonceCookieName: "nonce", - router.StateCookieName: "state", - router.CodeVerifierCookieName: "code_verifier", - } - for _, cookie := range req.Cookies() { - ciphertext, err := base64.StdEncoding.DecodeString(cookie.Value) - if err != nil { - panic(err) - } - plaintext, err := h.Crypter.Decrypt(ciphertext) - if err != nil { - panic(err) - } - v.Set(mapping[cookie.Name], string(plaintext)) - } - u.RawQuery = v.Encode() - + // Follow redirect to authorize with idporten req, err = client.Get(u.String()) assert.NoError(t, err) + defer req.Body.Close() + + // Get callback URL after successful auth + location = req.Header.Get("location") + callbackURL, err := url.Parse(location) + assert.NoError(t, err) + + // Follow redirect to callback + req, err = client.Get(callbackURL.String()) + assert.NoError(t, err) + + cookies := client.Jar.Cookies(callbackURL) + var sessionCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == router.SessionCookieName { + sessionCookie = cookie + } + } + + assert.NotNil(t, sessionCookie) + }