mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-10 10:27:02 +00:00
refactor(openid/client): extract oauth request method
Co-authored-by: sindrerh2 <sindre.rodseth.hansen@nav.no>
This commit is contained in:
@@ -94,37 +94,16 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestBody := strings.NewReader(params.URLValues(map[string]string{
|
||||
payload := params.URLValues(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": c.cfg.Client().ClientID(),
|
||||
}).Encode())
|
||||
}).Encode()
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), requestBody)
|
||||
endpoint := c.cfg.Provider().TokenEndpoint()
|
||||
body, err := c.oauthPostRequest(ctx, endpoint, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("performing request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading server response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
||||
var errorResponse openid.TokenErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
|
||||
} else if resp.StatusCode >= 500 {
|
||||
return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenResponse openid.TokenResponse
|
||||
@@ -183,3 +162,34 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
|
||||
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
func (c *Client) oauthPostRequest(ctx context.Context, endpoint, payload string) ([]byte, error) {
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("performing request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading server response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
||||
var errorResponse openid.TokenErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
|
||||
} else if resp.StatusCode >= 500 {
|
||||
return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, payload)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
urllib "net/url"
|
||||
"slices"
|
||||
@@ -45,9 +44,9 @@ var (
|
||||
)
|
||||
|
||||
type Login struct {
|
||||
AuthCodeURL string
|
||||
authorizationRequest
|
||||
Cookie *openid.LoginCookie
|
||||
AuthCodeURL string
|
||||
Cookie openid.LoginCookie
|
||||
}
|
||||
|
||||
type authorizationRequest struct {
|
||||
@@ -60,6 +59,16 @@ type authorizationRequest struct {
|
||||
State string
|
||||
}
|
||||
|
||||
func (a authorizationRequest) ToCookie() openid.LoginCookie {
|
||||
return openid.LoginCookie{
|
||||
Acr: a.Acr,
|
||||
CodeVerifier: a.CodeVerifier,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
RedirectURI: a.CallbackURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Login(r *http.Request) (*Login, error) {
|
||||
request, err := c.newAuthorizationRequest(r)
|
||||
if err != nil {
|
||||
@@ -73,52 +82,48 @@ func (c *Client) Login(r *http.Request) (*Login, error) {
|
||||
|
||||
return &Login{
|
||||
AuthCodeURL: authCodeURL,
|
||||
authorizationRequest: *request,
|
||||
Cookie: &openid.LoginCookie{
|
||||
Acr: request.Acr,
|
||||
CodeVerifier: request.CodeVerifier,
|
||||
State: request.State,
|
||||
Nonce: request.Nonce,
|
||||
RedirectURI: request.CallbackURL,
|
||||
},
|
||||
authorizationRequest: request,
|
||||
Cookie: request.ToCookie(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest, error) {
|
||||
func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, error) {
|
||||
var req authorizationRequest
|
||||
|
||||
callbackURL, err := url.LoginCallback(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating callback url: %w", err)
|
||||
return req, fmt.Errorf("generating callback url: %w", err)
|
||||
}
|
||||
|
||||
acr, err := getAcrParam(c, r)
|
||||
acrParam, err := getAcrParam(c, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err)
|
||||
return req, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err)
|
||||
}
|
||||
|
||||
locale, err := getLocaleParam(c, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrInvalidLocale, err)
|
||||
return req, fmt.Errorf("%w: %w", ErrInvalidLocale, err)
|
||||
}
|
||||
|
||||
prompt, err := getPromptParam(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrInvalidPrompt, err)
|
||||
return req, fmt.Errorf("%w: %w", ErrInvalidPrompt, err)
|
||||
}
|
||||
|
||||
nonce, err := strings.GenerateBase64(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating nonce: %w", err)
|
||||
return req, fmt.Errorf("creating nonce: %w", err)
|
||||
}
|
||||
|
||||
state, err := strings.GenerateBase64(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating state: %w", err)
|
||||
return req, fmt.Errorf("creating state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
|
||||
return &authorizationRequest{
|
||||
Acr: acr,
|
||||
return authorizationRequest{
|
||||
Acr: acrParam,
|
||||
CallbackURL: callbackURL,
|
||||
CodeVerifier: codeVerifier,
|
||||
Locale: locale,
|
||||
@@ -128,7 +133,7 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) (string, error) {
|
||||
func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) (string, error) {
|
||||
var authCodeURL string
|
||||
|
||||
if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" {
|
||||
@@ -192,35 +197,11 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)
|
||||
return "", fmt.Errorf("generating client authentication parameters: %w", err)
|
||||
}
|
||||
|
||||
urlValues := authParams.URLValues(params)
|
||||
|
||||
requestBody := stringslib.NewReader(urlValues.Encode())
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody)
|
||||
payload := authParams.URLValues(params).Encode()
|
||||
endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint()
|
||||
body, err := c.oauthPostRequest(ctx, endpoint, payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("performing request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading server response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
||||
var errorResponse openid.TokenErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
|
||||
}
|
||||
return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
|
||||
} else if resp.StatusCode >= 500 {
|
||||
return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var pushedAuthorizationResponse openid.PushedAuthorizationResponse
|
||||
|
||||
Reference in New Issue
Block a user