mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
refactor(mock): use oauth error response for all idp errors
This commit is contained in:
@@ -204,7 +204,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
authorizeRequest, err = ip.parseAuthorizationRequest(query)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(err.Error()))
|
||||
oauthError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -215,14 +215,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
clientId := query.Get("client_id")
|
||||
if len(clientId) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing client_id"))
|
||||
oauthError(w, fmt.Errorf("missing client_id"))
|
||||
return
|
||||
}
|
||||
|
||||
requestUri := query.Get("request_uri")
|
||||
if len(requestUri) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing request_uri"))
|
||||
oauthError(w, fmt.Errorf("missing request_uri"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -230,14 +230,14 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
code, ok = ip.PushedAuthorizationRequestCodes[requestUri]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(fmt.Sprintf("no matching request_uri for %q", requestUri)))
|
||||
oauthError(w, fmt.Errorf("no matching request_uri for %q", requestUri))
|
||||
return
|
||||
}
|
||||
|
||||
authorizeRequest, ok = ip.Codes[code]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(fmt.Sprintf("no matching code for %q", code)))
|
||||
oauthError(w, fmt.Errorf("no matching code for %q", code))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -245,7 +245,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
u, err := url.Parse(authorizeRequest.RedirectUri)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse redirect uri"))
|
||||
oauthError(w, fmt.Errorf("couldn't parse redirect uri: %w", err))
|
||||
return
|
||||
}
|
||||
v := url.Values{}
|
||||
@@ -361,7 +361,7 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
oauthError(w, fmt.Errorf("malformed payload"))
|
||||
oauthError(w, fmt.Errorf("malformed payload: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -398,19 +398,11 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri
|
||||
})
|
||||
}
|
||||
|
||||
func oauthError(w http.ResponseWriter, err error) {
|
||||
w.Header().Set("content-type", "application/json")
|
||||
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
|
||||
Error: "invalid_request",
|
||||
ErrorDescription: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("malformed payload?"))
|
||||
oauthError(w, fmt.Errorf("malformed payload"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -424,7 +416,7 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
default:
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("unsupported grant_type: " + grantType))
|
||||
oauthError(w, fmt.Errorf("unsupported grant_type: %q", grantType))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -433,46 +425,46 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
|
||||
code := r.PostForm.Get("code")
|
||||
if len(code) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing code"))
|
||||
oauthError(w, fmt.Errorf("missing code"))
|
||||
return
|
||||
}
|
||||
|
||||
auth, ok := ip.Codes[code]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("no matching code"))
|
||||
oauthError(w, fmt.Errorf("no matching code"))
|
||||
return
|
||||
}
|
||||
|
||||
err := ip.validateClientAuthentication(w, r, auth.ClientID)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
oauthError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
redirect := r.PostForm.Get("redirect_uri")
|
||||
if len(redirect) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing redirect_uri"))
|
||||
oauthError(w, fmt.Errorf("missing redirect_uri"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(auth.RedirectUri) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("redirect_uri was not set in auth code request"))
|
||||
oauthError(w, fmt.Errorf("redirect_uri was not set in auth code request"))
|
||||
return
|
||||
}
|
||||
|
||||
if auth.RedirectUri != redirect {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("redirect_uri does not match redirect_uri used to acquire code"))
|
||||
oauthError(w, fmt.Errorf("redirect_uri does not match redirect_uri used to acquire code"))
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier := r.PostForm.Get("code_verifier")
|
||||
if len(codeVerifier) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing code_verifier"))
|
||||
oauthError(w, fmt.Errorf("missing code_verifier"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -480,7 +472,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
|
||||
|
||||
if expectedCodeChallenge != auth.CodeChallenge {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("code_verifier is invalid"))
|
||||
oauthError(w, fmt.Errorf("code_verifier is invalid"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -499,7 +491,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
|
||||
signedAccessToken, err := ip.signToken(accessToken)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not sign access token: " + err.Error()))
|
||||
oauthError(w, fmt.Errorf("could not sign access token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -522,7 +514,7 @@ func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http
|
||||
signedIdToken, err := ip.signToken(idToken)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not sign access token: " + err.Error()))
|
||||
oauthError(w, fmt.Errorf("could not sign id token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -552,20 +544,20 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h
|
||||
refreshToken := r.PostForm.Get("refresh_token")
|
||||
if len(refreshToken) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing refresh_token"))
|
||||
oauthError(w, fmt.Errorf("missing refresh_token"))
|
||||
return
|
||||
}
|
||||
|
||||
data, ok := ip.RefreshTokens[refreshToken]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("no matching refresh_token"))
|
||||
oauthError(w, fmt.Errorf("no matching refresh_token"))
|
||||
return
|
||||
}
|
||||
|
||||
err := ip.validateClientAuthentication(w, r, data.ClientID)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
oauthError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -582,7 +574,7 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h
|
||||
signedAccessToken, err := ip.signToken(accessToken)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not sign access token: " + err.Error()))
|
||||
oauthError(w, fmt.Errorf("could not sign access token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -665,20 +657,20 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req
|
||||
|
||||
if postLogoutRedirectURI == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required 'post_logout_redirect_uri' parameter"))
|
||||
oauthError(w, fmt.Errorf("missing required 'post_logout_redirect_uri' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required 'state' parameter"))
|
||||
oauthError(w, fmt.Errorf("missing required 'state' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse post_logout_redirect_uri"))
|
||||
oauthError(w, fmt.Errorf("couldn't parse post_logout_redirect_uri: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -706,3 +698,11 @@ func (in *relyingPartyServer) SetHandler(handler http.Handler) {
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func oauthError(w http.ResponseWriter, err error) {
|
||||
w.Header().Set("content-type", "application/json")
|
||||
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
|
||||
Error: "invalid_request",
|
||||
ErrorDescription: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user