refactor: move openid related structs to own pkg

This commit is contained in:
Trong Huu Nguyen
2021-10-16 10:38:32 +02:00
parent e7d5a6073c
commit 2f0243b69a
10 changed files with 79 additions and 81 deletions

View File

@@ -1,48 +0,0 @@
package auth
import (
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/token"
)
func ClientAssertion(cfg config.IDPorten, expiration time.Duration) (string, error) {
key, err := jwk.ParseKey([]byte(cfg.ClientJWK))
if err != nil {
return "", fmt.Errorf("parsing client JWK: %w", err)
}
iat := time.Now()
exp := iat.Add(expiration)
errs := make([]error, 0)
tok := jwt.New()
errs = append(errs, tok.Set(jwt.IssuerKey, cfg.ClientID))
errs = append(errs, tok.Set(jwt.SubjectKey, cfg.ClientID))
errs = append(errs, tok.Set(jwt.AudienceKey, cfg.WellKnown.Issuer))
errs = append(errs, tok.Set("scope", token.ScopeOpenID))
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))
for _, err := range errs {
if err != nil {
return "", fmt.Errorf("setting claim for client assertion: %w", err)
}
}
encoded, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key)
if err != nil {
return "", fmt.Errorf("signing client assertion: %w", err)
}
return string(encoded), nil
}

View File

@@ -0,0 +1,23 @@
package openid
import (
"github.com/lestrrat-go/jwx/jwk"
"github.com/nais/wonderwall/pkg/scopes"
)
type ClientConfiguration interface {
GetClientID() string
GetClientJWK() jwk.Key
GetPostLogoutRedirectURI() string
GetRedirectURI() string
GetScopes() scopes.Scopes
GetACRValues() OptionalConfiguration
GetUILocales() OptionalConfiguration
GetWellKnownURL() string
}
type OptionalConfiguration struct {
Enabled bool `json:"enabled"`
Value string `json:"value"`
}

View File

@@ -1,11 +1,15 @@
package config
package openid
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/lestrrat-go/jwx/jwk"
)
type IDPortenWellKnown struct {
type Configuration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
PushedAuthorizationRequestEndpoint string `json:"pushed_authorization_request_endpoint"`
@@ -42,15 +46,28 @@ func (in Supported) Contains(value string) bool {
return false
}
func (c *Config) FetchWellKnownConfig() error {
response, err := http.Get(c.IDPorten.WellKnownURL)
func FetchWellKnownConfig(wellKnownURI string) (*Configuration, error) {
response, err := http.Get(wellKnownURI)
if err != nil {
return err
return nil, fmt.Errorf("fetching well known configuration: %w", err)
}
// can this play with viper in any way?
if err := json.NewDecoder(response.Body).Decode(&c.IDPorten.WellKnown); err != nil {
return err
var cfg Configuration
if err := json.NewDecoder(response.Body).Decode(&cfg); err != nil {
return nil, fmt.Errorf("decoding well known configuration: %w", err)
}
return nil
return &cfg, nil
}
func (c *Configuration) FetchJwkSet(ctx context.Context) (*jwk.Set, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
jwkSet, err := jwk.Fetch(ctx, c.JwksURI)
if err != nil {
return nil, fmt.Errorf("fetching jwks: %w", err)
}
return &jwkSet, nil
}

View File

@@ -1,4 +1,4 @@
package auth
package openid
import (
"crypto/rand"

View File

@@ -9,6 +9,7 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
)
var (
@@ -40,7 +41,7 @@ func CanonicalRedirectURL(r *http.Request) string {
// LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found.
// The value must exist in the supplied list of supported values.
func LoginURLParameter(r *http.Request, parameter, fallback string, supported config.Supported) (string, error) {
func LoginURLParameter(r *http.Request, parameter, fallback string, supported openid.Supported) (string, error) {
value := r.URL.Query().Get(parameter)
if len(value) == 0 {

View File

@@ -1,13 +1,15 @@
package request_test
import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/request"
"github.com/stretchr/testify/assert"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/request"
)
func TestCanonicalRedirectURL(t *testing.T) {
@@ -33,7 +35,7 @@ func TestLoginURLParameter(t *testing.T) {
name string
parameter string
fallback string
supported config.Supported
supported openid.Supported
url string
expectErr error
expected string
@@ -67,7 +69,7 @@ func TestLoginURLParameter(t *testing.T) {
{
name: "no supported values should return error",
url: "http://localhost:8080/oauth2/login",
supported: config.Supported{""},
supported: openid.Supported{""},
expectErr: request.InvalidLoginParameterError,
},
} {
@@ -78,7 +80,7 @@ func TestLoginURLParameter(t *testing.T) {
// default test values
parameter := "param"
fallback := "valid"
supported := config.Supported{"valid", "valid2"}
supported := openid.Supported{"valid", "valid2"}
if len(test.parameter) > 0 {
parameter = test.parameter

View File

@@ -3,15 +3,15 @@ package router
import (
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"github.com/nais/wonderwall/pkg/auth"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/request"
)
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
params, err := auth.GenerateLoginParameters()
params, err := openid.GenerateLoginParameters()
if err != nil {
h.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err))
return

View File

@@ -3,12 +3,11 @@ package router
import (
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"net/url"
"github.com/nais/wonderwall/pkg/auth"
"github.com/nais/wonderwall/pkg/token"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/request"
)
var (
@@ -16,8 +15,8 @@ var (
InvalidLocaleError = errors.New("InvalidLocale")
)
func (h *Handler) LoginURL(r *http.Request, params *auth.Parameters) (string, error) {
u, err := url.Parse(h.Config.IDPorten.WellKnown.AuthorizationEndpoint)
func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string, error) {
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().AuthorizationEndpoint)
if err != nil {
return "", err
}

View File

@@ -2,11 +2,14 @@ package router_test
import (
"errors"
"github.com/nais/wonderwall/pkg/auth"
"github.com/nais/wonderwall/pkg/router"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/router"
)
func TestLoginURL(t *testing.T) {
@@ -48,7 +51,7 @@ func TestLoginURL(t *testing.T) {
req, err := http.NewRequest("GET", test.url, nil)
assert.NoError(t, err)
params, err := auth.GenerateLoginParameters()
params, err := openid.GenerateLoginParameters()
assert.NoError(t, err)
handler := handler(cfg)

View File

@@ -45,12 +45,12 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return nil, fmt.Errorf("session not found in store: %w", err)
}
log.Warnf("get session: store is unavailable; using cookie fallback: %+v", err)
fallbackSessionData, err := h.GetSessionFallback(r)
if err != nil {
return nil, fmt.Errorf("fallback session not found: %w", err)
}
log.Warnf("get session: store is unavailable: %+v; using cookie fallback", err)
return fallbackSessionData, nil
}
@@ -97,11 +97,12 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
return nil
}
log.Warnf("create session: store is unavailable; using cookie fallback: %+v", err)
err = h.SetSessionFallback(w, sessionData, sessionLifetime)
if err != nil {
return fmt.Errorf("writing session to fallback store: %w", err)
}
log.Warnf("create session: store is unavailable: %+v; using cookie fallback", err)
return nil
}