refactor(mock): use oauth error response for all idp errors

This commit is contained in:
Trong Huu Nguyen
2025-01-23 09:02:19 +01:00
parent ade44f0950
commit 837323d728

View File

@@ -204,7 +204,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
authorizeRequest, err = ip.parseAuthorizationRequest(query)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}
@@ -215,14 +215,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
clientId := query.Get("client_id")
if len(clientId) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing client_id"))
oauthError(w, fmt.Errorf("missing client_id"))
return
}
requestUri := query.Get("request_uri")
if len(requestUri) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing request_uri"))
oauthError(w, fmt.Errorf("missing request_uri"))
return
}
@@ -230,14 +230,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
code, ok = ip.PushedAuthorizationRequestCodes[requestUri]
if !ok {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("no matching request_uri for %q", requestUri)))
oauthError(w, fmt.Errorf("no matching request_uri for %q", requestUri))
return
}
authorizeRequest, ok = ip.Codes[code]
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(fmt.Sprintf("no matching code for %q", code)))
oauthError(w, fmt.Errorf("no matching code for %q", code))
return
}
}
@@ -245,7 +245,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
u, err := url.Parse(authorizeRequest.RedirectUri)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("couldn't parse redirect uri"))
oauthError(w, fmt.Errorf("couldn't parse redirect uri: %w", err))
return
}
v := url.Values{}
@@ -361,7 +361,7 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
oauthError(w, fmt.Errorf("malformed payload"))
oauthError(w, fmt.Errorf("malformed payload: %w", err))
return
}
@@ -398,19 +398,11 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri
})
}
func oauthError(w http.ResponseWriter, err error) {
w.Header().Set("content-type", "application/json")
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
Error: "invalid_request",
ErrorDescription: err.Error(),
})
}
func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("malformed payload?"))
oauthError(w, fmt.Errorf("malformed payload"))
return
}
@@ -424,7 +416,7 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
return
default:
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("unsupported grant_type: " + grantType))
oauthError(w, fmt.Errorf("unsupported grant_type: %q", grantType))
return
}
}
@@ -433,46 +425,46 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
code := r.PostForm.Get("code")
if len(code) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing code"))
oauthError(w, fmt.Errorf("missing code"))
return
}
auth, ok := ip.Codes[code]
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("no matching code"))
oauthError(w, fmt.Errorf("no matching code"))
return
}
err := ip.validateClientAuthentication(w, r, auth.ClientID)
if err != nil {
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}
redirect := r.PostForm.Get("redirect_uri")
if len(redirect) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing redirect_uri"))
oauthError(w, fmt.Errorf("missing redirect_uri"))
return
}
if len(auth.RedirectUri) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("redirect_uri was not set in auth code request"))
oauthError(w, fmt.Errorf("redirect_uri was not set in auth code request"))
return
}
if auth.RedirectUri != redirect {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("redirect_uri does not match redirect_uri used to acquire code"))
oauthError(w, fmt.Errorf("redirect_uri does not match redirect_uri used to acquire code"))
return
}
codeVerifier := r.PostForm.Get("code_verifier")
if len(codeVerifier) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing code_verifier"))
oauthError(w, fmt.Errorf("missing code_verifier"))
return
}
@@ -480,7 +472,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
if expectedCodeChallenge != auth.CodeChallenge {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("code_verifier is invalid"))
oauthError(w, fmt.Errorf("code_verifier is invalid"))
return
}
@@ -499,7 +491,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
signedAccessToken, err := ip.signToken(accessToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("could not sign access token: " + err.Error()))
oauthError(w, fmt.Errorf("could not sign access token: %w", err))
return
}
@@ -522,7 +514,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
signedIdToken, err := ip.signToken(idToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("could not sign access token: " + err.Error()))
oauthError(w, fmt.Errorf("could not sign id token: %w", err))
return
}
@@ -552,20 +544,20 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h
refreshToken := r.PostForm.Get("refresh_token")
if len(refreshToken) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing refresh_token"))
oauthError(w, fmt.Errorf("missing refresh_token"))
return
}
data, ok := ip.RefreshTokens[refreshToken]
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("no matching refresh_token"))
oauthError(w, fmt.Errorf("no matching refresh_token"))
return
}
err := ip.validateClientAuthentication(w, r, data.ClientID)
if err != nil {
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}
@@ -582,7 +574,7 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h
signedAccessToken, err := ip.signToken(accessToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("could not sign access token: " + err.Error()))
oauthError(w, fmt.Errorf("could not sign access token: %w", err))
return
}
@@ -665,20 +657,20 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req
if postLogoutRedirectURI == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required 'post_logout_redirect_uri' parameter"))
oauthError(w, fmt.Errorf("missing required 'post_logout_redirect_uri' parameter"))
return
}
if state == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required 'state' parameter"))
oauthError(w, fmt.Errorf("missing required 'state' parameter"))
return
}
u, err := url.Parse(postLogoutRedirectURI)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("couldn't parse post_logout_redirect_uri"))
oauthError(w, fmt.Errorf("couldn't parse post_logout_redirect_uri: %w", err))
return
}
@@ -706,3 +698,11 @@ func (in *relyingPartyServer) SetHandler(handler http.Handler) {
Handler: handler,
}
}
func oauthError(w http.ResponseWriter, err error) {
w.Header().Set("content-type", "application/json")
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
Error: "invalid_request",
ErrorDescription: err.Error(),
})
}