mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-08 09:27:12 +00:00
refactor: clean up tests and mock setup
This commit is contained in:
88
pkg/mock/client_configuration.go
Normal file
88
pkg/mock/client_configuration.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/scopes"
|
||||
)
|
||||
|
||||
type TestClientConfiguration struct {
|
||||
ClientID string
|
||||
ClientJWK jwk.Key
|
||||
RedirectURI string
|
||||
PostLogoutRedirectURI string
|
||||
Scopes scopes.Scopes
|
||||
ACRValues openid.OptionalConfiguration
|
||||
UILocales openid.OptionalConfiguration
|
||||
WellKnownURL string
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetRedirectURI() string {
|
||||
return c.RedirectURI
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetClientID() string {
|
||||
return c.ClientID
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetClientJWK() jwk.Key {
|
||||
return c.ClientJWK
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetPostLogoutRedirectURI() string {
|
||||
return c.PostLogoutRedirectURI
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetScopes() scopes.Scopes {
|
||||
return c.Scopes
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetACRValues() openid.OptionalConfiguration {
|
||||
return c.ACRValues
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetUILocales() openid.OptionalConfiguration {
|
||||
return c.UILocales
|
||||
}
|
||||
|
||||
func (c TestClientConfiguration) GetWellKnownURL() string {
|
||||
return c.WellKnownURL
|
||||
}
|
||||
|
||||
func clientConfiguration() TestClientConfiguration {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
key, err := jwk.New(privateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
key.Set(jwk.AlgorithmKey, jwa.RS256)
|
||||
key.Set(jwk.KeyTypeKey, jwa.RSA)
|
||||
key.Set(jwk.KeyIDKey, uuid.New().String())
|
||||
|
||||
return TestClientConfiguration{
|
||||
ClientID: "client_id",
|
||||
ClientJWK: key,
|
||||
RedirectURI: "http://localhost/callback",
|
||||
WellKnownURL: "",
|
||||
UILocales: openid.OptionalConfiguration{
|
||||
Enabled: true,
|
||||
Value: "nb",
|
||||
},
|
||||
ACRValues: openid.OptionalConfiguration{
|
||||
Enabled: true,
|
||||
Value: "Level4",
|
||||
},
|
||||
PostLogoutRedirectURI: "",
|
||||
Scopes: scopes.Defaults(),
|
||||
}
|
||||
}
|
||||
@@ -1,72 +1,40 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
)
|
||||
|
||||
type IDPorten struct {
|
||||
Clients map[string]string
|
||||
Config config.IDPorten
|
||||
Codes map[string]AuthRequest
|
||||
Keys jwk.Set
|
||||
type identityProviderHandler struct {
|
||||
Codes map[string]authorizeRequest
|
||||
Provider TestProvider
|
||||
Sessions map[string]string
|
||||
}
|
||||
|
||||
type AuthRequest struct {
|
||||
func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler {
|
||||
return &identityProviderHandler{
|
||||
Codes: make(map[string]authorizeRequest),
|
||||
Provider: provider,
|
||||
Sessions: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
type authorizeRequest struct {
|
||||
AcrLevel string
|
||||
CodeChallenge string
|
||||
Locale string
|
||||
Nonce string
|
||||
}
|
||||
|
||||
func NewIDPorten(clients map[string]string, config config.IDPorten) *IDPorten {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
key, err := jwk.New(privateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = jwk.AssignKeyID(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = key.Set(jwk.AlgorithmKey, jwa.RS256)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
keys := jwk.NewSet()
|
||||
keys.Add(key)
|
||||
|
||||
return &IDPorten{
|
||||
Clients: clients,
|
||||
Codes: make(map[string]AuthRequest),
|
||||
Config: config,
|
||||
Keys: keys,
|
||||
Sessions: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
type TokenJSON struct {
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
@@ -74,8 +42,9 @@ type TokenJSON struct {
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func (ip *IDPorten) signToken(token jwt.Token) (string, error) {
|
||||
signer, ok := ip.Keys.Get(0)
|
||||
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
privateJwkSet := *ip.Provider.PrivateJwkSet()
|
||||
signer, ok := privateJwkSet.Get(0)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("could not get signer")
|
||||
}
|
||||
@@ -88,7 +57,49 @@ func (ip *IDPorten) signToken(token jwt.Token) (string, error) {
|
||||
return string(signedToken), nil
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
redirect := query.Get("redirect_uri")
|
||||
acrLevel := query.Get("acr_values")
|
||||
codeChallenge := query.Get("code_challenge")
|
||||
locale := query.Get("ui_locales")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required fields"))
|
||||
return
|
||||
}
|
||||
|
||||
code := uuid.New().String()
|
||||
ip.Codes[code] = authorizeRequest{
|
||||
AcrLevel: acrLevel,
|
||||
CodeChallenge: codeChallenge,
|
||||
Locale: locale,
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse redirect uri"))
|
||||
return
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("code", code)
|
||||
v.Set("state", state)
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) Jwks(w http.ResponseWriter, _ *http.Request) {
|
||||
json.NewEncoder(w).Encode(ip.Provider.GetPublicJwkSet())
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -116,9 +127,48 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
sub := uuid.New().String()
|
||||
sid := uuid.New().String()
|
||||
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
if len(clientID) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing client_id"))
|
||||
return
|
||||
}
|
||||
|
||||
clientAssertion := r.PostForm.Get("client_assertion")
|
||||
if len(clientID) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing client_assertion"))
|
||||
return
|
||||
}
|
||||
|
||||
clientJwk := ip.Provider.GetClientConfiguration().GetClientJWK()
|
||||
clientJwkSet := jwk.NewSet()
|
||||
clientJwkSet.Add(clientJwk)
|
||||
publicClientJwkSet, err := jwk.PublicSetOf(clientJwkSet)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("failed to create public client jwk set"))
|
||||
return
|
||||
}
|
||||
|
||||
opts := []jwt.ParseOption{
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(publicClientJwkSet),
|
||||
jwt.WithIssuer(ip.Provider.GetClientConfiguration().GetClientID()),
|
||||
jwt.WithSubject(ip.Provider.GetClientConfiguration().GetClientID()),
|
||||
jwt.WithClaimValue("scope", ip.Provider.GetClientConfiguration().GetScopes().String()),
|
||||
jwt.WithAudience(ip.Provider.GetOpenIDConfiguration().Issuer),
|
||||
}
|
||||
_, err = jwt.Parse([]byte(clientAssertion), opts...)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(fmt.Sprintf("invalid client assertion: %+v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("sub", sub)
|
||||
accessToken.Set("iss", ip.Config.WellKnown.Issuer)
|
||||
accessToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer)
|
||||
accessToken.Set("acr", auth.AcrLevel)
|
||||
accessToken.Set("iat", time.Now().Unix())
|
||||
accessToken.Set("exp", time.Now().Unix()+expires)
|
||||
@@ -131,8 +181,8 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
idToken := jwt.New()
|
||||
idToken.Set("sub", sub)
|
||||
idToken.Set("iss", ip.Config.WellKnown.Issuer)
|
||||
idToken.Set("aud", ip.Config.ClientID)
|
||||
idToken.Set("iss", ip.Provider.GetOpenIDConfiguration().Issuer)
|
||||
idToken.Set("aud", clientID)
|
||||
idToken.Set("locale", auth.Locale)
|
||||
idToken.Set("nonce", auth.Nonce)
|
||||
idToken.Set("acr", auth.AcrLevel)
|
||||
@@ -147,8 +197,8 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ip.Sessions[sid] = ip.Config.ClientID
|
||||
token := &TokenJSON{
|
||||
ip.Sessions[sid] = clientID
|
||||
token := &tokenResponse{
|
||||
AccessToken: signedAccessToken,
|
||||
TokenType: "Bearer",
|
||||
IDToken: signedIdToken,
|
||||
@@ -159,59 +209,3 @@ func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
redirect := query.Get("redirect_uri")
|
||||
acrLevel := query.Get("acr_values")
|
||||
codeChallenge := query.Get("code_challenge")
|
||||
locale := query.Get("ui_locales")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required fields"))
|
||||
return
|
||||
}
|
||||
|
||||
code := uuid.New().String()
|
||||
ip.Codes[code] = AuthRequest{
|
||||
AcrLevel: acrLevel,
|
||||
CodeChallenge: codeChallenge,
|
||||
Locale: locale,
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse redirect uri"))
|
||||
return
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("code", code)
|
||||
v.Set("state", state)
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Jwks(w http.ResponseWriter, r *http.Request) {
|
||||
publicSet, err := jwk.PublicSetOf(ip.Keys)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not create public set: " + err.Error()))
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(publicSet)
|
||||
}
|
||||
|
||||
func IDPortenRouter(ip *IDPorten) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/authorize", ip.Authorize)
|
||||
r.Post("/token", ip.Token)
|
||||
r.Get("/jwks", ip.Jwks)
|
||||
return r
|
||||
}
|
||||
50
pkg/mock/provider.go
Normal file
50
pkg/mock/provider.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/jwks"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
|
||||
type TestProvider struct {
|
||||
ClientConfiguration *TestClientConfiguration
|
||||
OpenIDConfiguration *openid.Configuration
|
||||
JwksPair *jwks.Pair
|
||||
}
|
||||
|
||||
func (p TestProvider) GetClientConfiguration() openid.ClientConfiguration {
|
||||
return p.ClientConfiguration
|
||||
}
|
||||
|
||||
func (p TestProvider) GetOpenIDConfiguration() *openid.Configuration {
|
||||
return p.OpenIDConfiguration
|
||||
}
|
||||
|
||||
func (p TestProvider) GetPublicJwkSet() *jwk.Set {
|
||||
return &p.JwksPair.Public
|
||||
}
|
||||
|
||||
func (p TestProvider) PrivateJwkSet() *jwk.Set {
|
||||
return &p.JwksPair.Private
|
||||
}
|
||||
|
||||
func NewTestProvider() TestProvider {
|
||||
jwksPair, err := jwks.NewJwksPair()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
clientCfg := clientConfiguration()
|
||||
provider := TestProvider{
|
||||
ClientConfiguration: &clientCfg,
|
||||
OpenIDConfiguration: &openid.Configuration{
|
||||
ACRValuesSupported: openid.Supported{"Level3", "Level4"},
|
||||
UILocalesSupported: openid.Supported{"nb", "nb", "en", "se"},
|
||||
},
|
||||
JwksPair: jwksPair,
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
11
pkg/mock/router.go
Normal file
11
pkg/mock/router.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package mock
|
||||
|
||||
import "github.com/go-chi/chi/v5"
|
||||
|
||||
func identityProviderRouter(ip *identityProviderHandler) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/authorize", ip.Authorize)
|
||||
r.Post("/token", ip.Token)
|
||||
r.Get("/jwks", ip.Jwks)
|
||||
return r
|
||||
}
|
||||
20
pkg/mock/server.go
Normal file
20
pkg/mock/server.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
)
|
||||
|
||||
func IdentityProviderServer() (*httptest.Server, TestProvider) {
|
||||
provider := NewTestProvider()
|
||||
handler := newIdentityProviderHandler(provider)
|
||||
router := identityProviderRouter(handler)
|
||||
server := httptest.NewServer(router)
|
||||
|
||||
provider.OpenIDConfiguration.Issuer = server.URL
|
||||
provider.OpenIDConfiguration.JwksURI = server.URL + "/jwks"
|
||||
provider.OpenIDConfiguration.AuthorizationEndpoint = server.URL + "/authorize"
|
||||
provider.OpenIDConfiguration.TokenEndpoint = server.URL + "/token"
|
||||
provider.OpenIDConfiguration.EndSessionEndpoint = server.URL + "/endsession"
|
||||
|
||||
return server, provider
|
||||
}
|
||||
@@ -47,14 +47,14 @@ func TestLoginURL(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.url, func(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
req, err := http.NewRequest("GET", test.url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params, err := openid.GenerateLoginParameters()
|
||||
assert.NoError(t, err)
|
||||
|
||||
handler := handler(cfg)
|
||||
provider := mock.NewTestProvider()
|
||||
handler := handler(provider)
|
||||
_, err = handler.LoginURL(req, params)
|
||||
|
||||
if test.error != nil {
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
@@ -16,85 +10,31 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/provider"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
const clientID = "clientid"
|
||||
|
||||
var encryptionKey = []byte(`G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`) // 256 bits AES
|
||||
|
||||
var clients = map[string]string{
|
||||
clientID: "http://localhost/oauth2/logout/frontchannel",
|
||||
var cfg = config.Config{
|
||||
EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits AES
|
||||
Ingress: "/",
|
||||
OpenID: config.OpenID{
|
||||
Provider: "test",
|
||||
},
|
||||
SessionMaxLifetime: time.Hour,
|
||||
}
|
||||
|
||||
func defaultConfig() config.Config {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
key, err := jwk.New(privateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
key.Set(jwk.AlgorithmKey, jwa.RS256)
|
||||
key.Set(jwk.KeyTypeKey, jwa.RSA)
|
||||
key.Set(jwk.KeyIDKey, uuid.New().String())
|
||||
|
||||
clientJwk, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return config.Config{IDPorten: config.IDPorten{
|
||||
ClientID: clientID,
|
||||
ClientJWK: string(clientJwk),
|
||||
RedirectURI: "http://localhost/callback",
|
||||
WellKnownURL: "",
|
||||
WellKnown: config.IDPortenWellKnown{
|
||||
Issuer: "issuer",
|
||||
AuthorizationEndpoint: "http://localhost:1234/authorize",
|
||||
ACRValuesSupported: config.Supported{"Level3", "Level4"},
|
||||
UILocalesSupported: config.Supported{"nb", "nb", "en", "se"},
|
||||
},
|
||||
Locale: config.IDPortenLocale{
|
||||
Enabled: true,
|
||||
Value: "nb",
|
||||
},
|
||||
SecurityLevel: config.IDPortenSecurityLevel{
|
||||
Enabled: true,
|
||||
Value: "Level4",
|
||||
},
|
||||
PostLogoutRedirectURI: "",
|
||||
SessionMaxLifetime: time.Hour,
|
||||
}}
|
||||
}
|
||||
|
||||
func handler(cfg config.Config) *router.Handler {
|
||||
var jwkSet jwk.Set
|
||||
var err error
|
||||
|
||||
if len(cfg.IDPorten.WellKnown.JwksURI) == 0 {
|
||||
jwk.NewSet()
|
||||
} else {
|
||||
jwkSet, err = jwk.Fetch(context.Background(), cfg.IDPorten.WellKnown.JwksURI)
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
crypter := cryptutil.New(encryptionKey)
|
||||
func handler(provider provider.Provider) *router.Handler {
|
||||
crypter := cryptutil.New([]byte(cfg.EncryptionKey))
|
||||
sessionStore := session.NewMemory()
|
||||
|
||||
handler, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, jwkSet, sessionStore, "")
|
||||
handler, err := router.NewHandler(cfg, crypter, zerolog.Logger{}, provider, sessionStore)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -102,9 +42,8 @@ func handler(cfg config.Config) *router.Handler {
|
||||
}
|
||||
|
||||
func TestHandler_Login(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
|
||||
h := handler(cfg)
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
h := handler(idp)
|
||||
r := router.New(h)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
@@ -117,41 +56,38 @@ func TestHandler_Login(t *testing.T) {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
idprouter := mock.IDPortenRouter(mock.NewIDPorten(clients, cfg.IDPorten))
|
||||
idpserver := httptest.NewServer(idprouter)
|
||||
|
||||
h.Config.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize"
|
||||
|
||||
loginURL, err := url.Parse(server.URL + "/oauth2/login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
req, err := client.Get(loginURL.String())
|
||||
resp, err := client.Get(loginURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
cookies := client.Jar.Cookies(loginURL)
|
||||
loginCookie := getCookieFromJar(h.GetLoginCookieName(), cookies)
|
||||
assert.NotNil(t, loginCookie)
|
||||
|
||||
location := req.Header.Get("location")
|
||||
location := resp.Header.Get("location")
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, idpserver.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host))
|
||||
assert.Equal(t, "/authorize", u.Path)
|
||||
assert.Equal(t, cfg.IDPorten.SecurityLevel.Value, u.Query().Get("acr_values"))
|
||||
assert.Equal(t, cfg.IDPorten.Locale.Value, u.Query().Get("ui_locales"))
|
||||
assert.Equal(t, cfg.IDPorten.ClientID, u.Query().Get("client_id"))
|
||||
assert.Equal(t, cfg.IDPorten.RedirectURI, u.Query().Get("redirect_uri"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetACRValues().Value, u.Query().Get("acr_values"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetUILocales().Value, u.Query().Get("ui_locales"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), u.Query().Get("client_id"))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetRedirectURI(), u.Query().Get("redirect_uri"))
|
||||
assert.NotEmpty(t, u.Query().Get("state"))
|
||||
assert.NotEmpty(t, u.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, u.Query().Get("code_challenge"))
|
||||
|
||||
req, err = client.Get(u.String())
|
||||
resp, err = client.Get(u.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
location = req.Header.Get("location")
|
||||
location = resp.Header.Get("location")
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -160,21 +96,14 @@ func TestHandler_Login(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
|
||||
idprouter := mock.IDPortenRouter(mock.NewIDPorten(clients, cfg.IDPorten))
|
||||
idpserver := httptest.NewServer(idprouter)
|
||||
cfg.IDPorten.WellKnown.JwksURI = idpserver.URL + "/jwks"
|
||||
cfg.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize"
|
||||
cfg.IDPorten.WellKnown.TokenEndpoint = idpserver.URL + "/token"
|
||||
cfg.IDPorten.WellKnown.EndSessionEndpoint = idpserver.URL + "/endsession"
|
||||
|
||||
h := handler(cfg)
|
||||
h := handler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback"
|
||||
h.Config.IDPorten.PostLogoutRedirectURI = server.URL
|
||||
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -187,9 +116,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
|
||||
// First, run /oauth2/login to set cookies
|
||||
loginURL, err := url.Parse(server.URL + "/oauth2/login")
|
||||
req, err := client.Get(loginURL.String())
|
||||
resp, err := client.Get(loginURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
cookies := client.Jar.Cookies(loginURL)
|
||||
sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies)
|
||||
@@ -199,23 +129,25 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
assert.NotNil(t, loginCookie)
|
||||
|
||||
// Get authorization URL
|
||||
location := req.Header.Get("location")
|
||||
location := resp.Header.Get("location")
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to authorize with idporten
|
||||
req, err = client.Get(u.String())
|
||||
// Follow redirect to authorize with identity provider
|
||||
resp, err = client.Get(u.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get callback URL after successful auth
|
||||
location = req.Header.Get("location")
|
||||
location = resp.Header.Get("location")
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to callback
|
||||
req, err = client.Get(callbackURL.String())
|
||||
resp, err = client.Get(callbackURL.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies = client.Jar.Cookies(callbackURL)
|
||||
sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies)
|
||||
@@ -228,9 +160,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
logoutURL, err := url.Parse(server.URL + "/oauth2/logout")
|
||||
assert.NoError(t, err)
|
||||
|
||||
req, err = client.Get(logoutURL.String())
|
||||
resp, err = client.Get(logoutURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
cookies = client.Jar.Cookies(logoutURL)
|
||||
sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies)
|
||||
@@ -238,7 +171,7 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
assert.Nil(t, sessionCookie)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
location = req.Header.Get("location")
|
||||
location = resp.Header.Get("location")
|
||||
endsessionURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -249,27 +182,18 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{h.Config.IDPorten.PostLogoutRedirectURI})
|
||||
assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetPostLogoutRedirectURI()})
|
||||
assert.NotEmpty(t, endsessionParams["id_token_hint"])
|
||||
}
|
||||
|
||||
func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
|
||||
idp := mock.NewIDPorten(clients, cfg.IDPorten)
|
||||
idprouter := mock.IDPortenRouter(idp)
|
||||
idpserver := httptest.NewServer(idprouter)
|
||||
|
||||
cfg.IDPorten.WellKnown.JwksURI = idpserver.URL + "/jwks"
|
||||
cfg.IDPorten.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize"
|
||||
cfg.IDPorten.WellKnown.TokenEndpoint = idpserver.URL + "/token"
|
||||
|
||||
h := handler(cfg)
|
||||
_, idp := mock.IdentityProviderServer()
|
||||
h := handler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback"
|
||||
h.Config.IDPorten.PostLogoutRedirectURI = server.URL
|
||||
idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback"
|
||||
idp.ClientConfiguration.PostLogoutRedirectURI = server.URL
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -281,28 +205,31 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
}
|
||||
|
||||
// First, run /oauth2/login to set cookies
|
||||
req, err := client.Get(server.URL + "/oauth2/login")
|
||||
resp, err := client.Get(server.URL + "/oauth2/login")
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get authorization URL
|
||||
location := req.Header.Get("location")
|
||||
location := resp.Header.Get("location")
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to authorize with idporten
|
||||
req, err = client.Get(u.String())
|
||||
resp, err = client.Get(u.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get callback URL after successful auth
|
||||
location = req.Header.Get("location")
|
||||
location = resp.Header.Get("location")
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to callback
|
||||
req, err = client.Get(callbackURL.String())
|
||||
resp, err = client.Get(callbackURL.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies := client.Jar.Cookies(callbackURL)
|
||||
sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies)
|
||||
@@ -323,14 +250,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
|
||||
values := url.Values{}
|
||||
values.Add("sid", string(sid))
|
||||
values.Add("iss", h.Config.IDPorten.WellKnown.Issuer)
|
||||
values.Add("iss", idp.GetOpenIDConfiguration().Issuer)
|
||||
frontchannelLogoutURL.RawQuery = values.Encode()
|
||||
|
||||
req, err = client.Get(frontchannelLogoutURL.String())
|
||||
resp, err = client.Get(frontchannelLogoutURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, req.StatusCode)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
||||
|
||||
@@ -2,18 +2,20 @@ package router_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
h := handler(cfg)
|
||||
h := handler(mock.NewTestProvider())
|
||||
|
||||
t.Run("request without fallback session cookies", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
@@ -53,8 +55,7 @@ func TestHandler_GetSessionFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
h := handler(cfg)
|
||||
h := handler(mock.NewTestProvider())
|
||||
|
||||
// request should set session cookies in response
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -87,8 +88,7 @@ func TestHandler_SetSessionFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_DeleteSessionFallback(t *testing.T) {
|
||||
cfg := defaultConfig()
|
||||
h := handler(cfg)
|
||||
h := handler(mock.NewTestProvider())
|
||||
|
||||
writer := httptest.NewRecorder()
|
||||
h.DeleteSessionFallback(writer)
|
||||
|
||||
Reference in New Issue
Block a user