Refactor PKCE implementation (#1239)

This commit is contained in:
Hidetake Iwata
2025-01-12 21:41:20 +09:00
committed by GitHub
parent 606f1cd0b6
commit 898e8a12de
15 changed files with 145 additions and 157 deletions

View File

@@ -30,7 +30,7 @@ func (o *getTokenOptions) addFlags(f *pflag.FlagSet) {
f.StringVar(&o.ClientID, "oidc-client-id", "", "Client ID of the provider (mandatory)")
f.StringVar(&o.ClientSecret, "oidc-client-secret", "", "Client secret of the provider")
f.StringSliceVar(&o.ExtraScopes, "oidc-extra-scope", nil, "Scopes to request to the provider")
f.BoolVar(&o.UsePKCE, "oidc-use-pkce", false, "Force PKCE usage")
f.BoolVar(&o.UsePKCE, "oidc-use-pkce", false, "Force PKCE even if the provider does not support it")
f.BoolVar(&o.UseAccessToken, "oidc-use-access-token", false, "Instead of using the id_token, use the access_token to authenticate to Kubernetes")
f.BoolVar(&o.ForceRefresh, "force-refresh", false, "If set, refresh the ID token regardless of its expiration time")
o.tokenCacheOptions.addFlags(f)
@@ -81,7 +81,7 @@ func (cmd *GetToken) New() *cobra.Command {
IssuerURL: o.IssuerURL,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
UsePKCE: o.UsePKCE,
ForcePKCE: o.UsePKCE,
UseAccessToken: o.UseAccessToken,
ExtraScopes: o.ExtraScopes,
},

View File

@@ -20,11 +20,11 @@ type Interface interface {
GetAuthCodeURL(in AuthCodeURLInput) string
ExchangeAuthCode(ctx context.Context, in ExchangeAuthCodeInput) (*oidc.TokenSet, error)
GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeInput, localServerReadyChan chan<- string) (*oidc.TokenSet, error)
NegotiatedPKCEMethod() pkce.Method
GetTokenByROPC(ctx context.Context, username, password string) (*oidc.TokenSet, error)
GetDeviceAuthorization(ctx context.Context) (*oauth2dev.AuthorizationResponse, error)
ExchangeDeviceCode(ctx context.Context, authResponse *oauth2dev.AuthorizationResponse) (*oidc.TokenSet, error)
Refresh(ctx context.Context, refreshToken string) (*oidc.TokenSet, error)
SupportedPKCEMethods() []string
}
type AuthCodeURLInput struct {
@@ -60,7 +60,7 @@ type client struct {
oauth2Config oauth2.Config
clock clock.Interface
logger logger.Interface
supportedPKCEMethods []string
negotiatedPKCEMethod pkce.Method
deviceAuthorizationEndpoint string
useAccessToken bool
}
@@ -116,34 +116,33 @@ func (c *client) ExchangeAuthCode(ctx context.Context, in ExchangeAuthCodeInput)
return c.verifyToken(ctx, token, in.Nonce)
}
func authorizationRequestOptions(n string, p pkce.Params, e map[string]string) []oauth2.AuthCodeOption {
o := []oauth2.AuthCodeOption{
func authorizationRequestOptions(nonce string, pkceParams pkce.Params, extraParams map[string]string) []oauth2.AuthCodeOption {
opts := []oauth2.AuthCodeOption{
oauth2.AccessTypeOffline,
gooidc.Nonce(n),
gooidc.Nonce(nonce),
}
if !p.IsZero() {
o = append(o,
oauth2.SetAuthURLParam("code_challenge", p.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", p.CodeChallengeMethod),
)
if pkceParams.CodeChallenge != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_challenge", pkceParams.CodeChallenge))
}
for key, value := range e {
o = append(o, oauth2.SetAuthURLParam(key, value))
if pkceParams.CodeChallengeMethod != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", pkceParams.CodeChallengeMethod))
}
return o
for key, value := range extraParams {
opts = append(opts, oauth2.SetAuthURLParam(key, value))
}
return opts
}
func tokenRequestOptions(p pkce.Params) (o []oauth2.AuthCodeOption) {
if !p.IsZero() {
o = append(o, oauth2.SetAuthURLParam("code_verifier", p.CodeVerifier))
func tokenRequestOptions(pkceParams pkce.Params) []oauth2.AuthCodeOption {
var opts []oauth2.AuthCodeOption
if pkceParams.CodeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", pkceParams.CodeVerifier))
}
return
return opts
}
// SupportedPKCEMethods returns the PKCE methods supported by the provider.
// This may return nil if PKCE is not supported.
func (c *client) SupportedPKCEMethods() []string {
return c.supportedPKCEMethods
func (c *client) NegotiatedPKCEMethod() pkce.Method {
return c.negotiatedPKCEMethod
}
// GetTokenByROPC performs the resource owner password credentials flow.

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net/http"
"slices"
gooidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/wire"
@@ -24,7 +25,7 @@ var Set = wire.NewSet(
)
type FactoryInterface interface {
New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error)
New(ctx context.Context, prov oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error)
}
type Factory struct {
@@ -34,7 +35,7 @@ type Factory struct {
}
// New returns an instance of infrastructure.Interface with the given configuration.
func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error) {
func (f *Factory) New(ctx context.Context, prov oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error) {
rawTLSClientConfig, err := f.Loader.Load(tlsClientConfig)
if err != nil {
return nil, fmt.Errorf("could not load the TLS client config: %w", err)
@@ -52,7 +53,7 @@ func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsc
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
provider, err := gooidc.NewProvider(ctx, p.IssuerURL)
provider, err := gooidc.NewProvider(ctx, prov.IssuerURL)
if err != nil {
return nil, fmt.Errorf("oidc discovery error: %w", err)
}
@@ -60,9 +61,6 @@ func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsc
if err != nil {
return nil, fmt.Errorf("could not determine supported PKCE methods: %w", err)
}
if len(supportedPKCEMethods) == 0 && p.UsePKCE {
supportedPKCEMethods = []string{pkce.MethodS256}
}
deviceAuthorizationEndpoint, err := extractDeviceAuthorizationEndpoint(provider)
if err != nil {
return nil, fmt.Errorf("could not determine device authorization endpoint: %w", err)
@@ -72,34 +70,44 @@ func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsc
provider: provider,
oauth2Config: oauth2.Config{
Endpoint: provider.Endpoint(),
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Scopes: append(p.ExtraScopes, gooidc.ScopeOpenID),
ClientID: prov.ClientID,
ClientSecret: prov.ClientSecret,
Scopes: append(prov.ExtraScopes, gooidc.ScopeOpenID),
},
clock: f.Clock,
logger: f.Logger,
supportedPKCEMethods: supportedPKCEMethods,
negotiatedPKCEMethod: determinePKCEMethod(supportedPKCEMethods, prov.ForcePKCE),
deviceAuthorizationEndpoint: deviceAuthorizationEndpoint,
useAccessToken: p.UseAccessToken,
useAccessToken: prov.UseAccessToken,
}, nil
}
func determinePKCEMethod(supportedPKCEMethods []string, forcePKCE bool) pkce.Method {
if forcePKCE {
return pkce.MethodS256
}
if slices.Contains(supportedPKCEMethods, "S256") {
return pkce.MethodS256
}
return pkce.NoMethod
}
func extractSupportedPKCEMethods(provider *gooidc.Provider) ([]string, error) {
var d struct {
var claims struct {
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
}
if err := provider.Claims(&d); err != nil {
if err := provider.Claims(&claims); err != nil {
return nil, fmt.Errorf("invalid discovery document: %w", err)
}
return d.CodeChallengeMethodsSupported, nil
return claims.CodeChallengeMethodsSupported, nil
}
func extractDeviceAuthorizationEndpoint(provider *gooidc.Provider) (string, error) {
var d struct {
var claims struct {
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"`
}
if err := provider.Claims(&d); err != nil {
if err := provider.Claims(&claims); err != nil {
return "", fmt.Errorf("invalid discovery document: %w", err)
}
return d.DeviceAuthorizationEndpoint, nil
return claims.DeviceAuthorizationEndpoint, nil
}

View File

@@ -15,7 +15,7 @@ type Provider struct {
ClientID string
ClientSecret string // optional
ExtraScopes []string // optional
UsePKCE bool // optional
ForcePKCE bool // optional
UseAccessToken bool // optional
}

View File

@@ -10,11 +10,13 @@ import (
"fmt"
)
var Plain Params
type Method string
const (
// code challenge methods defined as https://tools.ietf.org/html/rfc7636#section-4.3
MethodS256 = "S256"
NoMethod Method = ""
// Code challenge methods defined as https://tools.ietf.org/html/rfc7636#section-4.3
MethodS256 Method = "S256"
)
// Params represents a set of the PKCE parameters.
@@ -24,27 +26,21 @@ type Params struct {
CodeVerifier string
}
func (p Params) IsZero() bool {
return p == Params{}
}
// New returns a parameters supported by the provider.
// You need to pass the code challenge methods defined in RFC7636.
// It returns Plain if no method is available.
func New(methods []string) (Params, error) {
for _, method := range methods {
if method == MethodS256 {
return NewS256()
}
// It returns a zero value if no method is available.
func New(method Method) (Params, error) {
if method == MethodS256 {
return NewS256()
}
return Plain, nil
return Params{}, nil
}
// NewS256 generates a parameters for S256.
func NewS256() (Params, error) {
b, err := random32()
if err != nil {
return Plain, fmt.Errorf("could not generate a random: %w", err)
return Params{}, fmt.Errorf("could not generate a random: %w", err)
}
return computeS256(b), nil
}
@@ -63,7 +59,7 @@ func computeS256(b []byte) Params {
_, _ = s.Write([]byte(v))
return Params{
CodeChallenge: base64URLEncode(s.Sum(nil)),
CodeChallengeMethod: MethodS256,
CodeChallengeMethod: string(MethodS256),
CodeVerifier: v,
}
}

View File

@@ -2,40 +2,33 @@ package pkce
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestNew(t *testing.T) {
t.Run("S256", func(t *testing.T) {
p, err := New([]string{"plain", "S256"})
params, err := New(MethodS256)
if err != nil {
t.Fatalf("New error: %s", err)
}
if p.CodeChallengeMethod != "S256" {
t.Errorf("CodeChallengeMethod wants S256 but was %s", p.CodeChallengeMethod)
if params.CodeChallengeMethod != "S256" {
t.Errorf("CodeChallengeMethod wants S256 but was %s", params.CodeChallengeMethod)
}
if p.CodeChallenge == "" {
if params.CodeChallenge == "" {
t.Errorf("CodeChallenge wants non-empty but was empty")
}
if p.CodeVerifier == "" {
if params.CodeVerifier == "" {
t.Errorf("CodeVerifier wants non-empty but was empty")
}
})
t.Run("plain", func(t *testing.T) {
p, err := New([]string{"plain"})
t.Run("NoMethod", func(t *testing.T) {
params, err := New(NoMethod)
if err != nil {
t.Fatalf("New error: %s", err)
}
if !p.IsZero() {
t.Errorf("IsZero wants true but was false")
}
})
t.Run("nil", func(t *testing.T) {
p, err := New(nil)
if err != nil {
t.Fatalf("New error: %s", err)
}
if !p.IsZero() {
t.Errorf("IsZero wants true but was false")
if diff := cmp.Diff(Params{}, params); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}

View File

@@ -41,9 +41,9 @@ func (u *Browser) Do(ctx context.Context, o *BrowserOption, oidcClient client.In
if err != nil {
return nil, fmt.Errorf("could not generate a nonce: %w", err)
}
p, err := pkce.New(oidcClient.SupportedPKCEMethods())
pkceParams, err := pkce.New(oidcClient.NegotiatedPKCEMethod())
if err != nil {
return nil, fmt.Errorf("could not generate PKCE parameters: %w", err)
return nil, fmt.Errorf("could not generate the PKCE parameters: %w", err)
}
successHTML := BrowserSuccessHTML
if o.OpenURLAfterAuthentication != "" {
@@ -53,7 +53,7 @@ func (u *Browser) Do(ctx context.Context, o *BrowserOption, oidcClient client.In
BindAddress: o.BindAddress,
State: state,
Nonce: nonce,
PKCEParams: p,
PKCEParams: pkceParams,
RedirectURLHostname: o.RedirectURLHostname,
AuthRequestExtraParams: o.AuthRequestExtraParams,
LocalServerSuccessHTML: successHTML,

View File

@@ -10,6 +10,7 @@ import (
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/oidc/client_mock"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/oidc/client"
"github.com/int128/kubelogin/pkg/pkce"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/stretchr/testify/mock"
)
@@ -31,9 +32,7 @@ func TestBrowser_Do(t *testing.T) {
AuthRequestExtraParams: map[string]string{"ttl": "86400", "reauth": "true"},
}
mockClient := client_mock.NewMockInterface(t)
mockClient.EXPECT().
SupportedPKCEMethods().
Return(nil)
mockClient.EXPECT().NegotiatedPKCEMethod().Return(pkce.NoMethod)
mockClient.EXPECT().
GetTokenByAuthCode(mock.Anything, mock.Anything, mock.Anything).
Run(func(_ context.Context, in client.GetTokenByAuthCodeInput, readyChan chan<- string) {
@@ -85,9 +84,7 @@ func TestBrowser_Do(t *testing.T) {
AuthenticationTimeout: 10 * time.Second,
}
mockClient := client_mock.NewMockInterface(t)
mockClient.EXPECT().
SupportedPKCEMethods().
Return(nil)
mockClient.EXPECT().NegotiatedPKCEMethod().Return(pkce.NoMethod)
mockClient.EXPECT().
GetTokenByAuthCode(mock.Anything, mock.Anything, mock.Anything).
Run(func(_ context.Context, _ client.GetTokenByAuthCodeInput, readyChan chan<- string) {
@@ -127,9 +124,7 @@ func TestBrowser_Do(t *testing.T) {
AuthenticationTimeout: 10 * time.Second,
}
mockClient := client_mock.NewMockInterface(t)
mockClient.EXPECT().
SupportedPKCEMethods().
Return(nil)
mockClient.EXPECT().NegotiatedPKCEMethod().Return(pkce.NoMethod)
mockClient.EXPECT().
GetTokenByAuthCode(mock.Anything, mock.Anything, mock.Anything).
Run(func(_ context.Context, _ client.GetTokenByAuthCodeInput, readyChan chan<- string) {

View File

@@ -34,15 +34,14 @@ func (u *Keyboard) Do(ctx context.Context, o *KeyboardOption, oidcClient client.
if err != nil {
return nil, fmt.Errorf("could not generate a nonce: %w", err)
}
p, err := pkce.New(oidcClient.SupportedPKCEMethods())
pkceParams, err := pkce.New(oidcClient.NegotiatedPKCEMethod())
if err != nil {
return nil, fmt.Errorf("could not generate PKCE parameters: %w", err)
return nil, fmt.Errorf("could not generate the PKCE parameters: %w", err)
}
authCodeURL := oidcClient.GetAuthCodeURL(client.AuthCodeURLInput{
State: state,
Nonce: nonce,
PKCEParams: p,
PKCEParams: pkceParams,
RedirectURI: o.RedirectURL,
AuthRequestExtraParams: o.AuthRequestExtraParams,
})
@@ -55,7 +54,7 @@ func (u *Keyboard) Do(ctx context.Context, o *KeyboardOption, oidcClient client.
u.Logger.V(1).Infof("exchanging the code and token")
tokenSet, err := oidcClient.ExchangeAuthCode(ctx, client.ExchangeAuthCodeInput{
Code: code,
PKCEParams: p,
PKCEParams: pkceParams,
Nonce: nonce,
RedirectURI: o.RedirectURL,
})

View File

@@ -10,6 +10,7 @@ import (
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/oidc/client_mock"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/oidc/client"
"github.com/int128/kubelogin/pkg/pkce"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/stretchr/testify/mock"
)
@@ -24,9 +25,7 @@ func TestKeyboard_Do(t *testing.T) {
AuthRequestExtraParams: map[string]string{"ttl": "86400", "reauth": "true"},
}
mockClient := client_mock.NewMockInterface(t)
mockClient.EXPECT().
SupportedPKCEMethods().
Return(nil)
mockClient.EXPECT().NegotiatedPKCEMethod().Return(pkce.NoMethod)
mockClient.EXPECT().
GetAuthCodeURL(mock.Anything).
Run(func(in client.AuthCodeURLInput) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/oidc/client_mock"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/oidc/client"
"github.com/int128/kubelogin/pkg/pkce"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
testingLogger "github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
@@ -96,9 +97,7 @@ func TestAuthentication_Do(t *testing.T) {
},
}
mockClient := client_mock.NewMockInterface(t)
mockClient.EXPECT().
SupportedPKCEMethods().
Return(nil)
mockClient.EXPECT().NegotiatedPKCEMethod().Return(pkce.NoMethod)
mockClient.EXPECT().
Refresh(ctx, "EXPIRED_REFRESH_TOKEN").
Return(nil, errors.New("token has expired"))

View File

@@ -88,7 +88,7 @@ func (u *Setup) DoStage2(ctx context.Context, in Stage2Input) error {
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
UsePKCE: in.UsePKCE,
ForcePKCE: in.UsePKCE,
UseAccessToken: in.UseAccessToken,
},
GrantOptionSet: in.GrantOptionSet,