mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-08 17:37:01 +00:00
refactor: separate login param generation
Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
51
pkg/auth/login.go
Normal file
51
pkg/auth/login.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Parameters struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
Nonce string
|
||||
State string
|
||||
}
|
||||
|
||||
func GenerateLoginParameters() (*Parameters, error) {
|
||||
codeVerifier := make([]byte, 64)
|
||||
nonce := make([]byte, 32)
|
||||
state := make([]byte, 32)
|
||||
|
||||
var err error
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create state: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create nonce: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, codeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier))
|
||||
hasher := sha256.New()
|
||||
hasher.Write(codeVerifier)
|
||||
codeVerifierHash := hasher.Sum(nil)
|
||||
|
||||
return &Parameters{
|
||||
CodeVerifier: string(codeVerifier),
|
||||
CodeChallenge: base64.RawURLEncoding.EncodeToString(codeVerifierHash),
|
||||
Nonce: base64.RawURLEncoding.EncodeToString(nonce),
|
||||
State: base64.RawURLEncoding.EncodeToString(state),
|
||||
}, nil
|
||||
}
|
||||
@@ -2,11 +2,9 @@ package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/auth"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -80,73 +78,36 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler {
|
||||
return h
|
||||
}
|
||||
|
||||
type loginParams struct {
|
||||
state string
|
||||
codeVerifier string
|
||||
url string
|
||||
nonce string
|
||||
}
|
||||
|
||||
func (h *Handler) LoginURL(r *http.Request) (*loginParams, error) {
|
||||
codeVerifier := make([]byte, 64)
|
||||
nonce := make([]byte, 32)
|
||||
state := make([]byte, 32)
|
||||
|
||||
var err error
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create state: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create nonce: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(rand.Reader, codeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier))
|
||||
hasher := sha256.New()
|
||||
hasher.Write(codeVerifier)
|
||||
codeVerifierHash := hasher.Sum(nil)
|
||||
|
||||
func (h *Handler) LoginURL(r *http.Request, params *auth.Parameters) (string, error) {
|
||||
u, err := url.Parse(h.Config.WellKnown.AuthorizationEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
v := u.Query()
|
||||
v.Add("response_type", "code")
|
||||
v.Add("client_id", h.Config.ClientID)
|
||||
v.Add("redirect_uri", h.Config.RedirectURI)
|
||||
v.Add("scope", token.ScopeOpenID)
|
||||
v.Add("state", base64.RawURLEncoding.EncodeToString(state))
|
||||
v.Add("nonce", base64.RawURLEncoding.EncodeToString(nonce))
|
||||
v.Add("state", params.State)
|
||||
v.Add("nonce", params.Nonce)
|
||||
v.Add("response_mode", "query")
|
||||
v.Add("code_challenge", base64.RawURLEncoding.EncodeToString(codeVerifierHash))
|
||||
v.Add("code_challenge", params.CodeChallenge)
|
||||
v.Add("code_challenge_method", "S256")
|
||||
|
||||
err = h.withSecurityLevel(r, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
|
||||
return "", fmt.Errorf("%w: %+v", InvalidSecurityLevelError, err)
|
||||
}
|
||||
|
||||
err = h.withLocale(r, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %+v", InvalidLocaleError, err)
|
||||
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
|
||||
}
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
return &loginParams{
|
||||
state: base64.RawURLEncoding.EncodeToString(state),
|
||||
nonce: base64.RawURLEncoding.EncodeToString(nonce),
|
||||
codeVerifier: string(codeVerifier),
|
||||
url: u.String(),
|
||||
}, nil
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
|
||||
@@ -184,7 +145,14 @@ func (h *Handler) withLocale(r *http.Request, v url.Values) error {
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := h.LoginURL(r)
|
||||
params, err := auth.GenerateLoginParameters()
|
||||
if err != nil {
|
||||
log.Errorf("generating login parameters: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginURL, err := h.LoginURL(r, params)
|
||||
if err != nil {
|
||||
log.Errorf("login URL: %+v", err)
|
||||
|
||||
@@ -202,9 +170,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
err = h.setLoginCookie(w, &LoginCookie{
|
||||
State: params.state,
|
||||
Nonce: params.nonce,
|
||||
CodeVerifier: params.codeVerifier,
|
||||
State: params.State,
|
||||
Nonce: params.Nonce,
|
||||
CodeVerifier: params.CodeVerifier,
|
||||
Referer: CanonicalRedirectURL(r),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -213,7 +181,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, params.url, http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/auth"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
@@ -134,8 +135,11 @@ func TestLoginURL(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params, err := auth.GenerateLoginParameters()
|
||||
assert.NoError(t, err)
|
||||
|
||||
handler := handler(cfg)
|
||||
_, err = handler.LoginURL(req)
|
||||
_, err = handler.LoginURL(req, params)
|
||||
|
||||
if test.error != nil {
|
||||
assert.True(t, errors.Is(err, test.error))
|
||||
|
||||
Reference in New Issue
Block a user