refactor: remove indirection layer for login client

Co-authored-by: tronghn <trong.huu.nguyen@nav.no>
This commit is contained in:
Sindre Rødseth Hansen
2025-01-22 13:52:23 +01:00
committed by Trong Huu Nguyen
parent c442000be4
commit ade44f0950
2 changed files with 49 additions and 63 deletions

View File

@@ -66,15 +66,6 @@ func NewClient(cfg openidconfig.Config, jwksProvider JwksProvider) *Client {
}
}
func (c *Client) Login(r *http.Request) (*Login, error) {
login, err := NewLogin(c, r)
if err != nil {
return nil, fmt.Errorf("login: %w", err)
}
return login, nil
}
func (c *Client) Logout(r *http.Request) (*Logout, error) {
logout, err := NewLogout(c, r)
if err != nil {

View File

@@ -46,45 +46,40 @@ var (
type Login struct {
AuthCodeURL string
Acr string
Locale string
Prompt string
*openid.LoginCookie
authorizationRequest
Cookie *openid.LoginCookie
}
type authorizationRequest struct {
acr string
callbackURL string
codeVerifier string
locale string
nonce string
prompt string
state string
Acr string
CallbackURL string
CodeVerifier string
Locale string
Nonce string
Prompt string
State string
}
// TODO: remove indirection layer
func NewLogin(c *Client, r *http.Request) (*Login, error) {
func (c *Client) Login(r *http.Request) (*Login, error) {
request, err := c.newAuthorizationRequest(r)
if err != nil {
return nil, err
return nil, fmt.Errorf("login: %w", err)
}
authCodeURL, err := c.authCodeURL(r.Context(), request)
if err != nil {
return nil, fmt.Errorf("generating auth code url: %w", err)
return nil, fmt.Errorf("login: generating auth code url: %w", err)
}
return &Login{
AuthCodeURL: authCodeURL,
Acr: request.acr,
Locale: request.locale,
Prompt: request.prompt,
LoginCookie: &openid.LoginCookie{
Acr: request.acr,
CodeVerifier: request.codeVerifier,
State: request.state,
Nonce: request.nonce,
RedirectURI: request.callbackURL,
AuthCodeURL: authCodeURL,
authorizationRequest: *request,
Cookie: &openid.LoginCookie{
Acr: request.Acr,
CodeVerifier: request.CodeVerifier,
State: request.State,
Nonce: request.Nonce,
RedirectURI: request.CallbackURL,
},
}, nil
}
@@ -123,13 +118,13 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest
codeVerifier := oauth2.GenerateVerifier()
return &authorizationRequest{
acr: acr,
callbackURL: callbackURL,
codeVerifier: codeVerifier,
locale: locale,
nonce: nonce,
prompt: prompt,
state: state,
Acr: acr,
CallbackURL: callbackURL,
CodeVerifier: codeVerifier,
Locale: locale,
Nonce: nonce,
Prompt: prompt,
State: state,
}, nil
}
@@ -138,57 +133,57 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)
if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" {
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("nonce", request.nonce),
oauth2.SetAuthURLParam("nonce", request.Nonce),
oauth2.SetAuthURLParam("response_mode", "query"),
oauth2.S256ChallengeOption(request.codeVerifier),
openid.RedirectURIOption(request.callbackURL),
oauth2.S256ChallengeOption(request.CodeVerifier),
openid.RedirectURIOption(request.CallbackURL),
}
if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
opts = append(opts, oauth2.SetAuthURLParam("resource", resource))
}
if len(request.acr) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.acr))
if len(request.Acr) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.Acr))
}
if len(request.locale) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.locale))
if len(request.Locale) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.Locale))
}
if len(request.prompt) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.prompt))
if len(request.Prompt) > 0 {
opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.Prompt))
opts = append(opts, oauth2.SetAuthURLParam(MaxAgeURLParameter, "0"))
}
authCodeURL = c.oauth2Config.AuthCodeURL(request.state, opts...)
authCodeURL = c.oauth2Config.AuthCodeURL(request.State, opts...)
} else {
params := map[string]string{
"client_id": c.oauth2Config.ClientID,
"code_challenge": oauth2.S256ChallengeFromVerifier(request.codeVerifier),
"code_challenge": oauth2.S256ChallengeFromVerifier(request.CodeVerifier),
"code_challenge_method": "S256",
"nonce": request.nonce,
"redirect_uri": request.callbackURL,
"nonce": request.Nonce,
"redirect_uri": request.CallbackURL,
"response_mode": "query",
"response_type": "code",
"scope": stringslib.Join(c.oauth2Config.Scopes, " "),
"state": request.state,
"state": request.State,
}
if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
params["resource"] = resource
}
if len(request.acr) > 0 {
params[LoginParameterMapping[SecurityLevelURLParameter]] = request.acr
if len(request.Acr) > 0 {
params[LoginParameterMapping[SecurityLevelURLParameter]] = request.Acr
}
if len(request.locale) > 0 {
params[LoginParameterMapping[LocaleURLParameter]] = request.locale
if len(request.Locale) > 0 {
params[LoginParameterMapping[LocaleURLParameter]] = request.Locale
}
if len(request.prompt) > 0 {
params[PromptURLParameter] = request.prompt
if len(request.Prompt) > 0 {
params[PromptURLParameter] = request.Prompt
params[MaxAgeURLParameter] = "0"
}
@@ -252,9 +247,9 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)
}
func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
l.LoginCookie.Referer = canonicalRedirect
l.Cookie.Referer = canonicalRedirect
loginCookieJson, err := json.Marshal(l.LoginCookie)
loginCookieJson, err := json.Marshal(l.Cookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}