mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-07 00:46:56 +00:00
write callback test
Co-Authored-By: Trong Huu Nguyen <trong.huu.nguyen@nav.no>
This commit is contained in:
1
go.mod
1
go.mod
@@ -5,6 +5,7 @@ go 1.16
|
||||
require (
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/go-chi/chi v1.5.4
|
||||
github.com/google/uuid v1.1.2
|
||||
github.com/lestrrat-go/jwx v1.2.5
|
||||
github.com/nais/liberator v0.0.0-20210809103005-edb0141d646d
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
|
||||
|
||||
1
go.sum
1
go.sum
@@ -254,6 +254,7 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
|
||||
@@ -4,10 +4,16 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
type IDPorten struct {
|
||||
@@ -50,26 +56,132 @@ type TokenJSON struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int32 `json:"expires_in"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IDPorten) signToken(token jwt.Token) (string, error) {
|
||||
signer, ok := ip.Keys.Get(0)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("could not get signer")
|
||||
}
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwa.RS256, signer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(signedToken), nil
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("malformed payload?"))
|
||||
return
|
||||
}
|
||||
|
||||
code := r.PostForm.Get("code")
|
||||
|
||||
if len(code) == 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing code"))
|
||||
return
|
||||
}
|
||||
|
||||
auth, ok := ip.Codes[code]
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("no matching code"))
|
||||
return
|
||||
}
|
||||
|
||||
expires := int64(1200)
|
||||
|
||||
sub := uuid.New().String()
|
||||
sid := uuid.New().String()
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("sub", sub)
|
||||
accessToken.Set("iss", cfg.WellKnown.Issuer)
|
||||
accessToken.Set("acr", auth.AcrLevel)
|
||||
accessToken.Set("iat", time.Now().Unix())
|
||||
accessToken.Set("exp", time.Now().Unix()+expires)
|
||||
signedAccessToken, err := ip.signToken(accessToken)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not sign access token: " + err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
idToken := jwt.New()
|
||||
idToken.Set("sub", sub)
|
||||
idToken.Set("iss", cfg.WellKnown.Issuer)
|
||||
idToken.Set("aud", cfg.ClientID)
|
||||
idToken.Set("locale", auth.Locale)
|
||||
idToken.Set("nonce", auth.Nonce)
|
||||
idToken.Set("acr", auth.AcrLevel)
|
||||
idToken.Set("iat", time.Now().Unix())
|
||||
idToken.Set("exp", time.Now().Unix()+expires)
|
||||
idToken.Set("sid", sid)
|
||||
|
||||
signedIdToken, err := ip.signToken(idToken)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not sign access token: " + err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
ip.Sessions[sid] = cfg.ClientID
|
||||
// fixme: generate valid access token and id token; sign them with the correct key
|
||||
token := &TokenJSON{
|
||||
AccessToken: "access-token",
|
||||
TokenType: "token-type",
|
||||
RefreshToken: "refresh-token",
|
||||
IDToken: "id-token",
|
||||
ExpiresIn: 1200,
|
||||
AccessToken: string(signedAccessToken),
|
||||
TokenType: "Bearer",
|
||||
IDToken: string(signedIdToken),
|
||||
ExpiresIn: expires,
|
||||
}
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(token)
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Token(w http.ResponseWriter, r *http.Request) {
|
||||
func (ip *IDPorten) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
state := query.Get("state")
|
||||
redirect := query.Get("redirect_uri")
|
||||
acrLevel := query.Get("acr_values")
|
||||
codeChallenge := query.Get("code_challenge")
|
||||
locale := query.Get("ui_locales")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if state == "" || redirect == "" || acrLevel == "" || codeChallenge == "" || locale == "" || nonce == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required fields"))
|
||||
return
|
||||
}
|
||||
|
||||
code := uuid.New().String()
|
||||
ip.Codes[code] = AuthRequest{
|
||||
AcrLevel: acrLevel,
|
||||
CodeChallenge: codeChallenge,
|
||||
Locale: locale,
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
u, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("couldn't parse redirect uri"))
|
||||
return
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("code", code)
|
||||
v.Set("state", state)
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (ip *IDPorten) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -77,7 +189,13 @@ func (ip *IDPorten) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (ip *IDPorten) Jwks(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
publicSet, err := jwk.PublicSetOf(ip.Keys)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("could not create public set: " + err.Error()))
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(publicSet)
|
||||
}
|
||||
|
||||
func (ip *IDPorten) WellKnown(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -86,8 +204,8 @@ func (ip *IDPorten) WellKnown(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func idportenRouter(ip *IDPorten) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/authorize", ip.Authorize)
|
||||
r.Get("/token", ip.Token)
|
||||
r.Get("/authorize", ip.Authorize)
|
||||
r.Post("/token", ip.Token)
|
||||
r.Get("/endsession", ip.EndSession)
|
||||
r.Get("/jwks", ip.Jwks)
|
||||
r.Get("/.well-known/openid-configuration", ip.WellKnown)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,10 +20,12 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
)
|
||||
|
||||
const clientID = "clientid"
|
||||
|
||||
var encryptionKey = []byte(`G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`) // 256 bits AES
|
||||
|
||||
var cfg = config.IDPorten{
|
||||
ClientID: "clientid",
|
||||
ClientID: clientID,
|
||||
ClientJWK: `
|
||||
{
|
||||
"kty": "RSA",
|
||||
@@ -37,7 +41,7 @@ var cfg = config.IDPorten{
|
||||
"x5t": "9rJ_0ziKoGNjSS_l11hn0yQxEqg"
|
||||
}
|
||||
`,
|
||||
RedirectURI: "http://localhost/redirect",
|
||||
RedirectURI: "http://localhost/callback",
|
||||
WellKnownURL: "",
|
||||
WellKnown: config.IDPortenWellKnown{
|
||||
Issuer: "issuer",
|
||||
@@ -48,8 +52,13 @@ var cfg = config.IDPorten{
|
||||
PostLogoutRedirectURI: "",
|
||||
}
|
||||
|
||||
var clients = map[string]string{
|
||||
clientID: "http://localhost/oauth2/logout/frontchannel",
|
||||
}
|
||||
var idp = NewIDPorten(clients)
|
||||
|
||||
func handler() *router.Handler {
|
||||
return &router.Handler{
|
||||
handler := router.Handler{
|
||||
Config: cfg,
|
||||
OauthConfig: oauth2.Config{
|
||||
ClientID: "client-id",
|
||||
@@ -65,6 +74,8 @@ func handler() *router.Handler {
|
||||
UpstreamHost: "",
|
||||
IdTokenVerifier: nil,
|
||||
}
|
||||
handler.Init()
|
||||
return &handler
|
||||
}
|
||||
|
||||
func TestLoginURL(t *testing.T) {
|
||||
@@ -76,15 +87,21 @@ func TestLoginURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_Login(t *testing.T) {
|
||||
r := router.New(handler())
|
||||
server := httptest.NewServer(r)
|
||||
h := handler()
|
||||
r := router.New(h)
|
||||
|
||||
server := httptest.NewServer(r)
|
||||
client := server.Client()
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
req, err := client.Get(server.URL + "/oauth2/login")
|
||||
|
||||
idprouter := idportenRouter(idp)
|
||||
idpserver := httptest.NewServer(idprouter)
|
||||
|
||||
h.Config.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize"
|
||||
|
||||
req, err := client.Get(server.URL + "/oauth2/login")
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
|
||||
@@ -92,14 +109,26 @@ func TestHandler_Login(t *testing.T) {
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "localhost:1234", u.Host)
|
||||
assert.Equal(t, idpserver.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host))
|
||||
assert.Equal(t, "/authorize", u.Path)
|
||||
assert.Equal(t, cfg.SecurityLevel, u.Query().Get("acr_values"))
|
||||
assert.Equal(t, cfg.Locale, u.Query().Get("ui_locales"))
|
||||
assert.Equal(t, cfg.ClientID, u.Query().Get("client_id"))
|
||||
assert.Equal(t, cfg.RedirectURI, u.Query().Get("redirect_uri"))
|
||||
assert.NotEmpty(t, u.Query().Get("state"))
|
||||
assert.NotEmpty(t, u.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, u.Query().Get("code_challenge"))
|
||||
|
||||
req, err = client.Get(u.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
|
||||
location = req.Header.Get("location")
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, u.Query().Get("state"), callbackURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, callbackURL.Query().Get("code"))
|
||||
}
|
||||
|
||||
func TestHandler_Callback(t *testing.T) {
|
||||
@@ -107,13 +136,16 @@ func TestHandler_Callback(t *testing.T) {
|
||||
r := router.New(h)
|
||||
server := httptest.NewServer(r)
|
||||
|
||||
clients := map[string]string{
|
||||
h.Config.ClientID: server.URL + "/oauth2/logout/frontchannel",
|
||||
}
|
||||
idp := NewIDPorten(clients)
|
||||
idprouter := idportenRouter(idp)
|
||||
idpserver := httptest.NewServer(idprouter)
|
||||
h.OauthConfig.Endpoint.TokenURL = idpserver.URL + "/authorize"
|
||||
h.OauthConfig.Endpoint.TokenURL = idpserver.URL + "/token"
|
||||
h.Config.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize"
|
||||
h.Config.RedirectURI = server.URL + "/oauth2/callback"
|
||||
h.IdTokenVerifier = oidc.NewVerifier(
|
||||
cfg.WellKnown.Issuer,
|
||||
oidc.NewRemoteKeySet(context.Background(), idpserver.URL+"/jwks"),
|
||||
&oidc.Config{ClientID: cfg.ClientID},
|
||||
)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
assert.NoError(t, err)
|
||||
@@ -127,32 +159,35 @@ func TestHandler_Callback(t *testing.T) {
|
||||
// First, run /oauth2/login to set cookies
|
||||
req, err := client.Get(server.URL + "/oauth2/login")
|
||||
assert.NoError(t, err)
|
||||
req.Body.Close()
|
||||
defer req.Body.Close()
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
// Get authorization URL
|
||||
location := req.Header.Get("location")
|
||||
u, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
u.Path = "/oauth2/callback"
|
||||
v := &url.Values{}
|
||||
|
||||
mapping := map[string]string{
|
||||
router.NonceCookieName: "nonce",
|
||||
router.StateCookieName: "state",
|
||||
router.CodeVerifierCookieName: "code_verifier",
|
||||
}
|
||||
for _, cookie := range req.Cookies() {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
plaintext, err := h.Crypter.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
v.Set(mapping[cookie.Name], string(plaintext))
|
||||
}
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
// Follow redirect to authorize with idporten
|
||||
req, err = client.Get(u.String())
|
||||
assert.NoError(t, err)
|
||||
defer req.Body.Close()
|
||||
|
||||
// Get callback URL after successful auth
|
||||
location = req.Header.Get("location")
|
||||
callbackURL, err := url.Parse(location)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to callback
|
||||
req, err = client.Get(callbackURL.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
cookies := client.Jar.Cookies(callbackURL)
|
||||
var sessionCookie *http.Cookie
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == router.SessionCookieName {
|
||||
sessionCookie = cookie
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotNil(t, sessionCookie)
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user