diff --git a/pkg/auth/assertion.go b/pkg/auth/assertion.go deleted file mode 100644 index baacb34..0000000 --- a/pkg/auth/assertion.go +++ /dev/null @@ -1,48 +0,0 @@ -package auth - -import ( - "fmt" - "time" - - "github.com/google/uuid" - "github.com/lestrrat-go/jwx/jwa" - "github.com/lestrrat-go/jwx/jwk" - "github.com/lestrrat-go/jwx/jwt" - - "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/token" -) - -func ClientAssertion(cfg config.IDPorten, expiration time.Duration) (string, error) { - key, err := jwk.ParseKey([]byte(cfg.ClientJWK)) - if err != nil { - return "", fmt.Errorf("parsing client JWK: %w", err) - } - - iat := time.Now() - exp := iat.Add(expiration) - - errs := make([]error, 0) - - tok := jwt.New() - errs = append(errs, tok.Set(jwt.IssuerKey, cfg.ClientID)) - errs = append(errs, tok.Set(jwt.SubjectKey, cfg.ClientID)) - errs = append(errs, tok.Set(jwt.AudienceKey, cfg.WellKnown.Issuer)) - errs = append(errs, tok.Set("scope", token.ScopeOpenID)) - errs = append(errs, tok.Set(jwt.IssuedAtKey, iat)) - errs = append(errs, tok.Set(jwt.ExpirationKey, exp)) - errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String())) - - for _, err := range errs { - if err != nil { - return "", fmt.Errorf("setting claim for client assertion: %w", err) - } - } - - encoded, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key) - if err != nil { - return "", fmt.Errorf("signing client assertion: %w", err) - } - - return string(encoded), nil -} diff --git a/pkg/openid/client_configuration.go b/pkg/openid/client_configuration.go new file mode 100644 index 0000000..d8b2ca4 --- /dev/null +++ b/pkg/openid/client_configuration.go @@ -0,0 +1,23 @@ +package openid + +import ( + "github.com/lestrrat-go/jwx/jwk" + + "github.com/nais/wonderwall/pkg/scopes" +) + +type ClientConfiguration interface { + GetClientID() string + GetClientJWK() jwk.Key + GetPostLogoutRedirectURI() string + GetRedirectURI() string + GetScopes() scopes.Scopes + GetACRValues() OptionalConfiguration + GetUILocales() OptionalConfiguration + GetWellKnownURL() string +} + +type OptionalConfiguration struct { + Enabled bool `json:"enabled"` + Value string `json:"value"` +} diff --git a/pkg/config/wellknown.go b/pkg/openid/configuration.go similarity index 74% rename from pkg/config/wellknown.go rename to pkg/openid/configuration.go index c3f48d6..a1bc17e 100644 --- a/pkg/config/wellknown.go +++ b/pkg/openid/configuration.go @@ -1,11 +1,15 @@ -package config +package openid import ( + "context" "encoding/json" + "fmt" "net/http" + + "github.com/lestrrat-go/jwx/jwk" ) -type IDPortenWellKnown struct { +type Configuration struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` PushedAuthorizationRequestEndpoint string `json:"pushed_authorization_request_endpoint"` @@ -42,15 +46,28 @@ func (in Supported) Contains(value string) bool { return false } -func (c *Config) FetchWellKnownConfig() error { - response, err := http.Get(c.IDPorten.WellKnownURL) +func FetchWellKnownConfig(wellKnownURI string) (*Configuration, error) { + response, err := http.Get(wellKnownURI) if err != nil { - return err + return nil, fmt.Errorf("fetching well known configuration: %w", err) } - // can this play with viper in any way? - if err := json.NewDecoder(response.Body).Decode(&c.IDPorten.WellKnown); err != nil { - return err + var cfg Configuration + if err := json.NewDecoder(response.Body).Decode(&cfg); err != nil { + return nil, fmt.Errorf("decoding well known configuration: %w", err) } - return nil + + return &cfg, nil +} + +func (c *Configuration) FetchJwkSet(ctx context.Context) (*jwk.Set, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + jwkSet, err := jwk.Fetch(ctx, c.JwksURI) + if err != nil { + return nil, fmt.Errorf("fetching jwks: %w", err) + } + + return &jwkSet, nil } diff --git a/pkg/auth/login.go b/pkg/openid/login.go similarity index 98% rename from pkg/auth/login.go rename to pkg/openid/login.go index d785a40..6e819f5 100644 --- a/pkg/auth/login.go +++ b/pkg/openid/login.go @@ -1,4 +1,4 @@ -package auth +package openid import ( "crypto/rand" diff --git a/pkg/request/request.go b/pkg/request/request.go index 8779122..fe185a5 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -9,6 +9,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/openid" ) var ( @@ -40,7 +41,7 @@ func CanonicalRedirectURL(r *http.Request) string { // LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found. // The value must exist in the supplied list of supported values. -func LoginURLParameter(r *http.Request, parameter, fallback string, supported config.Supported) (string, error) { +func LoginURLParameter(r *http.Request, parameter, fallback string, supported openid.Supported) (string, error) { value := r.URL.Query().Get(parameter) if len(value) == 0 { diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go index 5427cda..a5d814a 100644 --- a/pkg/request/request_test.go +++ b/pkg/request/request_test.go @@ -1,13 +1,15 @@ package request_test import ( - "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/request" - "github.com/stretchr/testify/assert" "net/http" "net/url" "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/request" ) func TestCanonicalRedirectURL(t *testing.T) { @@ -33,7 +35,7 @@ func TestLoginURLParameter(t *testing.T) { name string parameter string fallback string - supported config.Supported + supported openid.Supported url string expectErr error expected string @@ -67,7 +69,7 @@ func TestLoginURLParameter(t *testing.T) { { name: "no supported values should return error", url: "http://localhost:8080/oauth2/login", - supported: config.Supported{""}, + supported: openid.Supported{""}, expectErr: request.InvalidLoginParameterError, }, } { @@ -78,7 +80,7 @@ func TestLoginURLParameter(t *testing.T) { // default test values parameter := "param" fallback := "valid" - supported := config.Supported{"valid", "valid2"} + supported := openid.Supported{"valid", "valid2"} if len(test.parameter) > 0 { parameter = test.parameter diff --git a/pkg/router/handler_login.go b/pkg/router/handler_login.go index 62645e0..8fb2faa 100644 --- a/pkg/router/handler_login.go +++ b/pkg/router/handler_login.go @@ -3,15 +3,15 @@ package router import ( "errors" "fmt" - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/request" "net/http" - "github.com/nais/wonderwall/pkg/auth" + "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/request" ) func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { - params, err := auth.GenerateLoginParameters() + params, err := openid.GenerateLoginParameters() if err != nil { h.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err)) return diff --git a/pkg/router/login_url.go b/pkg/router/login_url.go index 761574a..d75d931 100644 --- a/pkg/router/login_url.go +++ b/pkg/router/login_url.go @@ -3,12 +3,11 @@ package router import ( "errors" "fmt" - "github.com/nais/wonderwall/pkg/request" "net/http" "net/url" - "github.com/nais/wonderwall/pkg/auth" - "github.com/nais/wonderwall/pkg/token" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/request" ) var ( @@ -16,8 +15,8 @@ var ( InvalidLocaleError = errors.New("InvalidLocale") ) -func (h *Handler) LoginURL(r *http.Request, params *auth.Parameters) (string, error) { - u, err := url.Parse(h.Config.IDPorten.WellKnown.AuthorizationEndpoint) +func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string, error) { + u, err := url.Parse(h.Provider.GetOpenIDConfiguration().AuthorizationEndpoint) if err != nil { return "", err } diff --git a/pkg/router/login_url_test.go b/pkg/router/login_url_test.go index 5568b7c..ea3690a 100644 --- a/pkg/router/login_url_test.go +++ b/pkg/router/login_url_test.go @@ -2,11 +2,14 @@ package router_test import ( "errors" - "github.com/nais/wonderwall/pkg/auth" - "github.com/nais/wonderwall/pkg/router" - "github.com/stretchr/testify/assert" "net/http" "testing" + + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router" ) func TestLoginURL(t *testing.T) { @@ -48,7 +51,7 @@ func TestLoginURL(t *testing.T) { req, err := http.NewRequest("GET", test.url, nil) assert.NoError(t, err) - params, err := auth.GenerateLoginParameters() + params, err := openid.GenerateLoginParameters() assert.NoError(t, err) handler := handler(cfg) diff --git a/pkg/router/session.go b/pkg/router/session.go index 32fdf61..f20dc36 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -45,12 +45,12 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return nil, fmt.Errorf("session not found in store: %w", err) } - log.Warnf("get session: store is unavailable; using cookie fallback: %+v", err) fallbackSessionData, err := h.GetSessionFallback(r) if err != nil { return nil, fmt.Errorf("fallback session not found: %w", err) } + log.Warnf("get session: store is unavailable: %+v; using cookie fallback", err) return fallbackSessionData, nil } @@ -97,11 +97,12 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external return nil } - log.Warnf("create session: store is unavailable; using cookie fallback: %+v", err) err = h.SetSessionFallback(w, sessionData, sessionLifetime) if err != nil { return fmt.Errorf("writing session to fallback store: %w", err) } + + log.Warnf("create session: store is unavailable: %+v; using cookie fallback", err) return nil }