Files
wonderwall/pkg/handler/session_fallback_test.go
2022-07-19 19:24:28 +02:00

197 lines
5.5 KiB
Go

package handler_test
import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/session"
)
func TestHandler_GetSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
tokens := makeTokens(idp.Provider)
rpHandler := idp.RelyingPartyHandler
t.Run("request without fallback session cookies", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
_, err := rpHandler.GetSessionFallback(w, r)
assert.Error(t, err)
})
t.Run("request with fallback session cookies", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
w := httptest.NewRecorder()
sessionData, err := rpHandler.GetSessionFallback(w, r)
assert.NoError(t, err)
assert.Equal(t, "sid", sessionData.ExternalSessionID)
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
assert.Empty(t, sessionData.RefreshToken)
})
}
func TestHandler_SetSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
tokens := makeTokens(idp.Provider)
rpHandler := idp.RelyingPartyHandler
// request should set session cookies in response
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", tokens, nil)
err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn)
assert.NoError(t, err)
cookies := writer.Result().Cookies()
for _, test := range []struct {
cookieName string
want string
}{
{
cookieName: "wonderwall-1",
want: "sid",
},
{
cookieName: "wonderwall-2",
want: tokens.IDToken.GetSerialized(),
},
{
cookieName: "wonderwall-3",
want: tokens.AccessToken,
},
} {
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
}
}
func TestHandler_DeleteSessionFallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpHandler := idp.RelyingPartyHandler
tokens := makeTokens(idp.Provider)
t.Run("expire cookies if they are set", func(t *testing.T) {
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
writer := httptest.NewRecorder()
rpHandler.DeleteSessionFallback(writer, r)
cookies := writer.Result().Cookies()
assert.NotEmpty(t, cookies)
assert.Len(t, cookies, 3)
assertCookieExpired(t, "wonderwall-1", cookies)
assertCookieExpired(t, "wonderwall-2", cookies)
assertCookieExpired(t, "wonderwall-3", cookies)
})
t.Run("skip expiring cookies if they are not set", func(t *testing.T) {
writer := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
rpHandler.DeleteSessionFallback(writer, r)
cookies := writer.Result().Cookies()
assert.Empty(t, cookies)
})
}
func makeRequestWithFallbackCookies(t *testing.T, h *handler.Handler, tokens *openid.Tokens) *http.Request {
writer := httptest.NewRecorder()
expiresIn := time.Minute
data := session.NewData("sid", tokens, nil)
err := h.SetSessionFallback(writer, nil, data, expiresIn)
assert.NoError(t, err)
cookies := writer.Result().Cookies()
externalSessionIDCookie := getCookieFromJar("wonderwall-1", cookies)
assert.NotNil(t, externalSessionIDCookie)
idTokenCookie := getCookieFromJar("wonderwall-2", cookies)
assert.NotNil(t, idTokenCookie)
accessTokenCookie := getCookieFromJar("wonderwall-3", cookies)
assert.NotNil(t, accessTokenCookie)
// make request with fallback session cookies set
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.AddCookie(externalSessionIDCookie)
r.AddCookie(idTokenCookie)
r.AddCookie(accessTokenCookie)
return r
}
func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie) {
expired := getCookieFromJar(cookieName, cookies)
assert.NotNil(t, expired)
assert.Less(t, expired.MaxAge, 0)
assert.True(t, expired.Expires.Before(time.Now()))
assert.Empty(t, expired.Value)
}
func assertCookieExists(t *testing.T, h *handler.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
desiredCookie := getCookieFromJar(cookieName, cookies)
assert.NotNil(t, desiredCookie)
ciphertext, err := base64.StdEncoding.DecodeString(desiredCookie.Value)
assert.NoError(t, err)
plainbytes, err := h.Crypter.Decrypt(ciphertext)
assert.NoError(t, err)
assert.Equal(t, expectedValue, string(plainbytes))
}
func makeTokens(provider *mock.TestProvider) *openid.Tokens {
jwks := *provider.PrivateJwkSet()
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
if err != nil {
log.Fatalf("getting public jwk set: %+v", err)
}
signer, ok := jwks.Key(0)
if !ok {
log.Fatalf("getting signer")
}
idToken := jwtlib.New()
idToken.Set("jti", "id-token-jti")
signedIdToken, err := jwtlib.Sign(idToken, jwtlib.WithKey(jwa.RS256, signer))
if err != nil {
log.Fatalf("signing id_token: %+v", err)
}
parsedIdToken, err := jwtlib.Parse(signedIdToken, jwtlib.WithKeySet(*jwksPublic))
if err != nil {
log.Fatalf("parsing signed id_token: %+v", err)
}
accessToken := "some-access-token"
return &openid.Tokens{
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
AccessToken: accessToken,
}
}