feat: implement PAR for relying party

Fixes #235

Co-authored-by: tronghn <trong.huu.nguyen@nav.no>
This commit is contained in:
Sindre Rødseth Hansen
2025-01-22 13:00:30 +01:00
committed by Trong Huu Nguyen
parent 6be5a1ebe5
commit c442000be4
4 changed files with 133 additions and 14 deletions

View File

@@ -10,6 +10,8 @@ import (
"strings"
"time"
"github.com/nais/wonderwall/pkg/openid"
"github.com/alicebob/miniredis/v2"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -352,33 +354,33 @@ func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request)
func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWriter, r *http.Request) {
if ip.Config.Provider().PushedAuthorizationRequestEndpoint() == "" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("PAR endpoint not supported"))
oauthError(w, fmt.Errorf("PAR endpoint not supported"))
return
}
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("malformed payload?"))
oauthError(w, fmt.Errorf("malformed payload"))
return
}
if r.PostForm.Get("request_uri") != "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("request_uri should not be provided to PAR endpoint"))
oauthError(w, fmt.Errorf("request_uri should not be provided to PAR endpoint"))
return
}
authorizeRequest, err := ip.parseAuthorizationRequest(r.PostForm)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}
err = ip.validateClientAuthentication(w, r, r.PostForm.Get("client_id"))
if err != nil {
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}
@@ -390,7 +392,18 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"request_uri": requestUri, "expires_in": "60"})
json.NewEncoder(w).Encode(openid.PushedAuthorizationResponse{
RequestUri: requestUri,
ExpiresIn: 60,
})
}
func oauthError(w http.ResponseWriter, err error) {
w.Header().Set("content-type", "application/json")
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
Error: "invalid_request",
ErrorDescription: err.Error(),
})
}
func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {

View File

@@ -1,12 +1,16 @@
package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
urllib "net/url"
"slices"
stringslib "strings"
"golang.org/x/oauth2"
@@ -159,15 +163,89 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)
authCodeURL = c.oauth2Config.AuthCodeURL(request.state, opts...)
} else {
// TODO: implement PAR
// generate PAR request
// set all request parameters from authorizationRequest
// set client authentication parameters
params := map[string]string{
"client_id": c.oauth2Config.ClientID,
"code_challenge": oauth2.S256ChallengeFromVerifier(request.codeVerifier),
"code_challenge_method": "S256",
"nonce": request.nonce,
"redirect_uri": request.callbackURL,
"response_mode": "query",
"response_type": "code",
"scope": stringslib.Join(c.oauth2Config.Scopes, " "),
"state": request.state,
}
// perform POST to PAR endpoint
// extract request_uri from response
// generate auth code URL with request_uri and client_id
// set authCodeURL
if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
params["resource"] = resource
}
if len(request.acr) > 0 {
params[LoginParameterMapping[SecurityLevelURLParameter]] = request.acr
}
if len(request.locale) > 0 {
params[LoginParameterMapping[LocaleURLParameter]] = request.locale
}
if len(request.prompt) > 0 {
params[PromptURLParameter] = request.prompt
params[MaxAgeURLParameter] = "0"
}
authParams, err := c.AuthParams()
if err != nil {
return "", fmt.Errorf("generating client authentication parameters: %w", err)
}
urlValues := authParams.URLValues(params)
requestBody := stringslib.NewReader(urlValues.Encode())
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(r)
if err != nil {
return "", fmt.Errorf("performing request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading server response: %w", err)
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
var errorResponse openid.TokenErrorResponse
if err := json.Unmarshal(body, &errorResponse); err != nil {
return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
}
return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
} else if resp.StatusCode >= 500 {
return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
}
var pushedAuthorizationResponse openid.PushedAuthorizationResponse
if err := json.Unmarshal(body, &pushedAuthorizationResponse); err != nil {
return "", fmt.Errorf("unmarshalling token response: %w", err)
}
v := urllib.Values{
"client_id": {c.oauth2Config.ClientID},
"request_uri": {pushedAuthorizationResponse.RequestUri},
}
var buf bytes.Buffer
buf.WriteString(c.oauth2Config.Endpoint.AuthURL)
if stringslib.Contains(c.oauth2Config.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
authCodeURL = buf.String()
}
return authCodeURL, nil

View File

@@ -14,6 +14,28 @@ import (
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin_PushAuthorizationURL(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.SetPushedAuthorizationRequestEndpoint(idp.ProviderServer.URL + "/par")
defer idp.Close()
req := idp.GetRequest(mock.Ingress + "/oauth2/login")
result, err := idp.RelyingPartyHandler.Client.Login(req)
require.NoError(t, err)
parsed, err := url.Parse(result.AuthCodeURL)
assert.NoError(t, err)
query := parsed.Query()
assert.Contains(t, query, "request_uri")
assert.Contains(t, query, "client_id")
assert.NotEmpty(t, query["request_uri"])
assert.Contains(t, query["request_uri"][0], "urn:ietf:params:oauth:request_uri")
assert.ElementsMatch(t, query["client_id"], []string{idp.OpenIDConfig.Client().ClientID()})
}
func TestLogin_URL(t *testing.T) {
type loginURLTest struct {
name string

View File

@@ -15,6 +15,12 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// PushedAuthorizationResponse is the struct representing the HTTP response from authorization servers as defined in RFC 9126, section 2.2.
type PushedAuthorizationResponse struct {
RequestUri string `json:"request_uri"`
ExpiresIn int64 `json:"expires_in"`
}
// TokenErrorResponse is the struct representing the HTTP error response returned from authorization servers as defined in RFC 6749, section 5.2.
type TokenErrorResponse struct {
Error string `json:"error"`