Files
wonderwall/pkg/mock/handler.go
2021-10-16 10:50:22 +02:00

212 lines
5.4 KiB
Go

package mock
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)
type identityProviderHandler struct {
Codes map[string]authorizeRequest
Provider TestProvider
Sessions map[string]string
}
func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler {
return &identityProviderHandler{
Codes: make(map[string]authorizeRequest),
Provider: provider,
Sessions: make(map[string]string),
}
}
type authorizeRequest struct {
AcrLevel string
CodeChallenge string
Locale string
Nonce string
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
privateJwkSet := *ip.Provider.PrivateJwkSet()
signer, ok := privateJwkSet.Get(0)
if !ok {
return "", fmt.Errorf("could not get signer")
}
signedToken, err := jwt.Sign(token, jwa.RS256, signer)
if err != nil {
return "", err
}
return string(signedToken), nil
}
func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
state := query.Get("state")
redirect := query.Get("redirect_uri")
acrLevel := query.Get("acr_values")
codeChallenge := query.Get("code_challenge")
locale := query.Get("ui_locales")
nonce := query.Get("nonce")
if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required fields"))
return
}
code := uuid.New().String()
ip.Codes[code] = authorizeRequest{
AcrLevel: acrLevel,
CodeChallenge: codeChallenge,
Locale: locale,
Nonce: nonce,
}
u, err := url.Parse(redirect)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("couldn't parse redirect uri"))
return
}
v := url.Values{}
v.Set("code", code)
v.Set("state", state)
u.RawQuery = v.Encode()
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}
func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, _ *http.Request) {
json.NewEncoder(w).Encode(ip.Provider.GetPublicJwkSet())
}
func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("malformed payload?"))
return
}
code := r.PostForm.Get("code")
if len(code) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing code"))
return
}
auth, ok := ip.Codes[code]
if !ok {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("no matching code"))
return
}
expires := int64(1200)
sub := uuid.New().String()
sid := uuid.New().String()
clientID := r.PostForm.Get("client_id")
if len(clientID) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing client_id"))
return
}
clientAssertion := r.PostForm.Get("client_assertion")
if len(clientID) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing client_assertion"))
return
}
clientJwk := ip.Provider.GetClientConfiguration().GetClientJWK()
clientJwkSet := jwk.NewSet()
clientJwkSet.Add(clientJwk)
publicClientJwkSet, err := jwk.PublicSetOf(clientJwkSet)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("failed to create public client jwk set"))
return
}
opts := []jwt.ParseOption{
jwt.WithValidate(true),
jwt.WithKeySet(publicClientJwkSet),
jwt.WithIssuer(ip.Provider.GetClientConfiguration().GetClientID()),
jwt.WithSubject(ip.Provider.GetClientConfiguration().GetClientID()),
jwt.WithClaimValue("scope", ip.Provider.GetClientConfiguration().GetScopes().String()),
jwt.WithAudience(ip.Provider.GetOpenIDConfiguration().Issuer),
}
_, err = jwt.Parse([]byte(clientAssertion), opts...)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(fmt.Sprintf("invalid client assertion: %+v", err)))
return
}
accessToken := jwt.New()
accessToken.Set("sub", sub)
accessToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer)
accessToken.Set("acr", auth.AcrLevel)
accessToken.Set("iat", time.Now().Unix())
accessToken.Set("exp", time.Now().Unix()+expires)
signedAccessToken, err := ip.signToken(accessToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("could not sign access token: " + err.Error()))
return
}
idToken := jwt.New()
idToken.Set("sub", sub)
idToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer)
idToken.Set("aud", clientID)
idToken.Set("locale", auth.Locale)
idToken.Set("nonce", auth.Nonce)
idToken.Set("acr", auth.AcrLevel)
idToken.Set("iat", time.Now().Unix())
idToken.Set("exp", time.Now().Unix()+expires)
idToken.Set("sid", sid)
signedIdToken, err := ip.signToken(idToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("could not sign access token: " + err.Error()))
return
}
ip.Sessions[sid] = clientID
token := &tokenResponse{
AccessToken: signedAccessToken,
TokenType: "Bearer",
IDToken: signedIdToken,
ExpiresIn: expires,
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(token)
}