mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-09 01:47:03 +00:00
test(openid/client): add missing tests
This commit is contained in:
@@ -29,6 +29,7 @@ type IdentityProvider struct {
|
||||
Cfg *config.Config
|
||||
OpenIDConfig Configuration
|
||||
Provider TestProvider
|
||||
ProviderHandler *IdentityProviderHandler
|
||||
ProviderServer *httptest.Server
|
||||
RelyingPartyHandler *router.Handler
|
||||
RelyingPartyServer *httptest.Server
|
||||
@@ -87,11 +88,12 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider {
|
||||
RelyingPartyServer: rpServer,
|
||||
OpenIDConfig: openidConfig,
|
||||
Provider: provider,
|
||||
ProviderHandler: handler,
|
||||
ProviderServer: server,
|
||||
}
|
||||
}
|
||||
|
||||
func identityProviderRouter(ip *identityProviderHandler) chi.Router {
|
||||
func identityProviderRouter(ip *IdentityProviderHandler) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/authorize", ip.Authorize)
|
||||
r.Post("/token", ip.Token)
|
||||
@@ -100,23 +102,23 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router {
|
||||
return r
|
||||
}
|
||||
|
||||
type identityProviderHandler struct {
|
||||
Codes map[string]authorizeRequest
|
||||
type IdentityProviderHandler struct {
|
||||
Codes map[string]AuthorizeRequest
|
||||
Config openidconfig.Config
|
||||
Provider TestProvider
|
||||
Sessions map[string]string
|
||||
}
|
||||
|
||||
func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *identityProviderHandler {
|
||||
return &identityProviderHandler{
|
||||
Codes: make(map[string]authorizeRequest),
|
||||
func newIdentityProviderHandler(provider TestProvider, cfg openidconfig.Config) *IdentityProviderHandler {
|
||||
return &IdentityProviderHandler{
|
||||
Codes: make(map[string]AuthorizeRequest),
|
||||
Config: cfg,
|
||||
Provider: provider,
|
||||
Sessions: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
type authorizeRequest struct {
|
||||
type AuthorizeRequest struct {
|
||||
AcrLevel string
|
||||
CodeChallenge string
|
||||
Locale string
|
||||
@@ -131,7 +133,7 @@ type tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
func (ip *IdentityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
privateJwkSet := *ip.Provider.PrivateJwkSet()
|
||||
signer, ok := privateJwkSet.Key(0)
|
||||
if !ok {
|
||||
@@ -146,7 +148,7 @@ func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
return string(signedToken), nil
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IdentityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
redirect := query.Get("redirect_uri")
|
||||
@@ -162,7 +164,7 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
code := uuid.New().String()
|
||||
ip.Codes[code] = authorizeRequest{
|
||||
ip.Codes[code] = AuthorizeRequest{
|
||||
AcrLevel: acrLevel,
|
||||
CodeChallenge: codeChallenge,
|
||||
Locale: locale,
|
||||
@@ -188,12 +190,12 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) {
|
||||
jwks, _ := ip.Provider.GetPublicJwkSet(r.Context())
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -302,10 +304,11 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
token := &tokenResponse{
|
||||
AccessToken: signedAccessToken,
|
||||
TokenType: "Bearer",
|
||||
IDToken: signedIdToken,
|
||||
ExpiresIn: expires,
|
||||
AccessToken: signedAccessToken,
|
||||
TokenType: "Bearer",
|
||||
IDToken: signedIdToken,
|
||||
RefreshToken: code + "some-refresh-token",
|
||||
ExpiresIn: expires,
|
||||
}
|
||||
|
||||
w.Header().Set("content-type", "application/json")
|
||||
@@ -313,7 +316,7 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
|
||||
|
||||
@@ -48,3 +48,11 @@ func TestMakeAssertion(t *testing.T) {
|
||||
assert.True(t, assertion.Expiration().After(time.Now()))
|
||||
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
|
||||
}
|
||||
|
||||
func newTestClientWithConfig(config mock.Configuration) client.Client {
|
||||
return client.NewClient(config)
|
||||
}
|
||||
|
||||
func newTestClient() client.Client {
|
||||
return newTestClientWithConfig(mock.NewTestConfiguration(mock.Config()))
|
||||
}
|
||||
|
||||
170
pkg/openid/client/login_callback_test.go
Normal file
170
pkg/openid/client/login_callback_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
func TestLoginCallback_StateMismatchError(t *testing.T) {
|
||||
cookie := &openid.LoginCookie{
|
||||
State: "some-state",
|
||||
}
|
||||
|
||||
t.Run("invalid state", func(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback?state=some-other-state"
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
|
||||
err := lc.StateMismatchError()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing state", func(t *testing.T) {
|
||||
url := "http://wonderwall/oauth2/callback"
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
|
||||
err := lc.StateMismatchError()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginCallback_IdentityProviderError(t *testing.T) {
|
||||
cookie := &openid.LoginCookie{
|
||||
State: "some-state",
|
||||
}
|
||||
|
||||
url := "http://wonderwall/oauth2/callback?error=invalid_client&error_description=client%20authenticaion%20failed"
|
||||
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
|
||||
err := lc.IdentityProviderError()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLoginCallback_ExchangeAuthCode(t *testing.T) {
|
||||
t.Run("valid code", func(t *testing.T) {
|
||||
cookie := &openid.LoginCookie{}
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{
|
||||
"some-code": {},
|
||||
}
|
||||
|
||||
tokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokens)
|
||||
|
||||
assert.NotEmpty(t, tokens.AccessToken)
|
||||
assert.NotEmpty(t, tokens.RefreshToken)
|
||||
assert.NotEmpty(t, tokens.Extra("id_token"))
|
||||
assert.NotEmpty(t, tokens.TokenType)
|
||||
assert.NotEmpty(t, tokens.Expiry)
|
||||
|
||||
assert.Equal(t, "Bearer", tokens.TokenType)
|
||||
|
||||
assert.True(t, time.Now().Before(tokens.Expiry))
|
||||
assert.True(t, tokens.Expiry.Before(time.Now().Add(time.Hour)))
|
||||
})
|
||||
|
||||
t.Run("invalid code", func(t *testing.T) {
|
||||
cookie := &openid.LoginCookie{}
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes = map[string]mock.AuthorizeRequest{
|
||||
"some-other-code": {},
|
||||
"another-code": {},
|
||||
}
|
||||
|
||||
tokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginCallback_ProcessTokens(t *testing.T) {
|
||||
cookie := &openid.LoginCookie{
|
||||
State: "some-state",
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
url := "http://wonderwall/oauth2/callback?code=some-code"
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokens)
|
||||
})
|
||||
|
||||
t.Run("nonce mismatch", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
|
||||
Nonce: "some-other-nonce",
|
||||
}
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
|
||||
t.Run("unexpected audience", func(t *testing.T) {
|
||||
idp, lc := newLoginCallback(t, url, cookie)
|
||||
defer idp.Close()
|
||||
idp.ProviderHandler.Codes["some-code"] = mock.AuthorizeRequest{
|
||||
Nonce: "some-nonce",
|
||||
}
|
||||
idp.OpenIDConfig.ClientConfig.ClientID = "new-client-id"
|
||||
|
||||
rawTokens, err := lc.ExchangeAuthCode(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawTokens)
|
||||
|
||||
tokens, err := lc.ProcessTokens(context.Background(), rawTokens)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokens)
|
||||
})
|
||||
}
|
||||
|
||||
func newLoginCallback(t *testing.T, url string, cookie *openid.LoginCookie) (mock.IdentityProvider, client.LoginCallback) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
idp := mock.NewIdentityProvider(mock.Config())
|
||||
|
||||
cfg := idp.OpenIDConfig
|
||||
cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI
|
||||
cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI
|
||||
cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint
|
||||
|
||||
loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return idp, loginCallback
|
||||
}
|
||||
81
pkg/openid/client/logout_callback_test.go
Normal file
81
pkg/openid/client/logout_callback_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
func TestLogoutCallback_ValidateRequest(t *testing.T) {
|
||||
t.Run("nil cookie", func(t *testing.T) {
|
||||
_, err := newLogoutCallback(t, "http://localhost/oauth2/logout/callback?state=some-state", nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
url string
|
||||
cookie *openid.LogoutCookie
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
url: "http://localhost/oauth2/logout/callback?state=some-state",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-state",
|
||||
RedirectTo: "http://some-url",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty redirect",
|
||||
url: "http://localhost/oauth2/logout/callback?state=some-state",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-state",
|
||||
RedirectTo: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty state",
|
||||
url: "http://localhost/oauth2/logout/callback",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-state",
|
||||
RedirectTo: "http://some-url",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "state mismatch",
|
||||
url: "http://localhost/oauth2/logout/callback?state=some-other-state",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-state",
|
||||
RedirectTo: "http://some-url",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
lc, err := newLogoutCallback(t, test.url, test.cookie)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = lc.ValidateRequest()
|
||||
if test.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newLogoutCallback(t *testing.T, url string, cookie *openid.LogoutCookie) (client.LogoutCallback, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return newTestClient().LogoutCallback(req, cookie)
|
||||
}
|
||||
35
pkg/openid/client/logout_frontchannel_test.go
Normal file
35
pkg/openid/client/logout_frontchannel_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
func TestLogoutFrontchannel_Sid(t *testing.T) {
|
||||
t.Run("missing sid parameter in request", func(t *testing.T) {
|
||||
url := "http://localhost/oauth2/logout/frontchannel"
|
||||
lf := newLogoutFrontchannel(t, url)
|
||||
|
||||
assert.Empty(t, lf.Sid())
|
||||
assert.True(t, lf.MissingSidParameter())
|
||||
})
|
||||
|
||||
t.Run("has sid parameter in request", func(t *testing.T) {
|
||||
url := "http://localhost/oauth2/logout/frontchannel?sid=some-session-id"
|
||||
lf := newLogoutFrontchannel(t, url)
|
||||
|
||||
assert.Equal(t, "some-session-id", lf.Sid())
|
||||
assert.False(t, lf.MissingSidParameter())
|
||||
})
|
||||
}
|
||||
|
||||
func newLogoutFrontchannel(t *testing.T, url string) client.LogoutFrontchannel {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return newTestClient().LogoutFrontchannel(req)
|
||||
}
|
||||
105
pkg/openid/client/logout_test.go
Normal file
105
pkg/openid/client/logout_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
const (
|
||||
LogoutCallbackURI = "http://wonderwall/oauth2/logout/callback"
|
||||
PostLogoutRedirectURI = "http://some-other-url"
|
||||
EndSessionEndpoint = "http://provider/endsession"
|
||||
)
|
||||
|
||||
func TestLogout_CanonicalRedirect(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
canonicalRedirect := logout.CanonicalRedirect()
|
||||
|
||||
assert.Equal(t, PostLogoutRedirectURI, canonicalRedirect)
|
||||
}
|
||||
|
||||
func TestLogout_Cookie(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
cookie := logout.Cookie()
|
||||
|
||||
assert.NotNil(t, cookie)
|
||||
assert.NotEmpty(t, cookie.State)
|
||||
assert.NotEmpty(t, cookie.RedirectTo)
|
||||
}
|
||||
|
||||
func TestLogout_SingleLogoutURL(t *testing.T) {
|
||||
t.Run("with id_token", func(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
cookie := logout.Cookie()
|
||||
|
||||
state := cookie.State
|
||||
idToken := "some-id-token"
|
||||
|
||||
raw := logout.SingleLogoutURL(idToken)
|
||||
assert.NotEmpty(t, raw)
|
||||
|
||||
logoutUrl, err := url.Parse(raw)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := logoutUrl.Query()
|
||||
assert.Len(t, query, 3)
|
||||
|
||||
assert.Contains(t, query, "id_token_hint")
|
||||
assert.Equal(t, idToken, query.Get("id_token_hint"))
|
||||
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Equal(t, state, query.Get("state"))
|
||||
|
||||
assert.Contains(t, query, "post_logout_redirect_uri")
|
||||
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
|
||||
|
||||
logoutUrl.RawQuery = ""
|
||||
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
|
||||
})
|
||||
|
||||
t.Run("without id_token", func(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
cookie := logout.Cookie()
|
||||
|
||||
state := cookie.State
|
||||
idToken := ""
|
||||
|
||||
raw := logout.SingleLogoutURL(idToken)
|
||||
assert.NotEmpty(t, raw)
|
||||
|
||||
logoutUrl, err := url.Parse(raw)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := logoutUrl.Query()
|
||||
assert.Len(t, query, 2)
|
||||
|
||||
assert.NotContains(t, query, "id_token_hint")
|
||||
assert.Equal(t, idToken, query.Get("id_token_hint"))
|
||||
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Equal(t, state, query.Get("state"))
|
||||
|
||||
assert.Contains(t, query, "post_logout_redirect_uri")
|
||||
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
|
||||
|
||||
logoutUrl.RawQuery = ""
|
||||
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
|
||||
})
|
||||
}
|
||||
|
||||
func newLogout(t *testing.T) client.Logout {
|
||||
cfg := mock.NewTestConfiguration(mock.Config())
|
||||
cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI
|
||||
cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI
|
||||
cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint
|
||||
|
||||
logout, err := newTestClientWithConfig(cfg).Logout()
|
||||
assert.NoError(t, err)
|
||||
|
||||
return logout
|
||||
}
|
||||
Reference in New Issue
Block a user