test(openid/client): add missing tests

This commit is contained in:
Trong Huu Nguyen
2022-07-11 15:32:41 +02:00
parent b937c64dd6
commit c321cff4eb
6 changed files with 419 additions and 17 deletions

View File

@@ -29,6 +29,7 @@ type IdentityProvider struct {
Cfg *config.Config
OpenIDConfig Configuration
Provider TestProvider
ProviderHandler *IdentityProviderHandler
ProviderServer *httptest.Server
RelyingPartyHandler *router.Handler
RelyingPartyServer *httptest.Server
@@ -87,11 +88,12 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider {
RelyingPartyServer: rpServer,
OpenIDConfig: openidConfig,
Provider: provider,
ProviderHandler: handler,
ProviderServer: server,
}
}
func identityProviderRouter(ip *identityProviderHandler) chi.Router {
func identityProviderRouter(ip *IdentityProviderHandler) chi.Router {
r := chi.NewRouter()
r.Get("/authorize", ip.Authorize)
r.Post("/token", ip.Token)
@@ -100,23 +102,23 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router {
return r
}
type identityProviderHandler struct {
Codes map[string]authorizeRequest
type IdentityProviderHandler struct {
Codes map[string]AuthorizeRequest
Config openidconfig.Config
Provider TestProvider
Sessions map[string]string
}
func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *identityProviderHandler {
return &identityProviderHandler{
Codes: make(map[string]authorizeRequest),
func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *IdentityProviderHandler {
return &IdentityProviderHandler{
Codes: make(map[string]AuthorizeRequest),
Config: cfg,
Provider: provider,
Sessions: make(map[string]string),
}
}
type authorizeRequest struct {
type AuthorizeRequest struct {
AcrLevel string
CodeChallenge string
Locale string
@@ -131,7 +133,7 @@ type tokenResponse struct {
IDToken string `json:"id_token"`
}
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
func (ip *IdentityProviderHandler) signToken(token jwt.Token) (string, error) {
privateJwkSet := *ip.Provider.PrivateJwkSet()
signer, ok := privateJwkSet.Key(0)
if !ok {
@@ -146,7 +148,7 @@ func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
return string(signedToken), nil
}
func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
state := query.Get("state")
redirect := query.Get("redirect_uri")
@@ -162,7 +164,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
}
code := uuid.New().String()
ip.Codes[code] = authorizeRequest{
ip.Codes[code] = AuthorizeRequest{
AcrLevel: acrLevel,
CodeChallenge: codeChallenge,
Locale: locale,
@@ -188,12 +190,12 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}
func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) {
func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) {
jwks, _ := ip.Provider.GetPublicJwkSet(r.Context())
json.NewEncoder(w).Encode(jwks)
}
func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
@@ -302,10 +304,11 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
}
token := &tokenResponse{
AccessToken: signedAccessToken,
TokenType: "Bearer",
IDToken: signedIdToken,
ExpiresIn: expires,
AccessToken: signedAccessToken,
TokenType: "Bearer",
IDToken: signedIdToken,
RefreshToken: code + "some-refresh-token",
ExpiresIn: expires,
}
w.Header().Set("content-type", "application/json")
@@ -313,7 +316,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
json.NewEncoder(w).Encode(token)
}
func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
state := query.Get("state")
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")

View File

@@ -48,3 +48,11 @@ func TestMakeAssertion(t *testing.T) {
assert.True(t, assertion.Expiration().After(time.Now()))
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
}
func newTestClientWithConfig(config mock.Configuration) client.Client {
return client.NewClient(config)
}
func newTestClient() client.Client {
return newTestClientWithConfig(mock.NewTestConfiguration(mock.Config()))
}

View File

@@ -0,0 +1,170 @@
package client_test
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
)
func TestLoginCallback_StateMismatchError(t *testing.T) {
cookie := &openid.LoginCookie{
State: "some-state",
}
t.Run("invalid state", func(t *testing.T) {
url := "http://wonderwall/oauth2/callback?state=some-other-state"
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
err := lc.StateMismatchError()
assert.Error(t, err)
})
t.Run("missing state", func(t *testing.T) {
url := "http://wonderwall/oauth2/callback"
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
err := lc.StateMismatchError()
assert.Error(t, err)
})
}
func TestLoginCallback_IdentityProviderError(t *testing.T) {
cookie := &openid.LoginCookie{
State: "some-state",
}
url := "http://wonderwall/oauth2/callback?error=invalid_client&error_description=client%20authenticaion%20failed"
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
err := lc.IdentityProviderError()
assert.Error(t, err)
}
func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
t.Run("valid code", func(t *testing.T) {
cookie := &openid.LoginCookie{}
url := "http://wonderwall/oauth2/callback?code=some-code"
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{
"some-code": {},
}
tokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, tokens)
assert.NotEmpty(t, tokens.AccessToken)
assert.NotEmpty(t, tokens.RefreshToken)
assert.NotEmpty(t, tokens.Extra("id_token"))
assert.NotEmpty(t, tokens.TokenType)
assert.NotEmpty(t, tokens.Expiry)
assert.Equal(t, "Bearer", tokens.TokenType)
assert.True(t, time.Now().Before(tokens.Expiry))
assert.True(t, tokens.Expiry.Before(time.Now().Add(time.Hour)))
})
t.Run("invalid code", func(t *testing.T) {
cookie := &openid.LoginCookie{}
url := "http://wonderwall/oauth2/callback?code=some-code"
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{
"some-other-code": {},
"another-code": {},
}
tokens, err := lc.ExchangeAuthCode(context.Background())
assert.Error(t, err)
assert.Nil(t, tokens)
})
}
func TestLoginCallback_ProcessTokens(t *testing.T) {
cookie := &openid.LoginCookie{
State: "some-state",
Nonce: "some-nonce",
}
url := "http://wonderwall/oauth2/callback?code=some-code"
t.Run("happy path", func(t *testing.T) {
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
Nonce: "some-nonce",
}
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
assert.NoError(t, err)
assert.NotNil(t, tokens)
})
t.Run("nonce mismatch", func(t *testing.T) {
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
Nonce: "some-other-nonce",
}
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
assert.Error(t, err)
assert.Nil(t, tokens)
})
t.Run("unexpected audience", func(t *testing.T) {
idp, lc := newLoginCallback(t, url, cookie)
defer idp.Close()
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
Nonce: "some-nonce",
}
idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id"
rawTokens, err := lc.ExchangeAuthCode(context.Background())
assert.NoError(t, err)
assert.NotNil(t, rawTokens)
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
assert.Error(t, err)
assert.Nil(t, tokens)
})
}
func newLoginCallback(t *testing.T, url string, cookie *openid.LoginCookie) (mock.IdentityProvider, client.LoginCallback) {
req, err := http.NewRequest("GET", url, nil)
assert.NoError(t, err)
idp := mock.NewIdentityProvider(mock.Config())
cfg := idp.OpenIDConfig
cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI
cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI
cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint
loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie)
assert.NoError(t, err)
return idp, loginCallback
}

