diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index f8dd02f..7b6f0db 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -94,37 +94,16 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid return nil, err } - requestBody := strings.NewReader(params.URLValues(map[string]string{ + payload := params.URLValues(map[string]string{ "grant_type": "refresh_token", "refresh_token": refreshToken, "client_id": c.cfg.Client().ClientID(), - }).Encode()) + }).Encode() - r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), requestBody) + endpoint := c.cfg.Provider().TokenEndpoint() + body, err := c.oauthPostRequest(ctx, endpoint, payload) if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := c.httpClient.Do(r) - if err != nil { - return nil, fmt.Errorf("performing request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("reading server response: %w", err) - } - - if resp.StatusCode >= 400 && resp.StatusCode < 500 { - var errorResponse openid.TokenErrorResponse - if err := json.Unmarshal(body, &errorResponse); err != nil { - return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err) - } - return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription) - } else if resp.StatusCode >= 500 { - return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body) + return nil, err } var tokenResponse openid.TokenResponse @@ -183,3 +162,34 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { return string(encoded), nil } + +func (c *Client) oauthPostRequest(ctx context.Context, endpoint, payload string) ([]byte, error) { + r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(payload)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(r) + if err != nil { + return nil, fmt.Errorf("performing request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading server response: %w", err) + } + + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + var errorResponse openid.TokenErrorResponse + if err := json.Unmarshal(body, &errorResponse); err != nil { + return nil, fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err) + } + return nil, fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription) + } else if resp.StatusCode >= 500 { + return nil, fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, payload) + } + + return body, nil +} diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 4d1415a..7a317b6 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" urllib "net/url" "slices" @@ -45,9 +44,9 @@ var ( ) type Login struct { - AuthCodeURL string authorizationRequest - Cookie *openid.LoginCookie + AuthCodeURL string + Cookie openid.LoginCookie } type authorizationRequest struct { @@ -60,6 +59,16 @@ type authorizationRequest struct { State string } +func (a authorizationRequest) ToCookie() openid.LoginCookie { + return openid.LoginCookie{ + Acr: a.Acr, + CodeVerifier: a.CodeVerifier, + Nonce: a.Nonce, + State: a.State, + RedirectURI: a.CallbackURL, + } +} + func (c *Client) Login(r *http.Request) (*Login, error) { request, err := c.newAuthorizationRequest(r) if err != nil { @@ -73,52 +82,48 @@ func (c *Client) Login(r *http.Request) (*Login, error) { return &Login{ AuthCodeURL: authCodeURL, - authorizationRequest: *request, - Cookie: &openid.LoginCookie{ - Acr: request.Acr, - CodeVerifier: request.CodeVerifier, - State: request.State, - Nonce: request.Nonce, - RedirectURI: request.CallbackURL, - }, + authorizationRequest: request, + Cookie: request.ToCookie(), }, nil } -func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest, error) { +func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, error) { + var req authorizationRequest + callbackURL, err := url.LoginCallback(r) if err != nil { - return nil, fmt.Errorf("generating callback url: %w", err) + return req, fmt.Errorf("generating callback url: %w", err) } - acr, err := getAcrParam(c, r) + acrParam, err := getAcrParam(c, r) if err != nil { - return nil, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err) + return req, fmt.Errorf("%w: %w", ErrInvalidSecurityLevel, err) } locale, err := getLocaleParam(c, r) if err != nil { - return nil, fmt.Errorf("%w: %w", ErrInvalidLocale, err) + return req, fmt.Errorf("%w: %w", ErrInvalidLocale, err) } prompt, err := getPromptParam(r) if err != nil { - return nil, fmt.Errorf("%w: %w", ErrInvalidPrompt, err) + return req, fmt.Errorf("%w: %w", ErrInvalidPrompt, err) } nonce, err := strings.GenerateBase64(32) if err != nil { - return nil, fmt.Errorf("creating nonce: %w", err) + return req, fmt.Errorf("creating nonce: %w", err) } state, err := strings.GenerateBase64(32) if err != nil { - return nil, fmt.Errorf("creating state: %w", err) + return req, fmt.Errorf("creating state: %w", err) } codeVerifier := oauth2.GenerateVerifier() - return &authorizationRequest{ - Acr: acr, + return authorizationRequest{ + Acr: acrParam, CallbackURL: callbackURL, CodeVerifier: codeVerifier, Locale: locale, @@ -128,7 +133,7 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (*authorizationRequest }, nil } -func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) (string, error) { +func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) (string, error) { var authCodeURL string if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" { @@ -192,35 +197,11 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) return "", fmt.Errorf("generating client authentication parameters: %w", err) } - urlValues := authParams.URLValues(params) - - requestBody := stringslib.NewReader(urlValues.Encode()) - - r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody) + payload := authParams.URLValues(params).Encode() + endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint() + body, err := c.oauthPostRequest(ctx, endpoint, payload) if err != nil { - return "", fmt.Errorf("creating request: %w", err) - } - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := c.httpClient.Do(r) - if err != nil { - return "", fmt.Errorf("performing request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("reading server response: %w", err) - } - - if resp.StatusCode >= 400 && resp.StatusCode < 500 { - var errorResponse openid.TokenErrorResponse - if err := json.Unmarshal(body, &errorResponse); err != nil { - return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err) - } - return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription) - } else if resp.StatusCode >= 500 { - return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body) + return "", err } var pushedAuthorizationResponse openid.PushedAuthorizationResponse