diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index ec708fa..f8dd02f 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -66,15 +66,6 @@ func NewClient(cfg openidconfig.Config, jwksProvider JwksProvider) *Client { } } -func (c *Client) Login(r *http.Request) (*Login, error) { - login, err := NewLogin(c, r) - if err != nil { - return nil, fmt.Errorf("login: %w", err) - } - - return login, nil -} - func (c *Client) Logout(r *http.Request) (*Logout, error) { logout, err := NewLogout(c, r) if err != nil { diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 3a61517..4d1415a 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -46,45 +46,40 @@ var ( type Login struct { AuthCodeURL string - Acr string - Locale string - Prompt string - *openid.LoginCookie + authorizationRequest + Cookie *openid.LoginCookie } type authorizationRequest struct { - acr string - callbackURL string - codeVerifier string - locale string - nonce string - prompt string - state string + Acr string + CallbackURL string + CodeVerifier string + Locale string + Nonce string + Prompt string + State string } -// TODO: remove indirection layer -func NewLogin(c *Client, r *http.Request) (*Login, error) { +func (c *Client) Login(r *http.Request) (*Login, error) { request, err := c.newAuthorizationRequest(r) if err != nil { - return nil, err + return nil, fmt.Errorf("login: %w", err) } authCodeURL, err := c.authCodeURL(r.Context(), request) if err != nil { - return nil, fmt.Errorf("generating auth code url: %w", err) + return nil, fmt.Errorf("login: generating auth code url: %w", err) } return &Login{ - AuthCodeURL: authCodeURL, - Acr: request.acr, - Locale: request.locale, - Prompt: request.prompt, - LoginCookie: &openid.LoginCookie{ - Acr: request.acr, - CodeVerifier: request.codeVerifier, - State: request.state, - Nonce: request.nonce, - RedirectURI: request.callbackURL, + AuthCodeURL: authCodeURL, + authorizationRequest: *request, + Cookie: &openid.LoginCookie{ + Acr: request.Acr, + CodeVerifier: request.CodeVerifier, + State: request.State, + Nonce: request.Nonce, + RedirectURI: request.CallbackURL, }, }, nil } @@ -123,13 +118,13 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest codeVerifier := oauth2.GenerateVerifier() return &authorizationRequest{ - acr: acr, - callbackURL: callbackURL, - codeVerifier: codeVerifier, - locale: locale, - nonce: nonce, - prompt: prompt, - state: state, + Acr: acr, + CallbackURL: callbackURL, + CodeVerifier: codeVerifier, + Locale: locale, + Nonce: nonce, + Prompt: prompt, + State: state, }, nil } @@ -138,57 +133,57 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" { opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("nonce", request.nonce), + oauth2.SetAuthURLParam("nonce", request.Nonce), oauth2.SetAuthURLParam("response_mode", "query"), - oauth2.S256ChallengeOption(request.codeVerifier), - openid.RedirectURIOption(request.callbackURL), + oauth2.S256ChallengeOption(request.CodeVerifier), + openid.RedirectURIOption(request.CallbackURL), } if resource := c.cfg.Client().ResourceIndicator(); resource != "" { opts = append(opts, oauth2.SetAuthURLParam("resource", resource)) } - if len(request.acr) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.acr)) + if len(request.Acr) > 0 { + opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.Acr)) } - if len(request.locale) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.locale)) + if len(request.Locale) > 0 { + opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.Locale)) } - if len(request.prompt) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.prompt)) + if len(request.Prompt) > 0 { + opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.Prompt)) opts = append(opts, oauth2.SetAuthURLParam(MaxAgeURLParameter, "0")) } - authCodeURL = c.oauth2Config.AuthCodeURL(request.state, opts...) + authCodeURL = c.oauth2Config.AuthCodeURL(request.State, opts...) } else { params := map[string]string{ "client_id": c.oauth2Config.ClientID, - "code_challenge": oauth2.S256ChallengeFromVerifier(request.codeVerifier), + "code_challenge": oauth2.S256ChallengeFromVerifier(request.CodeVerifier), "code_challenge_method": "S256", - "nonce": request.nonce, - "redirect_uri": request.callbackURL, + "nonce": request.Nonce, + "redirect_uri": request.CallbackURL, "response_mode": "query", "response_type": "code", "scope": stringslib.Join(c.oauth2Config.Scopes, " "), - "state": request.state, + "state": request.State, } if resource := c.cfg.Client().ResourceIndicator(); resource != "" { params["resource"] = resource } - if len(request.acr) > 0 { - params[LoginParameterMapping[SecurityLevelURLParameter]] = request.acr + if len(request.Acr) > 0 { + params[LoginParameterMapping[SecurityLevelURLParameter]] = request.Acr } - if len(request.locale) > 0 { - params[LoginParameterMapping[LocaleURLParameter]] = request.locale + if len(request.Locale) > 0 { + params[LoginParameterMapping[LocaleURLParameter]] = request.Locale } - if len(request.prompt) > 0 { - params[PromptURLParameter] = request.prompt + if len(request.Prompt) > 0 { + params[PromptURLParameter] = request.Prompt params[MaxAgeURLParameter] = "0" } @@ -252,9 +247,9 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) } func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error { - l.LoginCookie.Referer = canonicalRedirect + l.Cookie.Referer = canonicalRedirect - loginCookieJson, err := json.Marshal(l.LoginCookie) + loginCookieJson, err := json.Marshal(l.Cookie) if err != nil { return fmt.Errorf("marshalling login cookie: %w", err) }