mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-07 08:57:07 +00:00
197 lines
5.5 KiB
Go
197 lines
5.5 KiB
Go
package handler_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/nais/wonderwall/pkg/handler"
|
|
"github.com/nais/wonderwall/pkg/mock"
|
|
"github.com/nais/wonderwall/pkg/openid"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
)
|
|
|
|
func TestHandler_GetSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
tokens := makeTokens(idp.Provider)
|
|
rpHandler := idp.RelyingPartyHandler
|
|
|
|
t.Run("request without fallback session cookies", func(t *testing.T) {
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
_, err := rpHandler.GetSessionFallback(w, r)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("request with fallback session cookies", func(t *testing.T) {
|
|
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
|
|
w := httptest.NewRecorder()
|
|
sessionData, err := rpHandler.GetSessionFallback(w, r)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "sid", sessionData.ExternalSessionID)
|
|
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
|
|
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
|
|
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
|
|
assert.Empty(t, sessionData.RefreshToken)
|
|
})
|
|
}
|
|
|
|
func TestHandler_SetSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
tokens := makeTokens(idp.Provider)
|
|
rpHandler := idp.RelyingPartyHandler
|
|
|
|
// request should set session cookies in response
|
|
writer := httptest.NewRecorder()
|
|
expiresIn := time.Minute
|
|
data := session.NewData("sid", tokens, nil)
|
|
err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn)
|
|
assert.NoError(t, err)
|
|
|
|
cookies := writer.Result().Cookies()
|
|
|
|
for _, test := range []struct {
|
|
cookieName string
|
|
want string
|
|
}{
|
|
{
|
|
cookieName: "wonderwall-1",
|
|
want: "sid",
|
|
},
|
|
{
|
|
cookieName: "wonderwall-2",
|
|
want: tokens.IDToken.GetSerialized(),
|
|
},
|
|
{
|
|
cookieName: "wonderwall-3",
|
|
want: tokens.AccessToken,
|
|
},
|
|
} {
|
|
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
|
|
}
|
|
}
|
|
|
|
func TestHandler_DeleteSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpHandler := idp.RelyingPartyHandler
|
|
tokens := makeTokens(idp.Provider)
|
|
|
|
t.Run("expire cookies if they are set", func(t *testing.T) {
|
|
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
|
|
writer := httptest.NewRecorder()
|
|
rpHandler.DeleteSessionFallback(writer, r)
|
|
cookies := writer.Result().Cookies()
|
|
|
|
assert.NotEmpty(t, cookies)
|
|
assert.Len(t, cookies, 3)
|
|
|
|
assertCookieExpired(t, "wonderwall-1", cookies)
|
|
assertCookieExpired(t, "wonderwall-2", cookies)
|
|
assertCookieExpired(t, "wonderwall-3", cookies)
|
|
})
|
|
|
|
t.Run("skip expiring cookies if they are not set", func(t *testing.T) {
|
|
writer := httptest.NewRecorder()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rpHandler.DeleteSessionFallback(writer, r)
|
|
cookies := writer.Result().Cookies()
|
|
|
|
assert.Empty(t, cookies)
|
|
})
|
|
}
|
|
|
|
func makeRequestWithFallbackCookies(t *testing.T, h *handler.Handler, tokens *openid.Tokens) *http.Request {
|
|
writer := httptest.NewRecorder()
|
|
expiresIn := time.Minute
|
|
data := session.NewData("sid", tokens, nil)
|
|
err := h.SetSessionFallback(writer, nil, data, expiresIn)
|
|
assert.NoError(t, err)
|
|
|
|
cookies := writer.Result().Cookies()
|
|
|
|
externalSessionIDCookie := getCookieFromJar("wonderwall-1", cookies)
|
|
assert.NotNil(t, externalSessionIDCookie)
|
|
idTokenCookie := getCookieFromJar("wonderwall-2", cookies)
|
|
assert.NotNil(t, idTokenCookie)
|
|
accessTokenCookie := getCookieFromJar("wonderwall-3", cookies)
|
|
assert.NotNil(t, accessTokenCookie)
|
|
|
|
// make request with fallback session cookies set
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.AddCookie(externalSessionIDCookie)
|
|
r.AddCookie(idTokenCookie)
|
|
r.AddCookie(accessTokenCookie)
|
|
|
|
return r
|
|
}
|
|
|
|
func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie) {
|
|
expired := getCookieFromJar(cookieName, cookies)
|
|
assert.NotNil(t, expired)
|
|
assert.Less(t, expired.MaxAge, 0)
|
|
assert.True(t, expired.Expires.Before(time.Now()))
|
|
assert.Empty(t, expired.Value)
|
|
}
|
|
|
|
func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
|
|
desiredCookie := getCookieFromJar(cookieName, cookies)
|
|
assert.NotNil(t, desiredCookie)
|
|
|
|
ciphertext, err := base64.StdEncoding.DecodeString(desiredCookie.Value)
|
|
assert.NoError(t, err)
|
|
|
|
plainbytes, err := h.Crypter.Decrypt(ciphertext)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedValue, string(plainbytes))
|
|
}
|
|
|
|
func makeTokens(provider *mock.TestProvider) *openid.Tokens {
|
|
jwks := *provider.PrivateJwkSet()
|
|
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
|
|
if err != nil {
|
|
log.Fatalf("getting public jwk set: %+v", err)
|
|
}
|
|
|
|
signer, ok := jwks.Key(0)
|
|
if !ok {
|
|
log.Fatalf("getting signer")
|
|
}
|
|
|
|
idToken := jwtlib.New()
|
|
idToken.Set("jti", "id-token-jti")
|
|
|
|
signedIdToken, err := jwtlib.Sign(idToken, jwtlib.WithKey(jwa.RS256, signer))
|
|
if err != nil {
|
|
log.Fatalf("signing id_token: %+v", err)
|
|
}
|
|
|
|
parsedIdToken, err := jwtlib.Parse(signedIdToken, jwtlib.WithKeySet(*jwksPublic))
|
|
if err != nil {
|
|
log.Fatalf("parsing signed id_token: %+v", err)
|
|
}
|
|
|
|
accessToken := "some-access-token"
|
|
|
|
return &openid.Tokens{
|
|
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
|
|
AccessToken: accessToken,
|
|
}
|
|
}
|