From cf7ca9c5b87d86ba8aa888d54a60260159ea1b76 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 30 Sep 2021 12:13:38 +0200 Subject: [PATCH] refactor: separate login param generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Sindre Rødseth Hansen --- pkg/auth/login.go | 51 ++++++++++++++++++++++++++ pkg/router/router.go | 76 ++++++++++++--------------------------- pkg/router/router_test.go | 6 +++- 3 files changed, 78 insertions(+), 55 deletions(-) create mode 100644 pkg/auth/login.go diff --git a/pkg/auth/login.go b/pkg/auth/login.go new file mode 100644 index 0000000..d785a40 --- /dev/null +++ b/pkg/auth/login.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" +) + +type Parameters struct { + CodeVerifier string + CodeChallenge string + Nonce string + State string +} + +func GenerateLoginParameters() (*Parameters, error) { + codeVerifier := make([]byte, 64) + nonce := make([]byte, 32) + state := make([]byte, 32) + + var err error + + _, err = io.ReadFull(rand.Reader, state) + if err != nil { + return nil, fmt.Errorf("failed to create state: %w", err) + } + + _, err = io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, fmt.Errorf("failed to create nonce: %w", err) + } + + _, err = io.ReadFull(rand.Reader, codeVerifier) + if err != nil { + return nil, fmt.Errorf("failed to create code verifier: %w", err) + } + + codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier)) + hasher := sha256.New() + hasher.Write(codeVerifier) + codeVerifierHash := hasher.Sum(nil) + + return &Parameters{ + CodeVerifier: string(codeVerifier), + CodeChallenge: base64.RawURLEncoding.EncodeToString(codeVerifierHash), + Nonce: base64.RawURLEncoding.EncodeToString(nonce), + State: base64.RawURLEncoding.EncodeToString(state), + }, nil +} diff --git a/pkg/router/router.go b/pkg/router/router.go index a6ec58d..5f98856 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -2,11 +2,9 @@ package router import ( "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" "errors" "fmt" + "github.com/nais/wonderwall/pkg/auth" "io" "net/http" "net/url" @@ -80,73 +78,36 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler { return h } -type loginParams struct { - state string - codeVerifier string - url string - nonce string -} - -func (h *Handler) LoginURL(r *http.Request) (*loginParams, error) { - codeVerifier := make([]byte, 64) - nonce := make([]byte, 32) - state := make([]byte, 32) - - var err error - - _, err = io.ReadFull(rand.Reader, state) - if err != nil { - return nil, fmt.Errorf("failed to create state: %w", err) - } - - _, err = io.ReadFull(rand.Reader, nonce) - if err != nil { - return nil, fmt.Errorf("failed to create nonce: %w", err) - } - - _, err = io.ReadFull(rand.Reader, codeVerifier) - if err != nil { - return nil, fmt.Errorf("failed to create code verifier: %w", err) - } - - codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier)) - hasher := sha256.New() - hasher.Write(codeVerifier) - codeVerifierHash := hasher.Sum(nil) - +func (h *Handler) LoginURL(r *http.Request, params *auth.Parameters) (string, error) { u, err := url.Parse(h.Config.WellKnown.AuthorizationEndpoint) if err != nil { - return nil, err + return "", err } + v := u.Query() v.Add("response_type", "code") v.Add("client_id", h.Config.ClientID) v.Add("redirect_uri", h.Config.RedirectURI) v.Add("scope", token.ScopeOpenID) - v.Add("state", base64.RawURLEncoding.EncodeToString(state)) - v.Add("nonce", base64.RawURLEncoding.EncodeToString(nonce)) + v.Add("state", params.State) + v.Add("nonce", params.Nonce) v.Add("response_mode", "query") - v.Add("code_challenge", base64.RawURLEncoding.EncodeToString(codeVerifierHash)) + v.Add("code_challenge", params.CodeChallenge) v.Add("code_challenge_method", "S256") err = h.withSecurityLevel(r, v) if err != nil { - return nil, fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err) + return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err) } err = h.withLocale(r, v) if err != nil { - return nil, fmt.Errorf("%w: %+v", InvalidLocaleError, err) + return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err) } u.RawQuery = v.Encode() - return &loginParams{ - state: base64.RawURLEncoding.EncodeToString(state), - nonce: base64.RawURLEncoding.EncodeToString(nonce), - codeVerifier: string(codeVerifier), - url: u.String(), - }, nil + return u.String(), nil } func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error { @@ -184,7 +145,14 @@ func (h *Handler) withLocale(r *http.Request, v url.Values) error { } func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { - params, err := h.LoginURL(r) + params, err := auth.GenerateLoginParameters() + if err != nil { + log.Errorf("generating login parameters: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + loginURL, err := h.LoginURL(r, params) if err != nil { log.Errorf("login URL: %+v", err) @@ -202,9 +170,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { } err = h.setLoginCookie(w, &LoginCookie{ - State: params.state, - Nonce: params.nonce, - CodeVerifier: params.codeVerifier, + State: params.State, + Nonce: params.Nonce, + CodeVerifier: params.CodeVerifier, Referer: CanonicalRedirectURL(r), }) if err != nil { @@ -213,7 +181,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { return } - http.Redirect(w, r, params.url, http.StatusTemporaryRedirect) + http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect) } func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 97b5902..ab291a3 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/nais/wonderwall/pkg/auth" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -134,8 +135,11 @@ func TestLoginURL(t *testing.T) { req, err := http.NewRequest("GET", test.url, nil) assert.NoError(t, err) + params, err := auth.GenerateLoginParameters() + assert.NoError(t, err) + handler := handler(cfg) - _, err = handler.LoginURL(req) + _, err = handler.LoginURL(req, params) if test.error != nil { assert.True(t, errors.Is(err, test.error))