mirror of
https://github.com/nais/wonderwall.git
synced 2026-02-14 17:49:54 +00:00
chore(deps): bump github.com/lestrrat-go/jwx from v2 to v3
This commit is contained in:
5
go.mod
5
go.mod
@@ -16,7 +16,8 @@ require (
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1
|
||||
github.com/nais/liberator v0.0.0-20250408101050-2ffa1b42f7f2
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/redis/go-redis/extra/redisotel/v9 v9.8.0
|
||||
@@ -58,8 +59,6 @@ require (
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httprc v1.0.6 // indirect
|
||||
github.com/lestrrat-go/iter v1.0.2 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -71,12 +71,10 @@ github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtn
|
||||
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
|
||||
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
|
||||
@@ -3,9 +3,10 @@ package crypto
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
)
|
||||
|
||||
type JwkSet struct {
|
||||
@@ -16,17 +17,28 @@ type JwkSet struct {
|
||||
func NewJwk() (jwk.Key, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, fmt.Errorf("generating key: %w", err)
|
||||
}
|
||||
|
||||
key, err := jwk.FromRaw(privateKey)
|
||||
key, err := jwk.Import(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("importing key: %w", err)
|
||||
}
|
||||
|
||||
key.Set(jwk.AlgorithmKey, jwa.RS256)
|
||||
key.Set(jwk.KeyTypeKey, jwa.RSA)
|
||||
jwk.AssignKeyID(key)
|
||||
err = key.Set(jwk.AlgorithmKey, jwa.RS256().String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting algorithm: %w", err)
|
||||
}
|
||||
|
||||
err = key.Set(jwk.KeyTypeKey, jwa.RSA().String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting key type: %w", err)
|
||||
}
|
||||
|
||||
err = jwk.AssignKeyID(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assigning key id: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
@@ -34,15 +46,18 @@ func NewJwk() (jwk.Key, error) {
|
||||
func NewJwkSet() (*JwkSet, error) {
|
||||
key, err := NewJwk()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("creating jwk: %w", err)
|
||||
}
|
||||
|
||||
privateKeys := jwk.NewSet()
|
||||
privateKeys.AddKey(key)
|
||||
err = privateKeys.AddKey(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding key to set: %w", err)
|
||||
}
|
||||
|
||||
publicKeys, err := jwk.PublicSetOf(privateKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("creating public set: %w", err)
|
||||
}
|
||||
|
||||
return &JwkSet{
|
||||
|
||||
@@ -2,9 +2,8 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
flag "github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
@@ -45,9 +44,9 @@ func (in OpenID) TrustedAudiences() map[string]bool {
|
||||
}
|
||||
|
||||
func (in OpenID) Validate() error {
|
||||
valid := jwa.SignatureAlgorithms()
|
||||
if !slices.Contains(valid, jwa.SignatureAlgorithm(in.IDTokenSigningAlg)) {
|
||||
return fmt.Errorf("invalid id_token signing algorithm: %q, must be one of %s", in.IDTokenSigningAlg, valid)
|
||||
_, ok := jwa.LookupSignatureAlgorithm(in.IDTokenSigningAlg)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid id_token signing algorithm: %q, must be one of %s", in.IDTokenSigningAlg, jwa.SignatureAlgorithms())
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -74,7 +73,7 @@ func openidFlags() {
|
||||
flag.String(OpenIDClientID, "", "Client ID for the OpenID client.")
|
||||
flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'.")
|
||||
flag.String(OpenIDClientSecret, "", "Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured.")
|
||||
flag.String(OpenIDIDTokenSigningAlg, string(jwa.RS256), "Expected JWA value (as defined in RFC 7518) of public keys for validating id_token signatures. This only applies where the key's 'alg' header is not set.")
|
||||
flag.String(OpenIDIDTokenSigningAlg, jwa.RS256().String(), "Expected JWA value (as defined in RFC 7518) of public keys for validating id_token signatures. This only applies where the key's 'alg' header is not set.")
|
||||
flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.")
|
||||
flag.String(OpenIDProvider, string(ProviderOpenID), "Provider configuration to load and use, either 'openid', 'azure', 'idporten'.")
|
||||
flag.String(OpenIDResourceIndicator, "", "OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens.")
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/nais/wonderwall/internal/crypto"
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/internal/crypto"
|
||||
@@ -188,7 +188,7 @@ func (ip *IdentityProviderHandler) signToken(token jwt.Token) (string, error) {
|
||||
return "", fmt.Errorf("could not get signer")
|
||||
}
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, signer))
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), signer))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -567,7 +567,12 @@ func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *h
|
||||
|
||||
iat := time.Now().Truncate(time.Second)
|
||||
exp := iat.Add(ip.TokenDuration)
|
||||
sub := data.OriginalIDToken.Subject()
|
||||
sub, ok := data.OriginalIDToken.Subject()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
oauthError(w, fmt.Errorf("could not get subject from original id token"))
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := jwt.New()
|
||||
accessToken.Set("sub", sub)
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/internal/crypto"
|
||||
@@ -59,7 +59,7 @@ func (t *TestProviderConfiguration) EndSessionEndpointURL() url.URL {
|
||||
}
|
||||
|
||||
func (t *TestProviderConfiguration) IDTokenSigningAlg() jwa.KeyAlgorithm {
|
||||
return jwa.RS256
|
||||
return jwa.RS256()
|
||||
}
|
||||
|
||||
func (t *TestProviderConfiguration) Issuer() string {
|
||||
|
||||
@@ -10,15 +10,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
httpinternal "github.com/nais/wonderwall/internal/http"
|
||||
"github.com/nais/wonderwall/internal/o11y/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
httpinternal "github.com/nais/wonderwall/internal/http"
|
||||
"github.com/nais/wonderwall/internal/o11y/otel"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
@@ -157,7 +156,12 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
|
||||
return "", fmt.Errorf("building client assertion: %w", err)
|
||||
}
|
||||
|
||||
encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key))
|
||||
alg, ok := key.Algorithm()
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing algorithm on client key")
|
||||
}
|
||||
|
||||
encoded, err := jwt.Sign(tok, jwt.WithKey(alg, key))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing client assertion: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
@@ -31,8 +31,12 @@ func TestMakeAssertion(t *testing.T) {
|
||||
key := openidConfig.Client().ClientJWK()
|
||||
publicKey, err := key.PublicKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
alg, ok := publicKey.Algorithm()
|
||||
assert.True(t, ok)
|
||||
|
||||
opts := []jwt.ParseOption{
|
||||
jwt.WithKey(publicKey.Algorithm(), publicKey),
|
||||
jwt.WithKey(alg, publicKey),
|
||||
jwt.WithRequiredClaim(jwt.IssuedAtKey),
|
||||
jwt.WithRequiredClaim(jwt.ExpirationKey),
|
||||
jwt.WithRequiredClaim(jwt.JwtIDKey),
|
||||
@@ -40,13 +44,26 @@ func TestMakeAssertion(t *testing.T) {
|
||||
assertion, err := jwt.ParseString(jwtAssertion, opts...)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.ElementsMatch(t, []string{"some-issuer"}, assertion.Audience())
|
||||
assert.Equal(t, "some-client-id", assertion.Issuer())
|
||||
assert.Equal(t, "some-client-id", assertion.Subject())
|
||||
aud, ok := assertion.Audience()
|
||||
assert.True(t, ok)
|
||||
assert.ElementsMatch(t, []string{"some-issuer"}, aud)
|
||||
|
||||
assert.True(t, assertion.IssuedAt().Before(time.Now()))
|
||||
assert.True(t, assertion.Expiration().After(time.Now()))
|
||||
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
|
||||
iss, ok := assertion.Issuer()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "some-client-id", iss)
|
||||
|
||||
sub, ok := assertion.Subject()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "some-client-id", sub)
|
||||
|
||||
iat, ok := assertion.IssuedAt()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, iat.Before(time.Now()))
|
||||
|
||||
exp, ok := assertion.Expiration()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, exp.After(time.Now()))
|
||||
assert.True(t, exp.Before(time.Now().Add(expiry)))
|
||||
}
|
||||
|
||||
// assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value.
|
||||
|
||||
@@ -3,7 +3,7 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
@@ -106,9 +106,14 @@ func NewProviderConfig(cfg *config.Config) (Provider, error) {
|
||||
|
||||
providerCfg.Print()
|
||||
|
||||
signingAlg, ok := jwa.LookupSignatureAlgorithm(cfg.OpenID.IDTokenSigningAlg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid id_token signing algorithm: %q, must be one of %s", cfg.OpenID.IDTokenSigningAlg, jwa.SignatureAlgorithms())
|
||||
}
|
||||
|
||||
return &provider{
|
||||
endSessionEndpointURL: endSessionEndpointURL,
|
||||
idTokenSigningAlg: jwa.SignatureAlgorithm(cfg.OpenID.IDTokenSigningAlg),
|
||||
idTokenSigningAlg: signingAlg,
|
||||
metadata: providerCfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/httprc/v3"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/nais/wonderwall/internal/o11y/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,11 +31,16 @@ type jwksLock struct {
|
||||
|
||||
func (p *JwksProvider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
|
||||
url := p.config.JwksURI()
|
||||
set, err := p.jwksCache.Get(ctx, url)
|
||||
set, err := p.jwksCache.Lookup(ctx, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider: fetching jwks: %w", err)
|
||||
}
|
||||
|
||||
set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider: mutating jwks: %w", err)
|
||||
}
|
||||
|
||||
return &set, nil
|
||||
}
|
||||
|
||||
@@ -59,6 +64,12 @@ func (p *JwksProvider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider: refreshing jwks: %w", err)
|
||||
}
|
||||
|
||||
set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("provider: mutating jwks: %w", err)
|
||||
}
|
||||
|
||||
span.SetAttributes(attribute.Bool("jwks.refreshed", true))
|
||||
return &set, nil
|
||||
}
|
||||
@@ -67,17 +78,13 @@ func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksP
|
||||
providerCfg := openidCfg.Provider()
|
||||
|
||||
uri := providerCfg.JwksURI()
|
||||
cache := jwk.NewCache(ctx)
|
||||
|
||||
err := cache.Register(uri, jwk.WithPostFetcher(keySetMutator(providerCfg)))
|
||||
cache, err := jwk.NewCache(ctx, httprc.NewClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering jwks provider uri to cache: %w", err)
|
||||
return nil, fmt.Errorf("creating jwks cache: %w", err)
|
||||
}
|
||||
|
||||
// trigger initial fetch and cache of jwk set
|
||||
_, err = cache.Refresh(ctx, uri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initial fetch of jwks from provider: %w", err)
|
||||
if err := cache.Register(ctx, uri); err != nil {
|
||||
return nil, fmt.Errorf("registering jwks provider uri to cache: %w", err)
|
||||
}
|
||||
|
||||
return &JwksProvider{
|
||||
@@ -87,21 +94,31 @@ func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksP
|
||||
}, nil
|
||||
}
|
||||
|
||||
func keySetMutator(cfg openidconfig.Provider) jwk.PostFetcher {
|
||||
return jwk.PostFetchFunc(func(uri string, set jwk.Set) (jwk.Set, error) {
|
||||
for i := 0; i < set.Len(); i++ {
|
||||
key, ok := set.Key(i)
|
||||
if !ok || key.Algorithm().String() != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// if no "alg" is set on the key, set it to the expected algorithm
|
||||
err := key.Set(jwk.AlgorithmKey, cfg.IDTokenSigningAlg())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting key algorithm: %w", err)
|
||||
}
|
||||
func ensureJwkSetWithAlg(set jwk.Set, expectedAlg jwa.KeyAlgorithm) (jwk.Set, error) {
|
||||
for i := 0; i < set.Len(); i++ {
|
||||
key, ok := set.Key(i)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return set, nil
|
||||
})
|
||||
alg, ok := key.Algorithm()
|
||||
if ok {
|
||||
// drop keys with "alg=none"
|
||||
if alg == jwa.NoSignature() {
|
||||
if err := set.RemoveKey(key); err != nil {
|
||||
return nil, fmt.Errorf("removing key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// don't mutate keys with a valid algorithm
|
||||
continue
|
||||
}
|
||||
|
||||
// set "alg" to expected algorithm for keys that don't have one set
|
||||
if err := key.Set(jwk.AlgorithmKey, expectedAlg); err != nil {
|
||||
return nil, fmt.Errorf("setting key algorithm: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid/acr"
|
||||
@@ -141,7 +141,10 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie, jwks *
|
||||
// OpenID Connect Core 3.1.3.7, step 3.
|
||||
// The `aud` (audience) Claim MAY contain an array with more than one element.
|
||||
// The ID Token MUST be rejected if the ID Token [...] contains additional audiences not trusted by the Client.
|
||||
audiences := in.Audience()
|
||||
audiences, ok := in.Audience()
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required 'aud' claim in id_token")
|
||||
}
|
||||
if len(audiences) > 1 {
|
||||
trusted := clientConfig.Audiences()
|
||||
untrusted := make([]string, 0)
|
||||
@@ -198,9 +201,9 @@ func (in *IDToken) Claim(claim string) (any, error) {
|
||||
return nil, fmt.Errorf("token is nil")
|
||||
}
|
||||
|
||||
gotClaim, ok := in.Token.Get(claim)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required '%s' claim in id_token", claim)
|
||||
var gotClaim any
|
||||
if err := in.Token.Get(claim, &gotClaim); err != nil {
|
||||
return nil, fmt.Errorf("missing required '%s' claim in id_token: %w", claim, err)
|
||||
}
|
||||
|
||||
return gotClaim, nil
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -45,13 +45,31 @@ func TestParseIDToken(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, sub, parsed.Subject())
|
||||
assert.Equal(t, "some-issuer", parsed.Issuer())
|
||||
assert.Equal(t, []string{"some-client-id"}, parsed.Audience())
|
||||
actualSub, ok := parsed.Subject()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, sub, actualSub)
|
||||
|
||||
actualIss, ok := parsed.Issuer()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "some-issuer", actualIss)
|
||||
|
||||
actualAud, ok := parsed.Audience()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"some-client-id"}, actualAud)
|
||||
|
||||
assert.Equal(t, "some-nonce", parsed.StringClaimOrEmpty("nonce"))
|
||||
assert.Equal(t, iat, parsed.IssuedAt())
|
||||
assert.Equal(t, exp, parsed.Expiration())
|
||||
assert.NotEmpty(t, parsed.JwtID())
|
||||
|
||||
actualIat, ok := parsed.IssuedAt()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, iat, actualIat)
|
||||
|
||||
actualExp, ok := parsed.Expiration()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, exp, actualExp)
|
||||
|
||||
actualJti, ok := parsed.JwtID()
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, actualJti)
|
||||
}
|
||||
|
||||
func TestIDToken_GetAcrClaim(t *testing.T) {
|
||||
@@ -191,28 +209,28 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
claims: &claims{
|
||||
remove: []string{"sub"},
|
||||
},
|
||||
expectErr: `"sub" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "sub" is missing`,
|
||||
},
|
||||
{
|
||||
name: "missing exp",
|
||||
claims: &claims{
|
||||
remove: []string{"exp"},
|
||||
},
|
||||
expectErr: `"exp" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "exp" is missing`,
|
||||
},
|
||||
{
|
||||
name: "missing iat",
|
||||
claims: &claims{
|
||||
remove: []string{"iat"},
|
||||
},
|
||||
expectErr: `"iat" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "iat" is missing`,
|
||||
},
|
||||
{
|
||||
name: "missing iss",
|
||||
claims: &claims{
|
||||
remove: []string{"iss"},
|
||||
},
|
||||
expectErr: `"iss" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "iss" is missing`,
|
||||
},
|
||||
{
|
||||
name: "iat is in the future",
|
||||
@@ -248,14 +266,14 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
"iss": "https://some-other-issuer",
|
||||
},
|
||||
},
|
||||
expectErr: `"iss" not satisfied: values do not match`,
|
||||
expectErr: `claim "iss" does not have the expected value`,
|
||||
},
|
||||
{
|
||||
name: "missing aud",
|
||||
claims: &claims{
|
||||
remove: []string{"aud"},
|
||||
},
|
||||
expectErr: `"aud" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "aud" is missing`,
|
||||
},
|
||||
{
|
||||
name: "audience mismatch",
|
||||
@@ -297,7 +315,7 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
claims: &claims{
|
||||
remove: []string{"nonce"},
|
||||
},
|
||||
expectErr: `"nonce" not satisfied: claim "nonce" does not exist`,
|
||||
expectErr: `claim "nonce" does not exist`,
|
||||
},
|
||||
{
|
||||
name: "nonce mismatch",
|
||||
@@ -306,7 +324,7 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
"nonce": "invalid-nonce",
|
||||
},
|
||||
},
|
||||
expectErr: `"nonce" not satisfied: values do not match`,
|
||||
expectErr: `claim "nonce" does not have the expected value`,
|
||||
},
|
||||
{
|
||||
name: "sid required",
|
||||
@@ -318,7 +336,7 @@ func TestIDToken_Validate(t *testing.T) {
|
||||
remove: []string{"sid"},
|
||||
},
|
||||
requireSid: true,
|
||||
expectErr: `"sid" not satisfied: required claim not found`,
|
||||
expectErr: `required claim "sid" is missing`,
|
||||
},
|
||||
{
|
||||
name: "acr required",
|
||||
@@ -429,7 +447,7 @@ func makeIDToken(claims *claims) (*openid.IDToken, error) {
|
||||
return nil, fmt.Errorf("no private key found at index 0")
|
||||
}
|
||||
|
||||
jws, err := jwt.Sign(idToken, jwt.WithKey(jwa.RS256, key))
|
||||
jws, err := jwt.Sign(idToken, jwt.WithKey(jwa.RS256(), key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("signing token: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
||||
jwtlib "github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jwtlib "github.com/lestrrat-go/jwx/v2/jwt"
|
||||
jwtlib "github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user