refactor: clean up tests

Co-Authored-By: Youssef Bel Mekki <youssef.bel.mekki@nav.no>
This commit is contained in:
Trong Huu Nguyen
2022-01-25 15:58:16 +01:00
parent 24cae11ba2
commit b40dbffa19
4 changed files with 30 additions and 81 deletions

View File

@@ -1,8 +1,6 @@
package mock
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
@@ -19,7 +17,6 @@ type identityProviderHandler struct {
Codes map[string]authorizeRequest
Provider TestProvider
Sessions map[string]string
SessionStates map[string]string
}
func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler {
@@ -27,7 +24,6 @@ func newIdentityProviderHandler(provider TestProvider) *identityProviderHandler
Codes: make(map[string]authorizeRequest),
Provider: provider,
Sessions: make(map[string]string),
SessionStates: make(map[string]string),
}
}
@@ -44,7 +40,6 @@ type tokenResponse struct {
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
SessionState string `json:"session_state"`
}
func (ip *identityProviderHandler) signToken(token jwt.Token) (string, error) {
@@ -95,7 +90,8 @@ func (ip *identityProviderHandler) Authorize(w http.ResponseWriter, r *http.Requ
v.Set("code", code)
v.Set("state", state)
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
v.Set("session_state", ip.generateSessionState(state, fmt.Sprintf("%s://%s", u.Scheme, u.Host)))
sessionID := uuid.New().String()
v.Set("session_state", sessionID)
}
u.RawQuery = v.Encode()
@@ -172,9 +168,6 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
v := url.Values{}
v.Set("error", "Unauthenticated")
v.Set("error_description", "invalid client assertion")
if ip.Provider.GetOpenIDConfiguration().SessionStateRequired() {
v.Set("session_state", ip.SessionStates[clientID])
}
v.Encode()
w.Write([]byte(fmt.Sprintf(v.Encode()+"%+v", err)))
return
@@ -193,7 +186,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
return
}
sid := uuid.New().String()
sessionID := uuid.New().String()
ip.Sessions[sessionID] = clientID
idToken := jwt.New()
idToken.Set("sub", sub)
@@ -206,9 +200,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
idToken.Set("exp", time.Now().Unix()+expires)
// If the sid claim should be in token and in active session
if !ip.Provider.OpenIDConfiguration.SessionStateRequired() || !ip.Provider.OpenIDConfiguration.SidClaimRequired() {
idToken.Set("sid", sid)
ip.Sessions[sid] = clientID
if ip.Provider.OpenIDConfiguration.SidClaimRequired() {
idToken.Set("sid", sessionID)
}
signedIdToken, err := ip.signToken(idToken)
@@ -225,36 +218,8 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request)
ExpiresIn: expires,
}
if ip.Provider.OpenIDConfiguration.SessionStateRequired() {
sessionState := ip.SessionStates[clientID]
token.SessionState = sessionState
ip.Sessions[sessionState] = clientID
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(token)
}
func (ip *identityProviderHandler) generateSessionState(state, originUrl string) string {
// Here, the session_state is calculated in this particular way,
// but it is entirely up to the OP how to do it under the
// requirements defined in this specification.
clientId := ip.Provider.ClientConfiguration.GetClientID()
salt := "some-salt"
saltedString := fmt.Sprintf("%s %s %s %s", clientId, state, originUrl, salt)
session := NewSHA256([]byte(saltedString))
sessionState := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s.%s", session, NewSHA256([]byte(salt)))))
ip.SessionStates[clientId] = sessionState
return sessionState
}
func NewSHA256(data []byte) []byte {
hash := sha256.Sum256(data)
return hash[:]
}
func (ip *identityProviderHandler) GetClientID(sessionID string) string {
return ip.Sessions[sessionID]
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/go-chi/chi/v5"
)
func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *identityProviderHandler) {
func IdentityProviderServer() (*httptest.Server, TestProvider) {
provider := NewTestProvider()
handler := newIdentityProviderHandler(provider)
router := identityProviderRouter(handler)
@@ -18,14 +18,7 @@ func IdentityProviderServer(iframe bool) (*httptest.Server, TestProvider, *ident
provider.OpenIDConfiguration.TokenEndpoint = server.URL + "/token"
provider.OpenIDConfiguration.EndSessionEndpoint = server.URL + "/endsession"
if iframe {
provider.OpenIDConfiguration.CheckSessionIframe = server.URL + "/checksession"
} else {
provider.OpenIDConfiguration.FrontchannelLogoutSupported = true
provider.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true
}
return server, provider, handler
return server, provider
}
func identityProviderRouter(ip *identityProviderHandler) chi.Router {

View File

@@ -31,6 +31,17 @@ func (p TestProvider) PrivateJwkSet() *jwk.Set {
return &p.JwksPair.Private
}
func (p TestProvider) WithFrontChannelLogoutSupport() TestProvider {
p.OpenIDConfiguration.FrontchannelLogoutSupported = true
p.OpenIDConfiguration.FrontchannelLogoutSessionSupported = true
return p
}
func (p TestProvider) WithCheckSessionIFrameSupport(url string) TestProvider {
p.OpenIDConfiguration.CheckSessionIframe = url
return p
}
func NewTestProvider() TestProvider {
jwksPair, err := crypto.NewJwkSet()
if err != nil {

View File

@@ -7,7 +7,6 @@ import (
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -45,7 +44,7 @@ func newHandler(provider openid.Provider) *router.Handler {
}
func TestHandler_Login(t *testing.T) {
idpserver, idp, _ := mock.IdentityProviderServer(false)
idpserver, idp := mock.IdentityProviderServer()
h := newHandler(idp)
r := router.New(h)
@@ -101,7 +100,7 @@ func TestHandler_Login(t *testing.T) {
}
func TestHandler_Callback_and_Logout(t *testing.T) {
idpserver, idp, _ := mock.IdentityProviderServer(false)
idpserver, idp := mock.IdentityProviderServer()
h := newHandler(idp)
r := router.New(h)
@@ -196,7 +195,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
}
func TestHandler_FrontChannelLogout(t *testing.T) {
_, idp, idpHandler := mock.IdentityProviderServer(false)
_, idp := mock.IdentityProviderServer()
idp.WithFrontChannelLogoutSupport()
h := newHandler(idp)
r := router.New(h)
server := httptest.NewServer(r)
@@ -252,9 +253,6 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
sid, err := h.Crypter.Decrypt(ciphertext)
assert.NoError(t, err)
clientID := idpHandler.GetClientID(parseSessionID(sid))
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID)
frontchannelLogoutURL, err := url.Parse(server.URL)
assert.NoError(t, err)
@@ -271,8 +269,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
defer resp.Body.Close()
}
func TestHandler_CheckSessionIframe(t *testing.T) {
_, idp, idpHandler := mock.IdentityProviderServer(true)
func TestHandler_SessionStateRequired(t *testing.T) {
idpServer, idp := mock.IdentityProviderServer()
idp.WithCheckSessionIFrameSupport(idpServer.URL + "/checksession")
h := newHandler(idp)
r := router.New(h)
server := httptest.NewServer(r)
@@ -311,24 +310,9 @@ func TestHandler_CheckSessionIframe(t *testing.T) {
callbackURL, err := url.Parse(location)
assert.NoError(t, err)
// Follow redirect to callback
resp, err = client.Get(callbackURL.String())
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := client.Jar.Cookies(callbackURL)
sessionCookie := getCookieFromJar(router.SessionCookieName, cookies)
assert.NotNil(t, sessionCookie)
ciphertext, err := base64.StdEncoding.DecodeString(sessionCookie.Value)
assert.NoError(t, err)
sessionState, err := h.Crypter.Decrypt(ciphertext)
assert.NoError(t, err)
clientID := idpHandler.GetClientID(parseSessionID(sessionState))
assert.Equal(t, idp.GetClientConfiguration().GetClientID(), clientID)
params := callbackURL.Query()
sessionState := params.Get("session_state")
assert.NotEmpty(t, sessionState)
}
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
@@ -340,7 +324,3 @@ func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
return nil
}
func parseSessionID(sessionID []byte) string {
return strings.Split(string(sessionID), ":")[2]
}