refactor: separate login param generation

Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
Trong Huu Nguyen
2021-09-30 12:13:38 +02:00
parent dbc0a47a46
commit cf7ca9c5b8
3 changed files with 78 additions and 55 deletions

51
pkg/auth/login.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
)
type Parameters struct {
CodeVerifier string
CodeChallenge string
Nonce string
State string
}
func GenerateLoginParameters() (*Parameters, error) {
codeVerifier := make([]byte, 64)
nonce := make([]byte, 32)
state := make([]byte, 32)
var err error
_, err = io.ReadFull(rand.Reader, state)
if err != nil {
return nil, fmt.Errorf("failed to create state: %w", err)
}
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, fmt.Errorf("failed to create nonce: %w", err)
}
_, err = io.ReadFull(rand.Reader, codeVerifier)
if err != nil {
return nil, fmt.Errorf("failed to create code verifier: %w", err)
}
codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier))
hasher := sha256.New()
hasher.Write(codeVerifier)
codeVerifierHash := hasher.Sum(nil)
return &Parameters{
CodeVerifier: string(codeVerifier),
CodeChallenge: base64.RawURLEncoding.EncodeToString(codeVerifierHash),
Nonce: base64.RawURLEncoding.EncodeToString(nonce),
State: base64.RawURLEncoding.EncodeToString(state),
}, nil
}

View File

@@ -2,11 +2,9 @@ package router
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/auth"
"io"
"net/http"
"net/url"
@@ -80,73 +78,36 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler {
return h
}
type loginParams struct {
state string
codeVerifier string
url string
nonce string
}
func (h *Handler) LoginURL(r *http.Request) (*loginParams, error) {
codeVerifier := make([]byte, 64)
nonce := make([]byte, 32)
state := make([]byte, 32)
var err error
_, err = io.ReadFull(rand.Reader, state)
if err != nil {
return nil, fmt.Errorf("failed to create state: %w", err)
}
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, fmt.Errorf("failed to create nonce: %w", err)
}
_, err = io.ReadFull(rand.Reader, codeVerifier)
if err != nil {
return nil, fmt.Errorf("failed to create code verifier: %w", err)
}
codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier))
hasher := sha256.New()
hasher.Write(codeVerifier)
codeVerifierHash := hasher.Sum(nil)
func (h *Handler) LoginURL(r *http.Request, params *auth.Parameters) (string, error) {
u, err := url.Parse(h.Config.WellKnown.AuthorizationEndpoint)
if err != nil {
return nil, err
return "", err
}
v := u.Query()
v.Add("response_type", "code")
v.Add("client_id", h.Config.ClientID)
v.Add("redirect_uri", h.Config.RedirectURI)
v.Add("scope", token.ScopeOpenID)
v.Add("state", base64.RawURLEncoding.EncodeToString(state))
v.Add("nonce", base64.RawURLEncoding.EncodeToString(nonce))
v.Add("state", params.State)
v.Add("nonce", params.Nonce)
v.Add("response_mode", "query")
v.Add("code_challenge", base64.RawURLEncoding.EncodeToString(codeVerifierHash))
v.Add("code_challenge", params.CodeChallenge)
v.Add("code_challenge_method", "S256")
err = h.withSecurityLevel(r, v)
if err != nil {
return nil, fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
}
err = h.withLocale(r, v)
if err != nil {
return nil, fmt.Errorf("%w: %+v", InvalidLocaleError, err)
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
}
u.RawQuery = v.Encode()
return &loginParams{
state: base64.RawURLEncoding.EncodeToString(state),
nonce: base64.RawURLEncoding.EncodeToString(nonce),
codeVerifier: string(codeVerifier),
url: u.String(),
}, nil
return u.String(), nil
}
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
@@ -184,7 +145,14 @@ func (h *Handler) withLocale(r *http.Request, v url.Values) error {
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
params, err := h.LoginURL(r)
params, err := auth.GenerateLoginParameters()
if err != nil {
log.Errorf("generating login parameters: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
loginURL, err := h.LoginURL(r, params)
if err != nil {
log.Errorf("login URL: %+v", err)
@@ -202,9 +170,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
err = h.setLoginCookie(w, &LoginCookie{
State: params.state,
Nonce: params.nonce,
CodeVerifier: params.codeVerifier,
State: params.State,
Nonce: params.Nonce,
CodeVerifier: params.CodeVerifier,
Referer: CanonicalRedirectURL(r),
})
if err != nil {
@@ -213,7 +181,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, params.url, http.StatusTemporaryRedirect)
http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
}
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {

View File

@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/auth"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
@@ -134,8 +135,11 @@ func TestLoginURL(t *testing.T) {
req, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
params, err := auth.GenerateLoginParameters()
assert.NoError(t, err)
handler := handler(cfg)
_, err = handler.LoginURL(req)
_, err = handler.LoginURL(req, params)
if test.error != nil {
assert.True(t, errors.Is(err, test.error))