diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 7a5a892..5a8760a 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -143,6 +143,7 @@ type AuthorizeRequest struct { CodeChallenge string Locale string Nonce string + RedirectUri string SessionID string } @@ -263,6 +264,7 @@ func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ CodeChallenge: codeChallenge, Locale: locale, Nonce: nonce, + RedirectUri: redirect, SessionID: sessionID, } @@ -332,6 +334,25 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } + redirect := r.PostForm.Get("redirect_uri") + if len(redirect) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing redirect_uri")) + return + } + + if len(auth.RedirectUri) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("redirect_uri was not set in auth code request")) + return + } + + if auth.RedirectUri != redirect { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("redirect_uri does not match redirect_uri used to acquire code")) + return + } + clientJwk := ip.Config.Client().ClientJWK() clientJwkSet := jwk.NewSet() clientJwkSet.AddKey(clientJwk) diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 9faed8e..213b9f7 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -63,12 +63,12 @@ func NewLogin(c Client, r *http.Request, loginstatus loginstatus.Loginstatus) (L return nil, fmt.Errorf("generating auth code url: %w", err) } - redirect := urlpkg.CanonicalRedirect(r) - cookie := params.cookie(redirect) + referer := urlpkg.CanonicalRedirect(r) + cookie := params.cookie(referer, callbackURL) return &login{ authCodeURL: url, - canonicalRedirect: redirect, + canonicalRedirect: referer, cookie: cookie, params: params, }, nil @@ -170,12 +170,13 @@ func (in *loginParameters) authCodeURL(r *http.Request, callbackURL string, logi return authCodeUrl, nil } -func (in *loginParameters) cookie(redirect string) *openid.LoginCookie { +func (in *loginParameters) cookie(referer, redirectURI string) *openid.LoginCookie { return &openid.LoginCookie{ State: in.State, Nonce: in.Nonce, CodeVerifier: in.CodeVerifier, - Referer: redirect, + Referer: referer, + RedirectURI: redirectURI, } } diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index 084999a..b9c31a6 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -80,6 +80,7 @@ func (in *loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro oauth2.SetAuthURLParam(openid.CodeVerifier, in.cookie.CodeVerifier), oauth2.SetAuthURLParam(openid.ClientAssertion, clientAssertion), oauth2.SetAuthURLParam(openid.ClientAssertionType, ClientAssertionTypeJwtBearer), + oauth2.SetAuthURLParam(openid.RedirectURI, in.cookie.RedirectURI), } code := in.requestParams.Get(openid.Code) diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index bab3264..2425235 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -2,13 +2,12 @@ package client_test import ( "context" - "net/http" - "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" + urlpkg "github.com/nais/wonderwall/pkg/handler/url" "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" @@ -90,6 +89,16 @@ func TestLoginCallback_RedeemTokens(t *testing.T) { assert.Nil(t, tokens) }) + t.Run("redirect_uri mismatch", func(t *testing.T) { + idp, lc := newLoginCallback(t, url) + defer idp.Close() + idp.ProviderHandler.Codes["some-code"].RedirectUri = "http://not-wonderwall/oauth2/callback" + + tokens, err := lc.RedeemTokens(context.Background()) + assert.Error(t, err) + assert.Nil(t, tokens) + }) + t.Run("unexpected audience", func(t *testing.T) { idp, lc := newLoginCallback(t, url) defer idp.Close() @@ -102,26 +111,31 @@ func TestLoginCallback_RedeemTokens(t *testing.T) { } func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, client.LoginCallback) { - cookie := &openid.LoginCookie{ - State: "some-state", - Nonce: "some-nonce", - CodeVerifier: "some-verifier", - } - - req := httptest.NewRequest(http.MethodGet, url, nil) - idp := mock.NewIdentityProvider(mock.Config()) + idp.SetIngresses(mock.Ingress) + req := idp.GetRequest(url) cfg := idp.OpenIDConfig + redirect, err := urlpkg.CallbackURL(req) + assert.NoError(t, err) + idp.ProviderHandler.Codes = map[string]*mock.AuthorizeRequest{ "some-code": { ClientID: idp.OpenIDConfig.Client().ClientID(), CodeChallenge: client.CodeChallenge("some-verifier"), Nonce: "some-nonce", + RedirectUri: redirect, }, } + cookie := &openid.LoginCookie{ + State: "some-state", + Nonce: "some-nonce", + CodeVerifier: "some-verifier", + RedirectURI: redirect, + } + loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie) assert.NoError(t, err) diff --git a/pkg/openid/cookies.go b/pkg/openid/cookies.go index 71167e2..939ee2c 100644 --- a/pkg/openid/cookies.go +++ b/pkg/openid/cookies.go @@ -5,4 +5,5 @@ type LoginCookie struct { Nonce string `json:"nonce"` CodeVerifier string `json:"code_verifier"` Referer string `json:"referer"` + RedirectURI string `json:"redirect_uri"` }