View File

@@ -0,0 +1,81 @@
package client_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
)
func TestLogoutCallback_ValidateRequest(t *testing.T) {
t.Run("nil cookie", func(t *testing.T) {
_, err := newLogoutCallback(t, "http://localhost/oauth2/logout/callback?state=some-state", nil)
assert.Error(t, err)
})
for _, test := range []struct {
name string
url string
cookie *openid.LogoutCookie
wantErr bool
}{
{
name: "valid request",
url: "http://localhost/oauth2/logout/callback?state=some-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: false,
},
{
name: "empty redirect",
url: "http://localhost/oauth2/logout/callback?state=some-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "",
},
wantErr: true,
},
{
name: "empty state",
url: "http://localhost/oauth2/logout/callback",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: true,
},
{
name: "state mismatch",
url: "http://localhost/oauth2/logout/callback?state=some-other-state",
cookie: &openid.LogoutCookie{
State: "some-state",
RedirectTo: "http://some-url",
},
wantErr: true,
},
} {
t.Run(test.name, func(t *testing.T) {
lc, err := newLogoutCallback(t, test.url, test.cookie)
assert.NoError(t, err)
err = lc.ValidateRequest()
if test.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func newLogoutCallback(t *testing.T, url string, cookie *openid.LogoutCookie) (client.LogoutCallback, error) {
req, err := http.NewRequest("GET", url, nil)
assert.NoError(t, err)
return newTestClient().LogoutCallback(req, cookie)
}

View File

@@ -0,0 +1,35 @@
package client_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/openid/client"
)
func TestLogoutFrontchannel_Sid(t *testing.T) {
t.Run("missing sid parameter in request", func(t *testing.T) {
url := "http://localhost/oauth2/logout/frontchannel"
lf := newLogoutFrontchannel(t, url)
assert.Empty(t, lf.Sid())
assert.True(t, lf.MissingSidParameter())
})
t.Run("has sid parameter in request", func(t *testing.T) {
url := "http://localhost/oauth2/logout/frontchannel?sid=some-session-id"
lf := newLogoutFrontchannel(t, url)
assert.Equal(t, "some-session-id", lf.Sid())
assert.False(t, lf.MissingSidParameter())
})
}
func newLogoutFrontchannel(t *testing.T, url string) client.LogoutFrontchannel {
req, err := http.NewRequest("GET", url, nil)
assert.NoError(t, err)
return newTestClient().LogoutFrontchannel(req)
}

View File

@@ -0,0 +1,105 @@
package client_test
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid/client"
)
const (
LogoutCallbackURI = "http://wonderwall/oauth2/logout/callback"
PostLogoutRedirectURI = "http://some-other-url"
EndSessionEndpoint = "http://provider/endsession"
)
func TestLogout_CanonicalRedirect(t *testing.T) {
logout := newLogout(t)
canonicalRedirect := logout.CanonicalRedirect()
assert.Equal(t, PostLogoutRedirectURI, canonicalRedirect)
}
func TestLogout_Cookie(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
assert.NotNil(t, cookie)
assert.NotEmpty(t, cookie.State)
assert.NotEmpty(t, cookie.RedirectTo)
}
func TestLogout_SingleLogoutURL(t *testing.T) {
t.Run("with id_token", func(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
state := cookie.State
idToken := "some-id-token"
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 3)
assert.Contains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
t.Run("without id_token", func(t *testing.T) {
logout := newLogout(t)
cookie := logout.Cookie()
state := cookie.State
idToken := ""
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 2)
assert.NotContains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
}
func newLogout(t *testing.T) client.Logout {
cfg := mock.NewTestConfiguration(mock.Config())
cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI
cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI
cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint
logout, err := newTestClientWithConfig(cfg).Logout()
assert.NoError(t, err)
return logout
}