diff --git a/docs/configuration.md b/docs/configuration.md index f70be93..62a99a4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -36,6 +36,7 @@ The following flags are available: | `openid.client-jwk` | string | | JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over `openid.client-secret`. | | `openid.client-secret` | string | | Client secret for the OpenID client. Overridden by `openid.client-jwk`, if configured. | | `openid.id-token-signing-alg` | string | `RS256` | Expected JWA value (as defined in RFC 7518) of public keys for validating id_token signatures. This only applies where the key's `alg` header is not set. | +| `openid.new-client-auth-jwt-type` | bool | `false` | When enabled, sets the value of the \"typ\" header of the JWT used for client authentication equal to "client-authentication+jwt" in accordance with RFC7523bis. If not enabled, the value is set to "JWT". | | `openid.post-logout-redirect-uri` | string | | URI for redirecting the user after successful logout at the Identity Provider. | | `openid.provider` | string | `openid` | Provider configuration to load and use, either `openid`, `azure`, `idporten`. | | `openid.resource-indicator` | string | | OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens. | diff --git a/pkg/config/openid.go b/pkg/config/openid.go index 07bfcde..5b240e6 100644 --- a/pkg/config/openid.go +++ b/pkg/config/openid.go @@ -25,6 +25,7 @@ type OpenID struct { ClientJWK string `json:"client-jwk"` ClientSecret string `json:"client-secret"` IDTokenSigningAlg string `json:"id-token-signing-alg"` + NewClientAuthJWTType bool `json:"new-client-auth-jwt-type"` PostLogoutRedirectURI string `json:"post-logout-redirect-uri"` Provider Provider `json:"provider"` ResourceIndicator string `json:"resource-indicator"` @@ -58,6 +59,7 @@ const ( OpenIDClientID = "openid.client-id" OpenIDClientJWK = "openid.client-jwk" OpenIDClientSecret = "openid.client-secret" + OpenIDNewClientAuthJWTType = "openid.new-client-auth-jwt-type" OpenIDIDTokenSigningAlg = "openid.id-token-signing-alg" OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri" OpenIDProvider = "openid.provider" @@ -74,6 +76,7 @@ func openidFlags() { flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'.") flag.String(OpenIDClientSecret, "", "Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured.") flag.String(OpenIDIDTokenSigningAlg, jwa.RS256().String(), "Expected JWA value (as defined in RFC 7518) of public keys for validating id_token signatures. This only applies where the key's 'alg' header is not set.") + flag.Bool(OpenIDNewClientAuthJWTType, false, "When enabled, sets the value of the \"typ\" header of the JWT used for client authentication equal to \"client-authentication+jwt\" in accordance with RFC7523bis. If not enabled, the value is set to \"JWT\".") flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.") flag.String(OpenIDProvider, string(ProviderOpenID), "Provider configuration to load and use, either 'openid', 'azure', 'idporten'.") flag.String(OpenIDResourceIndicator, "", "OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens.") diff --git a/pkg/mock/client.go b/pkg/mock/client.go index 98443f2..5a5f821 100644 --- a/pkg/mock/client.go +++ b/pkg/mock/client.go @@ -14,6 +14,8 @@ type TestClientConfiguration struct { trustedAudiences map[string]bool } +var _ openidconfig.Client = (*TestClientConfiguration)(nil) + func (c *TestClientConfiguration) ACRValues() string { return c.Config.OpenID.ACRValues } @@ -38,6 +40,10 @@ func (c *TestClientConfiguration) ClientSecret() string { return c.Config.OpenID.ClientSecret } +func (c *TestClientConfiguration) NewClientAuthJWTType() bool { + return c.Config.OpenID.NewClientAuthJWTType +} + func (c *TestClientConfiguration) SetPostLogoutRedirectURI(uri string) { c.Config.OpenID.PostLogoutRedirectURI = uri } diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index f05f320..db9b36f 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -144,7 +144,7 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken, previousIDToken func (c *Client) ClientAuthenticationParams() (openid.RequestParams, error) { switch c.cfg.Client().AuthMethod() { case openidconfig.AuthMethodPrivateKeyJWT: - assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime) + assertion, err := c.ClientAuthenticationAssertion(DefaultClientAssertionLifetime) if err != nil { return nil, fmt.Errorf("creating client assertion: %w", err) } @@ -158,12 +158,12 @@ func (c *Client) ClientAuthenticationParams() (openid.RequestParams, error) { return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod()) } -func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { +func (c *Client) ClientAuthenticationAssertion(expiration time.Duration) (string, error) { clientCfg := c.cfg.Client() providerCfg := c.cfg.Provider() key := clientCfg.ClientJWK() - iat := time.Now().Add(-5 * time.Second).Truncate(time.Second) + iat := time.Now() exp := iat.Add(expiration) tok, err := jwt.NewBuilder(). @@ -183,7 +183,16 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { return "", fmt.Errorf("missing algorithm on client key") } - encoded, err := jwt.Sign(tok, jwt.WithKey(alg, key)) + opts := make([]jwt.Option, 0) + if c.cfg.Client().NewClientAuthJWTType() { + hdrs := jws.NewHeaders() + if err := hdrs.Set(jws.TypeKey, "client-authentication+jwt"); err != nil { + return "", fmt.Errorf("setting type header on client assertion: %w", err) + } + opts = append(opts, jws.WithProtectedHeaders(hdrs)) + } + + encoded, err := jwt.Sign(tok, jwt.WithKey(alg, key, opts...)) if err != nil { return "", fmt.Errorf("signing client assertion: %w", err) } diff --git a/pkg/openid/client/client_test.go b/pkg/openid/client/client_test.go index 50edbe9..7ff576f 100644 --- a/pkg/openid/client/client_test.go +++ b/pkg/openid/client/client_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jws" "github.com/lestrrat-go/jwx/v3/jwt" "github.com/stretchr/testify/assert" @@ -14,7 +16,7 @@ import ( "github.com/nais/wonderwall/pkg/openid/client" ) -func TestMakeAssertion(t *testing.T) { +func TestClientAuthenticationAssertion(t *testing.T) { cfg := mock.Config() cfg.OpenID.ClientID = "some-client-id" @@ -23,7 +25,7 @@ func TestMakeAssertion(t *testing.T) { c := newTestClientWithConfig(openidConfig) expiry := 30 * time.Second - jwtAssertion, err := c.MakeAssertion(expiry) + jwtAssertion, err := c.ClientAuthenticationAssertion(expiry) assert.NoError(t, err) assertFlattenedAudience(t, jwtAssertion) @@ -64,6 +66,48 @@ func TestMakeAssertion(t *testing.T) { assert.True(t, ok) assert.True(t, exp.After(time.Now())) assert.True(t, exp.Before(time.Now().Add(expiry))) + + msg, err := jws.ParseString(jwtAssertion) + assert.NoError(t, err) + assert.Len(t, msg.Signatures(), 1) + headers := msg.Signatures()[0].ProtectedHeaders() + + typ, ok := headers.Type() + assert.True(t, ok) + assert.Equal(t, "JWT", typ) + + alg, ok = headers.Algorithm() + assert.True(t, ok) + assert.Equal(t, jwa.RS256(), alg) + + expectedKid, ok := key.KeyID() + assert.True(t, ok) + kid, ok := headers.KeyID() + assert.True(t, ok) + assert.Equal(t, expectedKid, kid) +} + +func TestClientAuthenticationAssertionHeader(t *testing.T) { + cfg := mock.Config() + cfg.OpenID.ClientID = "some-client-id" + cfg.OpenID.NewClientAuthJWTType = true + + openidConfig := mock.NewTestConfiguration(cfg) + openidConfig.TestProvider.SetIssuer("some-issuer") + c := newTestClientWithConfig(openidConfig) + + expiry := 30 * time.Second + jwtAssertion, err := c.ClientAuthenticationAssertion(expiry) + assert.NoError(t, err) + + msg, err := jws.ParseString(jwtAssertion) + assert.NoError(t, err) + assert.Len(t, msg.Signatures(), 1) + headers := msg.Signatures()[0].ProtectedHeaders() + + typ, ok := headers.Type() + assert.True(t, ok) + assert.Equal(t, "client-authentication+jwt", typ) } // assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value. diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index 6d485cd..1b5731e 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -24,6 +24,7 @@ type Client interface { ClientID() string ClientJWK() jwk.Key ClientSecret() string + NewClientAuthJWTType() bool PostLogoutRedirectURI() string ResourceIndicator() string Scopes() scopes.Scopes @@ -38,6 +39,8 @@ type client struct { trustedAudiences map[string]bool } +var _ Client = (*client)(nil) + func (in *client) ACRValues() string { return in.OpenID.ACRValues } @@ -62,6 +65,10 @@ func (in *client) ClientSecret() string { return in.OpenID.ClientSecret } +func (in *client) NewClientAuthJWTType() bool { + return in.OpenID.NewClientAuthJWTType +} + func (in *client) PostLogoutRedirectURI() string { return in.OpenID.PostLogoutRedirectURI }