diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 1636a8c..eb237c9 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -12,6 +12,7 @@ import ( "github.com/nais/wonderwall/pkg/logging" "github.com/nais/wonderwall/pkg/router" log "github.com/sirupsen/logrus" + "golang.org/x/oauth2" ) var maskedConfig = []string{ @@ -55,8 +56,19 @@ func run() error { return err } + oauthConfig := oauth2.Config{ + ClientID: cfg.IDPorten.ClientID, + Endpoint: oauth2.Endpoint{ + AuthURL: cfg.IDPorten.WellKnown.AuthorizationEndpoint, + TokenURL: cfg.IDPorten.WellKnown.TokenEndpoint, + }, + RedirectURL: cfg.IDPorten.RedirectURI, + Scopes: scopes, + } + handler := &router.Handler{ Config: cfg.IDPorten, + OauthConfig: oauthConfig, RelyingParty: relyingParty, } diff --git a/pkg/router/router.go b/pkg/router/router.go index 105dc56..d21fe21 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "gopkg.in/square/go-jose.v2" "io" "net/http" "net/url" @@ -16,6 +15,8 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/go-chi/chi" "github.com/nais/wonderwall/pkg/config" + "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" ) const ( @@ -27,18 +28,19 @@ const ( type Handler struct { Config config.IDPorten + OauthConfig oauth2.Config RelyingParty rp.RelyingParty } type loginParams struct { cookies []*http.Cookie - state []byte - codeVerifier []byte + state string + codeVerifier string url string } func (h *Handler) LoginURL() (*loginParams, error) { - codeVerifier := make([]byte, 32) + codeVerifier := make([]byte, 64) nonce := make([]byte, 32) state := make([]byte, 32) @@ -59,7 +61,9 @@ func (h *Handler) LoginURL() (*loginParams, error) { return nil, fmt.Errorf("failed to create code verifier: %w", err) } + codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier)) hasher := sha256.New() + hasher.Write(codeVerifier) codeVerifierHash := hasher.Sum(nil) u, err := url.Parse(h.Config.WellKnown.AuthorizationEndpoint) @@ -81,8 +85,8 @@ func (h *Handler) LoginURL() (*loginParams, error) { u.RawQuery = v.Encode() return &loginParams{ - state: state, - codeVerifier: codeVerifier, + state: base64.RawURLEncoding.EncodeToString(state), + codeVerifier: base64.RawURLEncoding.EncodeToString(codeVerifier), url: u.String(), }, nil } @@ -96,14 +100,14 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "state", - Value: string(params.state), + Value: params.state, Expires: time.Now().Add(10 * time.Minute), Secure: true, SameSite: http.SameSiteLaxMode, }) http.SetCookie(w, &http.Cookie{ Name: "code_verifier", - Value: string(params.codeVerifier), + Value: params.codeVerifier, Expires: time.Now().Add(10 * time.Minute), Secure: true, SameSite: http.SameSiteLaxMode, @@ -112,24 +116,84 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, params.url, http.StatusTemporaryRedirect) } -func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { +func (h *Handler) SignedJWTProfileAssertion(expiration time.Duration) (string, error) { key := &jose.JSONWebKey{} err := json.Unmarshal([]byte(h.Config.ClientJWK), key) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + return "", err + } + signingKey := jose.SigningKey{ + Algorithm: jose.RS256, + Key: key, + } + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", err + } + + iat := time.Now() + exp := iat.Add(expiration) + jwtRequest := oidc.JWTTokenRequest{ + Issuer: h.Config.ClientID, + Subject: h.Config.ClientID, + Audience: []string{h.Config.WellKnown.Issuer}, + ExpiresAt: oidc.Time(exp), + IssuedAt: oidc.Time(iat), + } + + payload, err := json.Marshal(jwtRequest) + if err != nil { + return "", err + } + result, err := signer.Sign(payload) + if err != nil { + return "", err + } + return result.CompactSerialize() +} + +func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { + state, err := r.Cookie("state") + if err != nil { + w.WriteHeader(http.StatusUnauthorized) return } - marshalToken := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) { - data, err := json.Marshal(tokens) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Write(data) + params := r.URL.Query() + if params.Get("error") != "" { + w.WriteHeader(http.StatusUnauthorized) + return } - rp.CodeExchangeHandler(marshalToken, h.RelyingParty)(w, r) + if params.Get("state") != state.Value { + w.WriteHeader(http.StatusUnauthorized) + return + } + + codeVerifier, err := r.Cookie("code_verifier") + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + assertion, err := h.SignedJWTProfileAssertion(time.Minute * 10) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", codeVerifier.Value), + oauth2.SetAuthURLParam("client_assertion", assertion), + oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), + } + + tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Write([]byte(tokens.AccessToken)) } func New(handler *Handler) chi.Router {