mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
Access Tokens are not necessarily JWTs. We also don't have to validate them as we only pass it on as an opaque string. This also means that we don't log the JTI access tokens anymore. We also simplify handling of oidc callbacks.
197 lines
5.5 KiB
Go
197 lines
5.5 KiB
Go
package router_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/nais/wonderwall/pkg/mock"
|
|
"github.com/nais/wonderwall/pkg/openid"
|
|
"github.com/nais/wonderwall/pkg/router"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
)
|
|
|
|
func TestHandler_GetSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
tokens := makeTokens(idp.Provider)
|
|
rpHandler := idp.RelyingPartyHandler
|
|
|
|
t.Run("request without fallback session cookies", func(t *testing.T) {
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
_, err := rpHandler.GetSessionFallback(w, r)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("request with fallback session cookies", func(t *testing.T) {
|
|
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
|
|
w := httptest.NewRecorder()
|
|
sessionData, err := rpHandler.GetSessionFallback(w, r)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "sid", sessionData.ExternalSessionID)
|
|
assert.Equal(t, tokens.AccessToken, sessionData.AccessToken)
|
|
assert.Equal(t, tokens.IDToken.GetSerialized(), sessionData.IDToken)
|
|
assert.Equal(t, "id-token-jti", sessionData.IDTokenJwtID)
|
|
assert.Empty(t, sessionData.RefreshToken)
|
|
})
|
|
}
|
|
|
|
func TestHandler_SetSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
tokens := makeTokens(idp.Provider)
|
|
rpHandler := idp.RelyingPartyHandler
|
|
|
|
// request should set session cookies in response
|
|
writer := httptest.NewRecorder()
|
|
expiresIn := time.Minute
|
|
data := session.NewData("sid", tokens, nil)
|
|
err := rpHandler.SetSessionFallback(writer, nil, data, expiresIn)
|
|
assert.NoError(t, err)
|
|
|
|
cookies := writer.Result().Cookies()
|
|
|
|
for _, test := range []struct {
|
|
cookieName string
|
|
want string
|
|
}{
|
|
{
|
|
cookieName: "wonderwall-1",
|
|
want: "sid",
|
|
},
|
|
{
|
|
cookieName: "wonderwall-2",
|
|
want: tokens.IDToken.GetSerialized(),
|
|
},
|
|
{
|
|
cookieName: "wonderwall-3",
|
|
want: tokens.AccessToken,
|
|
},
|
|
} {
|
|
assertCookieExists(t, rpHandler, test.cookieName, test.want, cookies)
|
|
}
|
|
}
|
|
|
|
func TestHandler_DeleteSessionFallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpHandler := idp.RelyingPartyHandler
|
|
tokens := makeTokens(idp.Provider)
|
|
|
|
t.Run("expire cookies if they are set", func(t *testing.T) {
|
|
r := makeRequestWithFallbackCookies(t, rpHandler, tokens)
|
|
writer := httptest.NewRecorder()
|
|
rpHandler.DeleteSessionFallback(writer, r)
|
|
cookies := writer.Result().Cookies()
|
|
|
|
assert.NotEmpty(t, cookies)
|
|
assert.Len(t, cookies, 3)
|
|
|
|
assertCookieExpired(t, "wonderwall-1", cookies)
|
|
assertCookieExpired(t, "wonderwall-2", cookies)
|
|
assertCookieExpired(t, "wonderwall-3", cookies)
|
|
})
|
|
|
|
t.Run("skip expiring cookies if they are not set", func(t *testing.T) {
|
|
writer := httptest.NewRecorder()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rpHandler.DeleteSessionFallback(writer, r)
|
|
cookies := writer.Result().Cookies()
|
|
|
|
assert.Empty(t, cookies)
|
|
})
|
|
}
|
|
|
|
func makeRequestWithFallbackCookies(t *testing.T, h *router.Handler, tokens *openid.Tokens) *http.Request {
|
|
writer := httptest.NewRecorder()
|
|
expiresIn := time.Minute
|
|
data := session.NewData("sid", tokens, nil)
|
|
err := h.SetSessionFallback(writer, nil, data, expiresIn)
|
|
assert.NoError(t, err)
|
|
|
|
cookies := writer.Result().Cookies()
|
|
|
|
externalSessionIDCookie := getCookieFromJar("wonderwall-1", cookies)
|
|
assert.NotNil(t, externalSessionIDCookie)
|
|
idTokenCookie := getCookieFromJar("wonderwall-2", cookies)
|
|
assert.NotNil(t, idTokenCookie)
|
|
accessTokenCookie := getCookieFromJar("wonderwall-3", cookies)
|
|
assert.NotNil(t, accessTokenCookie)
|
|
|
|
// make request with fallback session cookies set
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.AddCookie(externalSessionIDCookie)
|
|
r.AddCookie(idTokenCookie)
|
|
r.AddCookie(accessTokenCookie)
|
|
|
|
return r
|
|
}
|
|
|
|
func assertCookieExpired(t *testing.T, cookieName string, cookies []*http.Cookie) {
|
|
expired := getCookieFromJar(cookieName, cookies)
|
|
assert.NotNil(t, expired)
|
|
assert.Less(t, expired.MaxAge, 0)
|
|
assert.True(t, expired.Expires.Before(time.Now()))
|
|
assert.Empty(t, expired.Value)
|
|
}
|
|
|
|
func assertCookieExists(t *testing.T, h *router.Handler, cookieName, expectedValue string, cookies []*http.Cookie) {
|
|
desiredCookie := getCookieFromJar(cookieName, cookies)
|
|
assert.NotNil(t, desiredCookie)
|
|
|
|
ciphertext, err := base64.StdEncoding.DecodeString(desiredCookie.Value)
|
|
assert.NoError(t, err)
|
|
|
|
plainbytes, err := h.Crypter.Decrypt(ciphertext)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedValue, string(plainbytes))
|
|
}
|
|
|
|
func makeTokens(provider mock.TestProvider) *openid.Tokens {
|
|
jwks := *provider.PrivateJwkSet()
|
|
jwksPublic, err := provider.GetPublicJwkSet(context.TODO())
|
|
if err != nil {
|
|
log.Fatalf("getting public jwk set: %+v", err)
|
|
}
|
|
|
|
signer, ok := jwks.Key(0)
|
|
if !ok {
|
|
log.Fatalf("getting signer")
|
|
}
|
|
|
|
idToken := jwtlib.New()
|
|
idToken.Set("jti", "id-token-jti")
|
|
|
|
signedIdToken, err := jwtlib.Sign(idToken, jwtlib.WithKey(jwa.RS256, signer))
|
|
if err != nil {
|
|
log.Fatalf("signing id_token: %+v", err)
|
|
}
|
|
|
|
parsedIdToken, err := jwtlib.Parse(signedIdToken, jwtlib.WithKeySet(*jwksPublic))
|
|
if err != nil {
|
|
log.Fatalf("parsing signed id_token: %+v", err)
|
|
}
|
|
|
|
accessToken := "some-access-token"
|
|
|
|
return &openid.Tokens{
|
|
IDToken: openid.NewIDToken(string(signedIdToken), parsedIdToken),
|
|
AccessToken: accessToken,
|
|
}
|
|
}
|