From 837323d72861eabb16a2497a05c8520ccfe8f784 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 23 Jan 2025 09:02:19 +0100 Subject: [PATCH] refactor(mock): use oauth error response for all idp errors --- pkg/mock/openid.go | 68 +++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 39d04ba..d1ba60d 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -204,7 +204,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ authorizeRequest, err = ip.parseAuthorizationRequest(query) if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err.Error())) + oauthError(w, err) return } @@ -215,14 +215,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ clientId := query.Get("client_id") if len(clientId) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing client_id")) + oauthError(w, fmt.Errorf("missing client_id")) return } requestUri := query.Get("request_uri") if len(requestUri) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing request_uri")) + oauthError(w, fmt.Errorf("missing request_uri")) return } @@ -230,14 +230,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ code, ok = ip.PushedAuthorizationRequestCodes[requestUri] if !ok { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(fmt.Sprintf("no matching request_uri for %q", requestUri))) + oauthError(w, fmt.Errorf("no matching request_uri for %q", requestUri)) return } authorizeRequest, ok = ip.Codes[code] if !ok { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(fmt.Sprintf("no matching code for %q", code))) + oauthError(w, fmt.Errorf("no matching code for %q", code)) return } } @@ -245,7 +245,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ u, err := url.Parse(authorizeRequest.RedirectUri) if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("couldn't parse redirect uri")) + oauthError(w, fmt.Errorf("couldn't parse redirect uri: %w", err)) return } v := url.Values{} @@ -361,7 +361,7 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) - oauthError(w, fmt.Errorf("malformed payload")) + oauthError(w, fmt.Errorf("malformed payload: %w", err)) return } @@ -398,19 +398,11 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri }) } -func oauthError(w http.ResponseWriter, err error) { - w.Header().Set("content-type", "application/json") - json.NewEncoder(w).Encode(openid.TokenErrorResponse{ - Error: "invalid_request", - ErrorDescription: err.Error(), - }) -} - func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("malformed payload?")) + oauthError(w, fmt.Errorf("malformed payload")) return } @@ -424,7 +416,7 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return default: w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("unsupported grant_type: " + grantType)) + oauthError(w, fmt.Errorf("unsupported grant_type: %q", grantType)) return } } @@ -433,46 +425,46 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http code := r.PostForm.Get("code") if len(code) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing code")) + oauthError(w, fmt.Errorf("missing code")) return } auth, ok := ip.Codes[code] if !ok { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("no matching code")) + oauthError(w, fmt.Errorf("no matching code")) return } err := ip.validateClientAuthentication(w, r, auth.ClientID) if err != nil { - w.Write([]byte(err.Error())) + oauthError(w, err) return } redirect := r.PostForm.Get("redirect_uri") if len(redirect) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing redirect_uri")) + oauthError(w, fmt.Errorf("missing redirect_uri")) return } if len(auth.RedirectUri) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("redirect_uri was not set in auth code request")) + oauthError(w, fmt.Errorf("redirect_uri was not set in auth code request")) return } if auth.RedirectUri != redirect { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("redirect_uri does not match redirect_uri used to acquire code")) + oauthError(w, fmt.Errorf("redirect_uri does not match redirect_uri used to acquire code")) return } codeVerifier := r.PostForm.Get("code_verifier") if len(codeVerifier) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing code_verifier")) + oauthError(w, fmt.Errorf("missing code_verifier")) return } @@ -480,7 +472,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http if expectedCodeChallenge != auth.CodeChallenge { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("code_verifier is invalid")) + oauthError(w, fmt.Errorf("code_verifier is invalid")) return } @@ -499,7 +491,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http signedAccessToken, err := ip.signToken(accessToken) if err != nil { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("could not sign access token: " + err.Error())) + oauthError(w, fmt.Errorf("could not sign access token: %w", err)) return } @@ -522,7 +514,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http signedIdToken, err := ip.signToken(idToken) if err != nil { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("could not sign access token: " + err.Error())) + oauthError(w, fmt.Errorf("could not sign id token: %w", err)) return } @@ -552,20 +544,20 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h refreshToken := r.PostForm.Get("refresh_token") if len(refreshToken) == 0 { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing refresh_token")) + oauthError(w, fmt.Errorf("missing refresh_token")) return } data, ok := ip.RefreshTokens[refreshToken] if !ok { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("no matching refresh_token")) + oauthError(w, fmt.Errorf("no matching refresh_token")) return } err := ip.validateClientAuthentication(w, r, data.ClientID) if err != nil { - w.Write([]byte(err.Error())) + oauthError(w, err) return } @@ -582,7 +574,7 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h signedAccessToken, err := ip.signToken(accessToken) if err != nil { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("could not sign access token: " + err.Error())) + oauthError(w, fmt.Errorf("could not sign access token: %w", err)) return } @@ -665,20 +657,20 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req if postLogoutRedirectURI == "" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing required 'post_logout_redirect_uri' parameter")) + oauthError(w, fmt.Errorf("missing required 'post_logout_redirect_uri' parameter")) return } if state == "" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing required 'state' parameter")) + oauthError(w, fmt.Errorf("missing required 'state' parameter")) return } u, err := url.Parse(postLogoutRedirectURI) if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("couldn't parse post_logout_redirect_uri")) + oauthError(w, fmt.Errorf("couldn't parse post_logout_redirect_uri: %w", err)) return } @@ -706,3 +698,11 @@ func (in *relyingPartyServer) SetHandler(handler http.Handler) { Handler: handler, } } + +func oauthError(w http.ResponseWriter, err error) { + w.Header().Set("content-type", "application/json") + json.NewEncoder(w).Encode(openid.TokenErrorResponse{ + Error: "invalid_request", + ErrorDescription: err.Error(), + }) +}