refactor(openid/client): extract oauth request method

Co-authored-by: sindrerh2 <sindre.rodseth.hansen@nav.no>
This commit is contained in:
Trong Huu Nguyen
2025-01-23 10:17:13 +01:00
parent ab418c456c
commit 642457b950
2 changed files with 67 additions and 76 deletions

View File

@@ -94,37 +94,16 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
return nil, err
}
requestBody := strings.NewReader(params.URLValues(map[string]string{
payload := params.URLValues(map[string]string{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": c.cfg.Client().ClientID(),
}).Encode())
}).Encode()
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), requestBody)
endpoint := c.cfg.Provider().TokenEndpoint()
body, err := c.oauthPostRequest(ctx, endpoint, payload)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(r)
if err != nil {
return nil, fmt.Errorf("performing request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading server response: %w", err)
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
var errorResponse openid.TokenErrorResponse
if err := json.Unmarshal(body, &errorResponse); err != nil {
return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
}
return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
} else if resp.StatusCode >= 500 {
return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
return nil, err
}
var tokenResponse openid.TokenResponse
@@ -183,3 +162,34 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
return string(encoded), nil
}
func (c *Client) oauthPostRequest(ctx context.Context, endpoint, payload string) ([]byte, error) {
r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(r)
if err != nil {
return nil, fmt.Errorf("performing request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading server response: %w", err)
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
var errorResponse openid.TokenErrorResponse
if err := json.Unmarshal(body, &errorResponse); err != nil {
return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
}
return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
} else if resp.StatusCode >= 500 {
return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, payload)
}
return body, nil
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
urllib "net/url"
"slices"
@@ -45,9 +44,9 @@ var (
)
type Login struct {
AuthCodeURL string
authorizationRequest
Cookie *openid.LoginCookie
AuthCodeURL string
Cookie openid.LoginCookie
}
type authorizationRequest struct {
@@ -60,6 +59,16 @@ type authorizationRequest struct {
State string
}
func (a authorizationRequest) ToCookie() openid.LoginCookie {
return openid.LoginCookie{
Acr: a.Acr,
CodeVerifier: a.CodeVerifier,
Nonce: a.Nonce,
State: a.State,
RedirectURI: a.CallbackURL,
}
}
func (c *Client) Login(r *http.Request) (*Login, error) {
request, err := c.newAuthorizationRequest(r)
if err != nil {
@@ -73,52 +82,48 @@ func (c *Client) Login(r *http.Request) (*Login, error) {
return &Login{
AuthCodeURL: authCodeURL,
authorizationRequest: *request,
Cookie: &openid.LoginCookie{
Acr: request.Acr,
CodeVerifier: request.CodeVerifier,
State: request.State,
Nonce: request.Nonce,
RedirectURI: request.CallbackURL,
},
authorizationRequest: request,
Cookie: request.ToCookie(),
}, nil
}
func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest, error) {
func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, error) {
var req authorizationRequest
callbackURL, err := url.LoginCallback(r)
if err != nil {
return nil, fmt.Errorf("generating callback url: %w", err)
return req, fmt.Errorf("generating callback url: %w", err)
}
acr, err := getAcrParam(c, r)
acrParam, err := getAcrParam(c, r)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err)
return req, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err)
}
locale, err := getLocaleParam(c, r)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidLocale, err)
return req, fmt.Errorf("%w: %w", ErrInvalidLocale, err)
}
prompt, err := getPromptParam(r)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidPrompt, err)
return req, fmt.Errorf("%w: %w", ErrInvalidPrompt, err)
}
nonce, err := strings.GenerateBase64(32)
if err != nil {
return nil, fmt.Errorf("creating nonce: %w", err)
return req, fmt.Errorf("creating nonce: %w", err)
}
state, err := strings.GenerateBase64(32)
if err != nil {
return nil, fmt.Errorf("creating state: %w", err)
return req, fmt.Errorf("creating state: %w", err)
}
codeVerifier := oauth2.GenerateVerifier()
return &authorizationRequest{
Acr: acr,
return authorizationRequest{
Acr: acrParam,
CallbackURL: callbackURL,
CodeVerifier: codeVerifier,
Locale: locale,
@@ -128,7 +133,7 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest
}, nil
}
func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) (string, error) {
func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) (string, error) {
var authCodeURL string
if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" {
@@ -192,35 +197,11 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)
return "", fmt.Errorf("generating client authentication parameters: %w", err)
}
urlValues := authParams.URLValues(params)
requestBody := stringslib.NewReader(urlValues.Encode())
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody)
payload := authParams.URLValues(params).Encode()
endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint()
body, err := c.oauthPostRequest(ctx, endpoint, payload)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(r)
if err != nil {
return "", fmt.Errorf("performing request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading server response: %w", err)
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
var errorResponse openid.TokenErrorResponse
if err := json.Unmarshal(body, &errorResponse); err != nil {
return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
}
return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
} else if resp.StatusCode >= 500 {
return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
return "", err
}
var pushedAuthorizationResponse openid.PushedAuthorizationResponse