mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-19 14:56:52 +00:00
refactor: clean up tests
Co-Authored-By: Youssef Bel Mekki <youssef.bel.mekki@nav.no>
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -19,7 +17,6 @@ type identityProviderHandler struct {
|
||||
Codes map[string]authorizeRequest
|
||||
Provider TestProvider
|
||||
Sessions map[string]string
|
||||
SessionStates map[string]string
|
||||
}
|
||||
|
||||
func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler {
|
||||
@@ -27,7 +24,6 @@ func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler
|
||||
Codes: make(map[string]authorizeRequest),
|
||||
Provider: provider,
|
||||
Sessions: make(map[string]string),
|
||||
SessionStates: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +40,6 @@ type tokenResponse struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
SessionState string `json:"session_state"`
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
@@ -95,7 +90,8 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
|
||||
v.Set("code", code)
|
||||
v.Set("state", state)
|
||||
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
|
||||
v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host)))
|
||||
sessionID := uuid.New().String()
|
||||
v.Set("session_state", sessionID)
|
||||
}
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
@@ -172,9 +168,6 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
v := url.Values{}
|
||||
v.Set("error", "Unauthenticated")
|
||||
v.Set("error_description", "invalid client assertion")
|
||||
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
|
||||
v.Set("session_state", ip.SessionStates[clientID])
|
||||
}
|
||||
v.Encode()
|
||||
w.Write([]byte(fmt.Sprintf(v.Encode()+"%+v", err)))
|
||||
return
|
||||
@@ -193,7 +186,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
sid := uuid.New().String()
|
||||
sessionID := uuid.New().String()
|
||||
ip.Sessions[sessionID] = clientID
|
||||
|
||||
idToken := jwt.New()
|
||||
idToken.Set("sub", sub)
|
||||
@@ -206,9 +200,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
idToken.Set("exp", time.Now().Unix()+expires)
|
||||
|
||||
// If the sid claim should be in token and in active session
|
||||
if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
|
||||
idToken.Set("sid", sid)
|
||||
ip.Sessions[sid] = clientID
|
||||
if ip.Provider.OpenIDConfiguration.SidClaimRequired() {
|
||||
idToken.Set("sid", sessionID)
|
||||
}
|
||||
|
||||
signedIdToken, err := ip.signToken(idToken)
|
||||
@@ -225,36 +218,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
|
||||
ExpiresIn: expires,
|
||||
}
|
||||
|
||||
if ip.Provider.OpenIDConfiguration.SessionStateRequired() {
|
||||
sessionState := ip.SessionStates[clientID]
|
||||
token.SessionState = sessionState
|
||||
ip.Sessions[sessionState] = clientID
|
||||
}
|
||||
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) generateSessionState(state, originUrl string) string {
|
||||
// Here, the session_state is calculated in this particular way,
|
||||
// but it is entirely up to the OP how to do it under the
|
||||
// requirements defined in this specification.
|
||||
clientId := ip.Provider.ClientConfiguration.GetClientID()
|
||||
salt := "some-salt"
|
||||
saltedString := fmt.Sprintf("%s %s %s %s", clientId, state, originUrl, salt)
|
||||
session := NewSHA256([]byte(saltedString))
|
||||
sessionState := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s.%s", session, NewSHA256([]byte(salt)))))
|
||||
ip.SessionStates[clientId] = sessionState
|
||||
return sessionState
|
||||
|
||||
}
|
||||
|
||||
func NewSHA256(data []byte) []byte {
|
||||
hash := sha256.Sum256(data)
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func (ip *identityProviderHandler) GetClientID(sessionID string) string {
|
||||
return ip.Sessions[sessionID]
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *identityProviderHandler) {
|
||||
func IdentityProviderServer() (*httptest.Server, TestProvider) {
|
||||
provider := NewTestProvider()
|
||||
handler := newIdentityProviderHandler(provider)
|
||||
router := identityProviderRouter(handler)
|
||||
@@ -18,14 +18,7 @@ func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *ident
|
||||
provider.OpenIDConfiguration.TokenEndpoint = server.URL + "/token"
|
||||
provider.OpenIDConfiguration.EndSessionEndpoint = server.URL + "/endsession"
|
||||
|
||||
if iframe {
|
||||
provider.OpenIDConfiguration.CheckSessionIframe = server.URL + "/checksession"
|
||||
} else {
|
||||
provider.OpenIDConfiguration.FrontchannelLogoutSupported = true
|
||||
provider.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true
|
||||
}
|
||||
|
||||
return server, provider, handler
|
||||
return server, provider
|
||||
}
|
||||
|
||||
func identityProviderRouter(ip *identityProviderHandler) chi.Router {
|
||||
|
||||
@@ -31,6 +31,17 @@ func (p TestProvider) PrivateJwkSet() *jwk.Set {
|
||||
return &p.JwksPair.Private
|
||||
}
|
||||
|
||||
func (p TestProvider) WithFrontChannelLogoutSupport() TestProvider {
|
||||
p.OpenIDConfiguration.FrontchannelLogoutSupported = true
|
||||
p.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true
|
||||
return p
|
||||
}
|
||||
|
||||
func (p TestProvider) WithCheckSessionIFrameSupport(url string) TestProvider {
|
||||
p.OpenIDConfiguration.CheckSessionIframe = url
|
||||
return p
|
||||
}
|
||||
|
||||
func NewTestProvider() TestProvider {
|
||||
jwksPair, err := crypto.NewJwkSet()
|
||||
if err != nil {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -45,7 +44,7 @@ func newHandler(provider openid.Provider) *router.Handler {
|
||||
}
|
||||
|
||||
func TestHandler_Login(t *testing.T) {
|
||||
idpserver, idp, _ := mock.IdentityProviderServer(false)
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
h := newHandler(idp)
|
||||
r := router.New(h)
|
||||
|
||||
@@ -101,7 +100,7 @@ func TestHandler_Login(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
idpserver, idp, _ := mock.IdentityProviderServer(false)
|
||||
idpserver, idp := mock.IdentityProviderServer()
|
||||
|
||||
h := newHandler(idp)
|
||||
r := router.New(h)
|
||||
@@ -196,7 +195,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
_, idp, idpHandler := mock.IdentityProviderServer(false)
|
||||
_, idp := mock.IdentityProviderServer()
|
||||
idp.WithFrontChannelLogoutSupport()
|
||||
|
||||
h := newHandler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
@@ -252,9 +253,6 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
sid, err := h.Crypter.Decrypt(ciphertext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientID := idpHandler.GetClientID(parseSessionID(sid))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID)
|
||||
|
||||
frontchannelLogoutURL, err := url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -271,8 +269,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestHandler_CheckSessionIframe(t *testing.T) {
|
||||
_, idp, idpHandler := mock.IdentityProviderServer(true)
|
||||
func TestHandler_SessionStateRequired(t *testing.T) {
|
||||
idpServer, idp := mock.IdentityProviderServer()
|
||||
idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession")
|
||||
h := newHandler(idp)
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
@@ -311,24 +310,9 @@ func TestHandler_CheckSessionIframe(t *testing.T) {
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to callback
|
||||
resp, err = client.Get(callbackURL.String())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies := client.Jar.Cookies(callbackURL)
|
||||
sessionCookie := getCookieFromJar(router.SessionCookieName, cookies)
|
||||
|
||||
assert.NotNil(t, sessionCookie)
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(sessionCookie.Value)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sessionState, err := h.Crypter.Decrypt(ciphertext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientID := idpHandler.GetClientID(parseSessionID(sessionState))
|
||||
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID)
|
||||
params := callbackURL.Query()
|
||||
sessionState := params.Get("session_state")
|
||||
assert.NotEmpty(t, sessionState)
|
||||
}
|
||||
|
||||
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
||||
@@ -340,7 +324,3 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSessionID(sessionID []byte) string {
|
||||
return strings.Split(string(sessionID), ":")[2]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user