Refactor: extract tlsclientconfig.Config (#409)

This commit is contained in:
Hidetake Iwata
2020-11-03 14:37:24 +09:00
committed by GitHub
parent 878847f937
commit 34762216c1
32 changed files with 372 additions and 492 deletions

View File

@@ -1,71 +0,0 @@
// Package certpool provides loading certificates from files or base64 encoded string.
package certpool
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"io/ioutil"
"github.com/google/wire"
"golang.org/x/xerrors"
)
//go:generate mockgen -destination mock_certpool/mock_certpool.go github.com/int128/kubelogin/pkg/adaptors/certpool Interface
// Set provides an implementation and interface.
var Set = wire.NewSet(
wire.Value(NewFunc(New)),
wire.Struct(new(CertPool), "*"),
wire.Bind(new(Interface), new(*CertPool)),
)
type NewFunc func() Interface
// New returns an instance which implements the Interface.
func New() Interface {
return &CertPool{pool: x509.NewCertPool()}
}
type Interface interface {
AddFile(filename string) error
AddBase64Encoded(s string) error
SetRootCAs(cfg *tls.Config)
}
// CertPool represents a pool of certificates.
type CertPool struct {
pool *x509.CertPool
}
// SetRootCAs sets cfg.RootCAs if it has any certificate.
// Otherwise do nothing.
func (p *CertPool) SetRootCAs(cfg *tls.Config) {
if len(p.pool.Subjects()) > 0 {
cfg.RootCAs = p.pool
}
}
// AddFile loads the certificate from the file.
func (p *CertPool) AddFile(filename string) error {
b, err := ioutil.ReadFile(filename)
if err != nil {
return xerrors.Errorf("could not read %s: %w", filename, err)
}
if !p.pool.AppendCertsFromPEM(b) {
return xerrors.Errorf("could not append certificate from %s", filename)
}
return nil
}
// AddBase64Encoded loads the certificate from the base64 encoded string.
func (p *CertPool) AddBase64Encoded(s string) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return xerrors.Errorf("could not decode base64: %w", err)
}
if !p.pool.AppendCertsFromPEM(b) {
return xerrors.Errorf("could not append certificate")
}
return nil
}

View File

@@ -1,58 +0,0 @@
package certpool
import (
"crypto/tls"
"io/ioutil"
"testing"
)
func TestCertPool_AddFile(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
p := New()
if err := p.AddFile("testdata/ca1.crt"); err != nil {
t.Errorf("AddFile error: %s", err)
}
var cfg tls.Config
p.SetRootCAs(&cfg)
if n := len(cfg.RootCAs.Subjects()); n != 1 {
t.Errorf("n wants 1 but was %d", n)
}
})
t.Run("Invalid", func(t *testing.T) {
p := New()
err := p.AddFile("testdata/Makefile")
if err == nil {
t.Errorf("AddFile wants an error but was nil")
}
})
}
func TestCertPool_AddBase64Encoded(t *testing.T) {
p := New()
if err := p.AddBase64Encoded(readFile(t, "testdata/ca2.crt.base64")); err != nil {
t.Errorf("AddBase64Encoded error: %s", err)
}
var cfg tls.Config
p.SetRootCAs(&cfg)
if n := len(cfg.RootCAs.Subjects()); n != 1 {
t.Errorf("n wants 1 but was %d", n)
}
}
func TestCertPool_SetRootCAs(t *testing.T) {
p := New()
var cfg tls.Config
p.SetRootCAs(&cfg)
if cfg.RootCAs != nil {
t.Errorf("cfg.RootCAs wants nil but was %+v", cfg.RootCAs)
}
}
func readFile(t *testing.T, filename string) string {
t.Helper()
b, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("ReadFile error: %s", err)
}
return string(b)
}

View File

@@ -1,74 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/int128/kubelogin/pkg/adaptors/certpool (interfaces: Interface)
// Package mock_certpool is a generated GoMock package.
package mock_certpool
import (
tls "crypto/tls"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockInterface is a mock of Interface interface.
type MockInterface struct {
ctrl *gomock.Controller
recorder *MockInterfaceMockRecorder
}
// MockInterfaceMockRecorder is the mock recorder for MockInterface.
type MockInterfaceMockRecorder struct {
mock *MockInterface
}
// NewMockInterface creates a new mock instance.
func NewMockInterface(ctrl *gomock.Controller) *MockInterface {
mock := &MockInterface{ctrl: ctrl}
mock.recorder = &MockInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder {
return m.recorder
}
// AddBase64Encoded mocks base method.
func (m *MockInterface) AddBase64Encoded(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddBase64Encoded", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddBase64Encoded indicates an expected call of AddBase64Encoded.
func (mr *MockInterfaceMockRecorder) AddBase64Encoded(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBase64Encoded", reflect.TypeOf((*MockInterface)(nil).AddBase64Encoded), arg0)
}
// AddFile mocks base method.
func (m *MockInterface) AddFile(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddFile", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddFile indicates an expected call of AddFile.
func (mr *MockInterfaceMockRecorder) AddFile(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFile", reflect.TypeOf((*MockInterface)(nil).AddFile), arg0)
}
// SetRootCAs mocks base method.
func (m *MockInterface) SetRootCAs(arg0 *tls.Config) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetRootCAs", arg0)
}
// SetRootCAs indicates an expected call of SetRootCAs.
func (mr *MockInterfaceMockRecorder) SetRootCAs(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRootCAs", reflect.TypeOf((*MockInterface)(nil).SetRootCAs), arg0)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
"github.com/int128/kubelogin/pkg/usecases/authentication/ropc"
@@ -95,9 +96,6 @@ func TestCmd_Run(t *testing.T) {
KubeconfigFilename: "/path/to/kubeconfig",
KubeconfigContext: "hello.k8s.local",
KubeconfigUser: "google",
CACertFilename: "/path/to/cacert",
CACertData: "BASE64ENCODED",
SkipTLSVerify: true,
GrantOptionSet: authentication.GrantOptionSet{
AuthCodeBrowserOption: &authcode.BrowserOption{
BindAddress: []string{"127.0.0.1:10080", "127.0.0.1:20080"},
@@ -109,6 +107,11 @@ func TestCmd_Run(t *testing.T) {
RedirectURLHostname: "localhost",
},
},
TLSClientConfig: tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cacert"},
CACertData: []string{"BASE64ENCODED"},
SkipTLSVerify: true,
},
},
},
"GrantType=authcode-keyboard": {
@@ -244,14 +247,11 @@ func TestCmd_Run(t *testing.T) {
"--password", "PASS",
},
in: credentialplugin.Input{
TokenCacheDir: defaultTokenCacheDir,
IssuerURL: "https://issuer.example.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email", "profile"},
CACertFilename: "/path/to/cacert",
CACertData: "BASE64ENCODED",
SkipTLSVerify: true,
TokenCacheDir: defaultTokenCacheDir,
IssuerURL: "https://issuer.example.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email", "profile"},
GrantOptionSet: authentication.GrantOptionSet{
AuthCodeBrowserOption: &authcode.BrowserOption{
BindAddress: []string{"127.0.0.1:10080", "127.0.0.1:20080"},
@@ -264,6 +264,11 @@ func TestCmd_Run(t *testing.T) {
AuthRequestExtraParams: map[string]string{"ttl": "86400", "reauth": "true"},
},
},
TLSClientConfig: tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cacert"},
CACertData: []string{"BASE64ENCODED"},
SkipTLSVerify: true,
},
},
},
"GrantType=authcode-keyboard": {

View File

@@ -57,15 +57,13 @@ func (cmd *GetToken) New() *cobra.Command {
return xerrors.Errorf("get-token: %w", err)
}
in := credentialplugin.Input{
IssuerURL: o.IssuerURL,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
ExtraScopes: o.ExtraScopes,
CACertFilename: o.tlsOptions.CACertFilename,
CACertData: o.tlsOptions.CACertData,
SkipTLSVerify: o.tlsOptions.SkipTLSVerify,
TokenCacheDir: o.TokenCacheDir,
GrantOptionSet: grantOptionSet,
IssuerURL: o.IssuerURL,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
ExtraScopes: o.ExtraScopes,
TokenCacheDir: o.TokenCacheDir,
GrantOptionSet: grantOptionSet,
TLSClientConfig: o.tlsOptions.tlsClientConfig(),
}
if err := cmd.GetToken.Do(c.Context(), in); err != nil {
return xerrors.Errorf("get-token: %w", err)

View File

@@ -57,10 +57,8 @@ func (cmd *Root) New() *cobra.Command {
KubeconfigFilename: o.Kubeconfig,
KubeconfigContext: kubeconfig.ContextName(o.Context),
KubeconfigUser: kubeconfig.UserName(o.User),
CACertFilename: o.tlsOptions.CACertFilename,
CACertData: o.tlsOptions.CACertData,
SkipTLSVerify: o.tlsOptions.SkipTLSVerify,
GrantOptionSet: grantOptionSet,
TLSClientConfig: o.tlsOptions.tlsClientConfig(),
}
if err := cmd.Standalone.Do(c.Context(), in); err != nil {
return xerrors.Errorf("login: %w", err)

View File

@@ -42,14 +42,12 @@ func (cmd *Setup) New() *cobra.Command {
return xerrors.Errorf("setup: %w", err)
}
in := setup.Stage2Input{
IssuerURL: o.IssuerURL,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
ExtraScopes: o.ExtraScopes,
CACertFilename: o.tlsOptions.CACertFilename,
CACertData: o.tlsOptions.CACertData,
SkipTLSVerify: o.tlsOptions.SkipTLSVerify,
GrantOptionSet: grantOptionSet,
IssuerURL: o.IssuerURL,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
ExtraScopes: o.ExtraScopes,
GrantOptionSet: grantOptionSet,
TLSClientConfig: o.tlsOptions.tlsClientConfig(),
}
if c.Flags().Lookup("listen-address").Changed {
in.ListenAddressArgs = o.authenticationOptions.ListenAddress

View File

@@ -1,15 +1,26 @@
package cmd
import "github.com/spf13/pflag"
import (
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/spf13/pflag"
)
type tlsOptions struct {
CACertFilename string
CACertData string
CACertFilename []string
CACertData []string
SkipTLSVerify bool
}
func (o *tlsOptions) addFlags(f *pflag.FlagSet) {
f.StringVar(&o.CACertFilename, "certificate-authority", "", "Path to a cert file for the certificate authority")
f.StringVar(&o.CACertData, "certificate-authority-data", "", "Base64 encoded cert for the certificate authority")
f.StringArrayVar(&o.CACertFilename, "certificate-authority", nil, "Path to a cert file for the certificate authority")
f.StringArrayVar(&o.CACertData, "certificate-authority-data", nil, "Base64 encoded cert for the certificate authority")
f.BoolVar(&o.SkipTLSVerify, "insecure-skip-tls-verify", false, "If set, the server's certificate will not be checked for validity. This will make your HTTPS connections insecure")
}
func (o *tlsOptions) tlsClientConfig() tlsclientconfig.Config {
return tlsclientconfig.Config{
CACertFilename: o.CACertFilename,
CACertData: o.CACertData,
SkipTLSVerify: o.SkipTLSVerify,
}
}

View File

@@ -3,7 +3,6 @@ package oidcclient
import (
"context"
"crypto/tls"
"fmt"
"net/http"
@@ -13,33 +12,37 @@ import (
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/adaptors/oidcclient/logging"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/tlsclientconfig/loader"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
)
//go:generate mockgen -destination mock_oidcclient/mock_factory.go github.com/int128/kubelogin/pkg/adaptors/oidcclient FactoryInterface
var Set = wire.NewSet(
wire.Struct(new(Factory), "*"),
wire.Bind(new(FactoryInterface), new(*Factory)),
)
type FactoryInterface interface {
New(ctx context.Context, p oidc.Provider) (Interface, error)
New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error)
}
type Factory struct {
Loader loader.Loader
Clock clock.Interface
Logger logger.Interface
}
// New returns an instance of adaptors.Interface with the given configuration.
func (f *Factory) New(ctx context.Context, p oidc.Provider) (Interface, error) {
var tlsConfig tls.Config
tlsConfig.InsecureSkipVerify = p.SkipTLSVerify
if p.CertPool != nil {
p.CertPool.SetRootCAs(&tlsConfig)
func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsclientconfig.Config) (Interface, error) {
rawTLSClientConfig, err := f.Loader.Load(tlsClientConfig)
if err != nil {
return nil, xerrors.Errorf("could not load the TLS client config: %w", err)
}
baseTransport := &http.Transport{
TLSClientConfig: &tlsConfig,
TLSClientConfig: rawTLSClientConfig,
Proxy: http.ProxyFromEnvironment,
}
loggingTransport := &logging.Transport{

View File

@@ -0,0 +1,52 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/int128/kubelogin/pkg/adaptors/oidcclient (interfaces: FactoryInterface)
// Package mock_oidcclient is a generated GoMock package.
package mock_oidcclient
import (
context "context"
gomock "github.com/golang/mock/gomock"
oidcclient "github.com/int128/kubelogin/pkg/adaptors/oidcclient"
oidc "github.com/int128/kubelogin/pkg/oidc"
tlsclientconfig "github.com/int128/kubelogin/pkg/tlsclientconfig"
reflect "reflect"
)
// MockFactoryInterface is a mock of FactoryInterface interface.
type MockFactoryInterface struct {
ctrl *gomock.Controller
recorder *MockFactoryInterfaceMockRecorder
}
// MockFactoryInterfaceMockRecorder is the mock recorder for MockFactoryInterface.
type MockFactoryInterfaceMockRecorder struct {
mock *MockFactoryInterface
}
// NewMockFactoryInterface creates a new mock instance.
func NewMockFactoryInterface(ctrl *gomock.Controller) *MockFactoryInterface {
mock := &MockFactoryInterface{ctrl: ctrl}
mock.recorder = &MockFactoryInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFactoryInterface) EXPECT() *MockFactoryInterfaceMockRecorder {
return m.recorder
}
// New mocks base method.
func (m *MockFactoryInterface) New(arg0 context.Context, arg1 oidc.Provider, arg2 tlsclientconfig.Config) (oidcclient.Interface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "New", arg0, arg1, arg2)
ret0, _ := ret[0].(oidcclient.Interface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// New indicates an expected call of New.
func (mr *MockFactoryInterfaceMockRecorder) New(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockFactoryInterface)(nil).New), arg0, arg1, arg2)
}

View File

@@ -6,7 +6,6 @@ package di
import (
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/adaptors/browser"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/clock"
"github.com/int128/kubelogin/pkg/adaptors/cmd"
"github.com/int128/kubelogin/pkg/adaptors/credentialpluginwriter"
@@ -17,6 +16,7 @@ import (
"github.com/int128/kubelogin/pkg/adaptors/reader"
"github.com/int128/kubelogin/pkg/adaptors/stdio"
"github.com/int128/kubelogin/pkg/adaptors/tokencache"
"github.com/int128/kubelogin/pkg/tlsclientconfig/loader"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/credentialplugin"
"github.com/int128/kubelogin/pkg/usecases/setup"
@@ -52,7 +52,7 @@ func NewCmdForHeadless(clock.Interface, stdio.Stdin, stdio.Stdout, logger.Interf
kubeconfig.Set,
tokencache.Set,
oidcclient.Set,
certpool.Set,
loader.Set,
credentialpluginwriter.Set,
mutex.Set,
)

View File

@@ -7,7 +7,6 @@ package di
import (
"github.com/int128/kubelogin/pkg/adaptors/browser"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/clock"
"github.com/int128/kubelogin/pkg/adaptors/cmd"
"github.com/int128/kubelogin/pkg/adaptors/credentialpluginwriter"
@@ -18,6 +17,7 @@ import (
"github.com/int128/kubelogin/pkg/adaptors/reader"
"github.com/int128/kubelogin/pkg/adaptors/stdio"
"github.com/int128/kubelogin/pkg/adaptors/tokencache"
"github.com/int128/kubelogin/pkg/tlsclientconfig/loader"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
"github.com/int128/kubelogin/pkg/usecases/authentication/ropc"
@@ -29,6 +29,7 @@ import (
// Injectors from di.go:
// NewCmd returns an instance of adaptors.Cmd.
func NewCmd() cmd.Interface {
clockReal := &clock.Real{}
stdin := _wireFileValue
@@ -44,8 +45,11 @@ var (
_wireOsFileValue = os.Stdout
)
// NewCmdForHeadless returns an instance of adaptors.Cmd for headless testing.
func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout stdio.Stdout, loggerInterface logger.Interface, browserInterface browser.Interface) cmd.Interface {
loaderLoader := loader.Loader{}
factory := &oidcclient.Factory{
Loader: loaderLoader,
Clock: clockInterface,
Logger: loggerInterface,
}
@@ -75,11 +79,9 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
kubeconfigKubeconfig := &kubeconfig.Kubeconfig{
Logger: loggerInterface,
}
newFunc := _wireNewFuncValue
standaloneStandalone := &standalone.Standalone{
Authentication: authenticationAuthentication,
Kubeconfig: kubeconfigKubeconfig,
NewCertPool: newFunc,
Logger: loggerInterface,
}
root := &cmd.Root{
@@ -96,7 +98,6 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
getToken := &credentialplugin.GetToken{
Authentication: authenticationAuthentication,
TokenCacheRepository: repository,
NewCertPool: newFunc,
Writer: writer,
Mutex: mutexMutex,
Logger: loggerInterface,
@@ -107,7 +108,6 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
}
setupSetup := &setup.Setup{
Authentication: authenticationAuthentication,
NewCertPool: newFunc,
Logger: loggerInterface,
}
cmdSetup := &cmd.Setup{
@@ -121,7 +121,3 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
}
return cmdCmd
}
var (
_wireNewFuncValue = certpool.NewFunc(certpool.New)
)

View File

@@ -5,19 +5,16 @@ import (
"encoding/base64"
"encoding/binary"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/jwt"
"golang.org/x/xerrors"
)
// Provider represents an OIDC provider.
type Provider struct {
IssuerURL string
ClientID string
ClientSecret string // optional
ExtraScopes []string // optional
CertPool certpool.Interface // optional
SkipTLSVerify bool // optional
IssuerURL string
ClientID string
ClientSecret string // optional
ExtraScopes []string // optional
}
// TokenSet represents a set of ID token and refresh token.

View File

@@ -0,0 +1,8 @@
package tlsclientconfig
// Config represents a config for TLS client.
type Config struct {
CACertFilename []string
CACertData []string
SkipTLSVerify bool
}

View File

@@ -0,0 +1,66 @@
// Package loader provides loading certificates from files or base64 encoded string.
package loader
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"io/ioutil"
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"golang.org/x/xerrors"
)
// Set provides an implementation and interface.
var Set = wire.NewSet(
wire.Struct(new(Loader), "*"),
wire.Bind(new(Interface), new(*Loader)),
)
type Interface interface {
Load(config tlsclientconfig.Config) (*tls.Config, error)
}
// Loader represents a pool of certificates.
type Loader struct{}
func (l *Loader) Load(config tlsclientconfig.Config) (*tls.Config, error) {
rootCAs := x509.NewCertPool()
for _, f := range config.CACertFilename {
if err := addFile(rootCAs, f); err != nil {
return nil, xerrors.Errorf("could not load the certificate from %s: %w", f, err)
}
}
for _, d := range config.CACertData {
if err := addBase64Encoded(rootCAs, d); err != nil {
return nil, xerrors.Errorf("could not load the certificate: %w", err)
}
}
return &tls.Config{
RootCAs: rootCAs,
InsecureSkipVerify: config.SkipTLSVerify,
}, nil
}
func addFile(p *x509.CertPool, filename string) error {
b, err := ioutil.ReadFile(filename)
if err != nil {
return xerrors.Errorf("could not read: %w", err)
}
if !p.AppendCertsFromPEM(b) {
return xerrors.New("invalid certificate")
}
return nil
}
func addBase64Encoded(p *x509.CertPool, s string) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return xerrors.Errorf("could not decode base64: %w", err)
}
if !p.AppendCertsFromPEM(b) {
return xerrors.New("invalid certificate")
}
return nil
}

View File

@@ -0,0 +1,51 @@
package loader
import (
"io/ioutil"
"testing"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
)
func TestLoader_Load(t *testing.T) {
var loader Loader
t.Run("ValidFile", func(t *testing.T) {
cfg, err := loader.Load(tlsclientconfig.Config{
CACertFilename: []string{"testdata/ca1.crt"},
})
if err != nil {
t.Errorf("Load error: %s", err)
}
if n := len(cfg.RootCAs.Subjects()); n != 1 {
t.Errorf("n wants 1 but was %d", n)
}
})
t.Run("InvalidFile", func(t *testing.T) {
_, err := loader.Load(tlsclientconfig.Config{
CACertFilename: []string{"testdata/Makefile"},
})
if err == nil {
t.Errorf("AddFile wants an error but was nil")
}
})
t.Run("ValidBase64", func(t *testing.T) {
cfg, err := loader.Load(tlsclientconfig.Config{
CACertData: []string{readFile(t, "testdata/ca2.crt.base64")},
})
if err != nil {
t.Errorf("Load error: %s", err)
}
if n := len(cfg.RootCAs.Subjects()); n != 1 {
t.Errorf("n wants 1 but was %d", n)
}
})
}
func readFile(t *testing.T, filename string) string {
t.Helper()
b, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("ReadFile error: %s", err)
}
return string(b)
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/adaptors/oidcclient"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
"github.com/int128/kubelogin/pkg/usecases/authentication/ropc"
"golang.org/x/xerrors"
@@ -30,9 +31,10 @@ type Interface interface {
// Input represents an input DTO of the Authentication use-case.
type Input struct {
Provider oidc.Provider
GrantOptionSet GrantOptionSet
CachedTokenSet *oidc.TokenSet // optional
Provider oidc.Provider
GrantOptionSet GrantOptionSet
CachedTokenSet *oidc.TokenSet // optional
TLSClientConfig tlsclientconfig.Config
}
type GrantOptionSet struct {
@@ -90,7 +92,7 @@ func (u *Authentication) Do(ctx context.Context, in Input) (*Output, error) {
}
u.Logger.V(1).Infof("initializing an OpenID Connect client")
client, err := u.OIDCClient.New(ctx, in.Provider)
client, err := u.OIDCClient.New(ctx, in.Provider, in.TLSClientConfig)
if err != nil {
return nil, xerrors.Errorf("oidc error: %w", err)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/int128/kubelogin/pkg/testing/clock"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
testingLogger "github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
"github.com/int128/kubelogin/pkg/usecases/authentication/ropc"
"golang.org/x/xerrors"
@@ -26,6 +27,9 @@ func TestAuthentication_Do(t *testing.T) {
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
dummyTLSClientConfig := tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cert"},
}
issuedIDToken := testingJWT.EncodeF(t, func(claims *testingJWT.Claims) {
claims.Issuer = "https://accounts.google.com"
claims.Subject = "YOUR_SUBJECT"
@@ -38,7 +42,8 @@ func TestAuthentication_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
in := Input{
Provider: dummyProvider,
Provider: dummyProvider,
TLSClientConfig: dummyTLSClientConfig,
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
},
@@ -68,7 +73,8 @@ func TestAuthentication_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
in := Input{
Provider: dummyProvider,
Provider: dummyProvider,
TLSClientConfig: dummyTLSClientConfig,
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
RefreshToken: "VALID_REFRESH_TOKEN",
@@ -81,18 +87,14 @@ func TestAuthentication_Do(t *testing.T) {
IDToken: "NEW_ID_TOKEN",
RefreshToken: "NEW_REFRESH_TOKEN",
}, nil)
mockOIDCClientFactory := mock_oidcclient.NewMockFactoryInterface(ctrl)
mockOIDCClientFactory.EXPECT().
New(ctx, dummyProvider, dummyTLSClientConfig).
Return(mockOIDCClient, nil)
u := Authentication{
OIDCClient: &oidcclientFactory{
t: t,
client: mockOIDCClient,
want: oidc.Provider{
IssuerURL: "https://issuer.example.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
},
Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
OIDCClient: mockOIDCClientFactory,
Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
}
got, err := u.Do(ctx, in)
if err != nil {
@@ -115,7 +117,8 @@ func TestAuthentication_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
in := Input{
Provider: dummyProvider,
Provider: dummyProvider,
TLSClientConfig: dummyTLSClientConfig,
GrantOptionSet: GrantOptionSet{
AuthCodeBrowserOption: &authcode.BrowserOption{
BindAddress: []string{"127.0.0.1:8000"},
@@ -142,18 +145,14 @@ func TestAuthentication_Do(t *testing.T) {
IDToken: "NEW_ID_TOKEN",
RefreshToken: "NEW_REFRESH_TOKEN",
}, nil)
mockOIDCClientFactory := mock_oidcclient.NewMockFactoryInterface(ctrl)
mockOIDCClientFactory.EXPECT().
New(ctx, dummyProvider, dummyTLSClientConfig).
Return(mockOIDCClient, nil)
u := Authentication{
OIDCClient: &oidcclientFactory{
t: t,
client: mockOIDCClient,
want: oidc.Provider{
IssuerURL: "https://issuer.example.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
},
Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
OIDCClient: mockOIDCClientFactory,
Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
AuthCodeBrowser: &authcode.Browser{
Logger: testingLogger.New(t),
},
@@ -179,13 +178,14 @@ func TestAuthentication_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
in := Input{
Provider: dummyProvider,
TLSClientConfig: dummyTLSClientConfig,
GrantOptionSet: GrantOptionSet{
ROPCOption: &ropc.Option{
Username: "USER",
Password: "PASS",
},
},
Provider: dummyProvider,
}
mockOIDCClient := mock_oidcclient.NewMockInterface(ctrl)
mockOIDCClient.EXPECT().
@@ -194,17 +194,13 @@ func TestAuthentication_Do(t *testing.T) {
IDToken: "YOUR_ID_TOKEN",
RefreshToken: "YOUR_REFRESH_TOKEN",
}, nil)
mockOIDCClientFactory := mock_oidcclient.NewMockFactoryInterface(ctrl)
mockOIDCClientFactory.EXPECT().
New(ctx, dummyProvider, dummyTLSClientConfig).
Return(mockOIDCClient, nil)
u := Authentication{
OIDCClient: &oidcclientFactory{
t: t,
client: mockOIDCClient,
want: oidc.Provider{
IssuerURL: "https://issuer.example.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
},
Logger: testingLogger.New(t),
OIDCClient: mockOIDCClientFactory,
Logger: testingLogger.New(t),
ROPC: &ropc.ROPC{
Logger: testingLogger.New(t),
},
@@ -224,16 +220,3 @@ func TestAuthentication_Do(t *testing.T) {
}
})
}
type oidcclientFactory struct {
t *testing.T
client oidcclient.Interface
want oidc.Provider
}
func (f *oidcclientFactory) New(_ context.Context, got oidc.Provider) (oidcclient.Interface, error) {
if diff := cmp.Diff(f.want, got); diff != "" {
f.t.Errorf("mismatch (-want +got):\n%s", diff)
}
return f.client, nil
}

View File

@@ -5,14 +5,16 @@ package credentialplugin
import (
"context"
"strings"
"github.com/int128/kubelogin/pkg/adaptors/mutex"
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/credentialpluginwriter"
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/adaptors/tokencache"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"golang.org/x/xerrors"
)
@@ -30,21 +32,18 @@ type Interface interface {
// Input represents an input DTO of the GetToken use-case.
type Input struct {
IssuerURL string
ClientID string
ClientSecret string
ExtraScopes []string // optional
CACertFilename string // optional
CACertData string // optional
SkipTLSVerify bool
TokenCacheDir string
GrantOptionSet authentication.GrantOptionSet
IssuerURL string
ClientID string
ClientSecret string
ExtraScopes []string // optional
TokenCacheDir string
GrantOptionSet authentication.GrantOptionSet
TLSClientConfig tlsclientconfig.Config
}
type GetToken struct {
Authentication authentication.Interface
TokenCacheRepository tokencache.Interface
NewCertPool certpool.NewFunc
Writer credentialpluginwriter.Interface
Mutex mutex.Interface
Logger logger.Interface
@@ -68,9 +67,9 @@ func (u *GetToken) Do(ctx context.Context, in Input) error {
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
CACertFilename: in.CACertFilename,
CACertData: in.CACertData,
SkipTLSVerify: in.SkipTLSVerify,
CACertFilename: strings.Join(in.TLSClientConfig.CACertFilename, ","),
CACertData: strings.Join(in.TLSClientConfig.CACertData, ","),
SkipTLSVerify: in.TLSClientConfig.SkipTLSVerify,
}
if in.GrantOptionSet.ROPCOption != nil {
tokenCacheKey.Username = in.GrantOptionSet.ROPCOption.Username
@@ -80,28 +79,16 @@ func (u *GetToken) Do(ctx context.Context, in Input) error {
u.Logger.V(1).Infof("could not find a token cache: %s", err)
}
certPool := u.NewCertPool()
if in.CACertFilename != "" {
if err := certPool.AddFile(in.CACertFilename); err != nil {
return xerrors.Errorf("could not load the certificate file: %w", err)
}
}
if in.CACertData != "" {
if err := certPool.AddBase64Encoded(in.CACertData); err != nil {
return xerrors.Errorf("could not load the certificate data: %w", err)
}
}
authenticationInput := authentication.Input{
Provider: oidc.Provider{
IssuerURL: in.IssuerURL,
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
CertPool: certPool,
SkipTLSVerify: in.SkipTLSVerify,
IssuerURL: in.IssuerURL,
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
},
GrantOptionSet: in.GrantOptionSet,
CachedTokenSet: cachedTokenSet,
GrantOptionSet: in.GrantOptionSet,
CachedTokenSet: cachedTokenSet,
TLSClientConfig: in.TLSClientConfig,
}
authenticationOutput, err := u.Authentication.Do(ctx, authenticationInput)
if err != nil {

View File

@@ -9,8 +9,6 @@ import (
"github.com/int128/kubelogin/pkg/adaptors/mutex/mock_mutex"
"github.com/golang/mock/gomock"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/certpool/mock_certpool"
"github.com/int128/kubelogin/pkg/adaptors/credentialpluginwriter"
"github.com/int128/kubelogin/pkg/adaptors/credentialpluginwriter/mock_credentialpluginwriter"
"github.com/int128/kubelogin/pkg/adaptors/tokencache"
@@ -18,6 +16,7 @@ import (
"github.com/int128/kubelogin/pkg/oidc"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/mock_authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/ropc"
@@ -52,14 +51,12 @@ func TestGetToken_Do(t *testing.T) {
TokenCacheDir: "/path/to/token-cache",
GrantOptionSet: grantOptionSet,
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockAuthentication := mock_authentication.NewMockInterface(ctrl)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
CertPool: mockCertPool,
},
GrantOptionSet: grantOptionSet,
}).
@@ -79,7 +76,6 @@ func TestGetToken_Do(t *testing.T) {
u := GetToken{
Authentication: mockAuthentication,
TokenCacheRepository: tokenCacheRepository,
NewCertPool: func() certpool.Interface { return mockCertPool },
Writer: credentialPluginWriter,
Mutex: setupMutexMock(ctrl),
Logger: logger.New(t),
@@ -106,36 +102,33 @@ func TestGetToken_Do(t *testing.T) {
CACertData: "BASE64ENCODED",
SkipTLSVerify: true,
}
tlsClientConfig := tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cert"},
CACertData: []string{"BASE64ENCODED"},
SkipTLSVerify: true,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ctx := context.TODO()
in := Input{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
TokenCacheDir: "/path/to/token-cache",
CACertFilename: "/path/to/cert",
CACertData: "BASE64ENCODED",
SkipTLSVerify: true,
GrantOptionSet: grantOptionSet,
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
TokenCacheDir: "/path/to/token-cache",
GrantOptionSet: grantOptionSet,
TLSClientConfig: tlsClientConfig,
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockCertPool.EXPECT().
AddFile("/path/to/cert")
mockCertPool.EXPECT().
AddBase64Encoded("BASE64ENCODED")
mockAuthentication := mock_authentication.NewMockInterface(ctrl)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
SkipTLSVerify: true,
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
GrantOptionSet: grantOptionSet,
GrantOptionSet: grantOptionSet,
TLSClientConfig: tlsClientConfig,
}).
Return(&authentication.Output{TokenSet: tokenSet}, nil)
tokenCacheRepository := mock_tokencache.NewMockInterface(ctrl)
@@ -153,7 +146,6 @@ func TestGetToken_Do(t *testing.T) {
u := GetToken{
Authentication: mockAuthentication,
TokenCacheRepository: tokenCacheRepository,
NewCertPool: func() certpool.Interface { return mockCertPool },
Writer: credentialPluginWriter,
Mutex: setupMutexMock(ctrl),
Logger: logger.New(t),
@@ -173,7 +165,6 @@ func TestGetToken_Do(t *testing.T) {
ClientSecret: "YOUR_CLIENT_SECRET",
TokenCacheDir: "/path/to/token-cache",
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockAuthentication := mock_authentication.NewMockInterface(ctrl)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
@@ -181,7 +172,6 @@ func TestGetToken_Do(t *testing.T) {
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
},
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
@@ -212,7 +202,6 @@ func TestGetToken_Do(t *testing.T) {
u := GetToken{
Authentication: mockAuthentication,
TokenCacheRepository: tokenCacheRepository,
NewCertPool: func() certpool.Interface { return mockCertPool },
Writer: credentialPluginWriter,
Mutex: setupMutexMock(ctrl),
Logger: logger.New(t),
@@ -232,7 +221,6 @@ func TestGetToken_Do(t *testing.T) {
ClientSecret: "YOUR_CLIENT_SECRET",
TokenCacheDir: "/path/to/token-cache",
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockAuthentication := mock_authentication.NewMockInterface(ctrl)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
@@ -240,7 +228,6 @@ func TestGetToken_Do(t *testing.T) {
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
},
}).
Return(nil, xerrors.New("authentication error"))
@@ -255,7 +242,6 @@ func TestGetToken_Do(t *testing.T) {
u := GetToken{
Authentication: mockAuthentication,
TokenCacheRepository: tokenCacheRepository,
NewCertPool: func() certpool.Interface { return mockCertPool },
Writer: mock_credentialpluginwriter.NewMockInterface(ctrl),
Mutex: setupMutexMock(ctrl),
Logger: logger.New(t),

View File

@@ -5,7 +5,6 @@ import (
"context"
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/usecases/authentication"
)
@@ -22,6 +21,5 @@ type Interface interface {
type Setup struct {
Authentication authentication.Interface
NewCertPool certpool.NewFunc
Logger logger.Interface
}

View File

@@ -7,6 +7,7 @@ import (
"text/template"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"golang.org/x/xerrors"
)
@@ -72,36 +73,22 @@ type Stage2Input struct {
ClientID string
ClientSecret string
ExtraScopes []string // optional
CACertFilename string // optional
CACertData string // optional
SkipTLSVerify bool
ListenAddressArgs []string // non-nil if set by the command arg
GrantOptionSet authentication.GrantOptionSet
TLSClientConfig tlsclientconfig.Config
}
func (u *Setup) DoStage2(ctx context.Context, in Stage2Input) error {
u.Logger.Printf("authentication in progress...")
certPool := u.NewCertPool()
if in.CACertFilename != "" {
if err := certPool.AddFile(in.CACertFilename); err != nil {
return xerrors.Errorf("could not load the certificate file: %w", err)
}
}
if in.CACertData != "" {
if err := certPool.AddBase64Encoded(in.CACertData); err != nil {
return xerrors.Errorf("could not load the certificate data: %w", err)
}
}
out, err := u.Authentication.Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: in.IssuerURL,
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
CertPool: certPool,
SkipTLSVerify: in.SkipTLSVerify,
IssuerURL: in.IssuerURL,
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
ExtraScopes: in.ExtraScopes,
},
GrantOptionSet: in.GrantOptionSet,
GrantOptionSet: in.GrantOptionSet,
TLSClientConfig: in.TLSClientConfig,
})
if err != nil {
return xerrors.Errorf("authentication error: %w", err)
@@ -136,13 +123,13 @@ func makeCredentialPluginArgs(in Stage2Input) []string {
for _, extraScope := range in.ExtraScopes {
args = append(args, "--oidc-extra-scope="+extraScope)
}
if in.CACertFilename != "" {
args = append(args, "--certificate-authority="+in.CACertFilename)
for _, f := range in.TLSClientConfig.CACertFilename {
args = append(args, "--certificate-authority="+f)
}
if in.CACertData != "" {
args = append(args, "--certificate-authority-data="+in.CACertData)
for _, d := range in.TLSClientConfig.CACertData {
args = append(args, "--certificate-authority-data="+d)
}
if in.SkipTLSVerify {
if in.TLSClientConfig.SkipTLSVerify {
args = append(args, "--insecure-skip-tls-verify")
}

View File

@@ -6,11 +6,10 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/certpool/mock_certpool"
"github.com/int128/kubelogin/pkg/oidc"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/mock_authentication"
)
@@ -21,37 +20,33 @@ func TestSetup_DoStage2(t *testing.T) {
claims.Subject = "YOUR_SUBJECT"
claims.ExpiresAt = time.Now().Add(1 * time.Hour).Unix()
})
dummyTLSClientConfig := tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cert"},
}
var grantOptionSet authentication.GrantOptionSet
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ctx := context.Background()
in := Stage2Input{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email"},
CACertFilename: "/path/to/cert",
SkipTLSVerify: true,
GrantOptionSet: grantOptionSet,
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email"},
GrantOptionSet: grantOptionSet,
TLSClientConfig: dummyTLSClientConfig,
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockCertPool.EXPECT().
AddFile("/path/to/cert")
mockAuthentication := mock_authentication.NewMockInterface(ctrl)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email"},
CertPool: mockCertPool,
SkipTLSVerify: true,
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
ExtraScopes: []string{"email"},
},
GrantOptionSet: grantOptionSet,
GrantOptionSet: grantOptionSet,
TLSClientConfig: dummyTLSClientConfig,
}).
Return(&authentication.Output{
TokenSet: oidc.TokenSet{
@@ -61,7 +56,6 @@ func TestSetup_DoStage2(t *testing.T) {
}, nil)
u := Setup{
Authentication: mockAuthentication,
NewCertPool: func() certpool.Interface { return mockCertPool },
Logger: logger.New(t),
}
if err := u.DoStage2(ctx, in); err != nil {

View File

@@ -4,10 +4,10 @@ import (
"context"
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/kubeconfig"
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"golang.org/x/xerrors"
)
@@ -29,10 +29,8 @@ type Input struct {
KubeconfigFilename string // Default to the environment variable or global config as kubectl
KubeconfigContext kubeconfig.ContextName // Default to the current context but ignored if KubeconfigUser is set
KubeconfigUser kubeconfig.UserName // Default to the user of the context
CACertFilename string // optional
CACertData string // optional
SkipTLSVerify bool
GrantOptionSet authentication.GrantOptionSet
TLSClientConfig tlsclientconfig.Config
}
const oidcConfigErrorMessage = `No configuration found.
@@ -58,7 +56,6 @@ See https://github.com/int128/kubelogin for more.
type Standalone struct {
Authentication authentication.Interface
Kubeconfig kubeconfig.Interface
NewCertPool certpool.NewFunc
Logger logger.Interface
}
@@ -73,26 +70,13 @@ func (u *Standalone) Do(ctx context.Context, in Input) error {
u.Logger.Printf(deprecationMessage)
u.Logger.V(1).Infof("using the authentication provider of the user %s", authProvider.UserName)
u.Logger.V(1).Infof("a token will be written to %s", authProvider.LocationOfOrigin)
certPool := u.NewCertPool()
if authProvider.IDPCertificateAuthority != "" {
if err := certPool.AddFile(authProvider.IDPCertificateAuthority); err != nil {
return xerrors.Errorf("could not load the certificate of idp-certificate-authority: %w", err)
}
u.Logger.V(1).Infof("using the certificate %s", authProvider.IDPCertificateAuthority)
in.TLSClientConfig.CACertFilename = append(in.TLSClientConfig.CACertFilename, authProvider.IDPCertificateAuthority)
}
if authProvider.IDPCertificateAuthorityData != "" {
if err := certPool.AddBase64Encoded(authProvider.IDPCertificateAuthorityData); err != nil {
return xerrors.Errorf("could not load the certificate of idp-certificate-authority-data: %w", err)
}
}
if in.CACertFilename != "" {
if err := certPool.AddFile(in.CACertFilename); err != nil {
return xerrors.Errorf("could not load the certificate file: %w", err)
}
}
if in.CACertData != "" {
if err := certPool.AddBase64Encoded(in.CACertData); err != nil {
return xerrors.Errorf("could not load the certificate data: %w", err)
}
u.Logger.V(1).Infof("using the certificate in %s", authProvider.LocationOfOrigin)
in.TLSClientConfig.CACertData = append(in.TLSClientConfig.CACertData, authProvider.IDPCertificateAuthorityData)
}
var cachedTokenSet *oidc.TokenSet
if authProvider.IDToken != "" {
@@ -101,17 +85,17 @@ func (u *Standalone) Do(ctx context.Context, in Input) error {
RefreshToken: authProvider.RefreshToken,
}
}
authenticationInput := authentication.Input{
Provider: oidc.Provider{
IssuerURL: authProvider.IDPIssuerURL,
ClientID: authProvider.ClientID,
ClientSecret: authProvider.ClientSecret,
ExtraScopes: authProvider.ExtraScopes,
CertPool: certPool,
SkipTLSVerify: in.SkipTLSVerify,
IssuerURL: authProvider.IDPIssuerURL,
ClientID: authProvider.ClientID,
ClientSecret: authProvider.ClientSecret,
ExtraScopes: authProvider.ExtraScopes,
},
GrantOptionSet: in.GrantOptionSet,
CachedTokenSet: cachedTokenSet,
GrantOptionSet: in.GrantOptionSet,
CachedTokenSet: cachedTokenSet,
TLSClientConfig: in.TLSClientConfig,
}
authenticationOutput, err := u.Authentication.Do(ctx, authenticationInput)
if err != nil {

View File

@@ -6,13 +6,12 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/int128/kubelogin/pkg/adaptors/certpool"
"github.com/int128/kubelogin/pkg/adaptors/certpool/mock_certpool"
"github.com/int128/kubelogin/pkg/adaptors/kubeconfig"
"github.com/int128/kubelogin/pkg/adaptors/kubeconfig/mock_kubeconfig"
"github.com/int128/kubelogin/pkg/oidc"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
"github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/mock_authentication"
"golang.org/x/xerrors"
@@ -35,9 +34,6 @@ func TestStandalone_Do(t *testing.T) {
KubeconfigFilename: "/path/to/kubeconfig",
KubeconfigContext: "theContext",
KubeconfigUser: "theUser",
CACertFilename: "/path/to/cert1",
CACertData: "BASE64ENCODED1",
SkipTLSVerify: true,
GrantOptionSet: grantOptionSet,
}
currentAuthProvider := &kubeconfig.AuthProvider{
@@ -49,15 +45,6 @@ func TestStandalone_Do(t *testing.T) {
IDPCertificateAuthority: "/path/to/cert2",
IDPCertificateAuthorityData: "BASE64ENCODED2",
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockCertPool.EXPECT().
AddFile("/path/to/cert1")
mockCertPool.EXPECT().
AddFile("/path/to/cert2")
mockCertPool.EXPECT().
AddBase64Encoded("BASE64ENCODED1")
mockCertPool.EXPECT().
AddBase64Encoded("BASE64ENCODED2")
mockKubeconfig := mock_kubeconfig.NewMockInterface(ctrl)
mockKubeconfig.EXPECT().
GetCurrentAuthProvider("/path/to/kubeconfig", kubeconfig.ContextName("theContext"), kubeconfig.UserName("theUser")).
@@ -78,13 +65,15 @@ func TestStandalone_Do(t *testing.T) {
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
SkipTLSVerify: true,
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
GrantOptionSet: grantOptionSet,
TLSClientConfig: tlsclientconfig.Config{
CACertFilename: []string{"/path/to/cert2"},
CACertData: []string{"BASE64ENCODED2"},
},
}).
Return(&authentication.Output{
TokenSet: oidc.TokenSet{
@@ -95,7 +84,6 @@ func TestStandalone_Do(t *testing.T) {
u := Standalone{
Authentication: mockAuthentication,
Kubeconfig: mockKubeconfig,
NewCertPool: func() certpool.Interface { return mockCertPool },
Logger: logger.New(t),
}
if err := u.Do(ctx, in); err != nil {
@@ -116,7 +104,6 @@ func TestStandalone_Do(t *testing.T) {
ClientSecret: "YOUR_CLIENT_SECRET",
IDToken: issuedIDToken,
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockKubeconfig := mock_kubeconfig.NewMockInterface(ctrl)
mockKubeconfig.EXPECT().
GetCurrentAuthProvider("", kubeconfig.ContextName(""), kubeconfig.UserName("")).
@@ -128,7 +115,6 @@ func TestStandalone_Do(t *testing.T) {
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
},
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
@@ -143,7 +129,6 @@ func TestStandalone_Do(t *testing.T) {
u := Standalone{
Authentication: mockAuthentication,
Kubeconfig: mockKubeconfig,
NewCertPool: func() certpool.Interface { return mockCertPool },
Logger: logger.New(t),
}
if err := u.Do(ctx, in); err != nil {
@@ -183,7 +168,6 @@ func TestStandalone_Do(t *testing.T) {
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockKubeconfig := mock_kubeconfig.NewMockInterface(ctrl)
mockKubeconfig.EXPECT().
GetCurrentAuthProvider("", kubeconfig.ContextName(""), kubeconfig.UserName("")).
@@ -195,14 +179,12 @@ func TestStandalone_Do(t *testing.T) {
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
},
}).
Return(nil, xerrors.New("authentication error"))
u := Standalone{
Authentication: mockAuthentication,
Kubeconfig: mockKubeconfig,
NewCertPool: func() certpool.Interface { return mockCertPool },
Logger: logger.New(t),
}
if err := u.Do(ctx, in); err == nil {
@@ -222,7 +204,6 @@ func TestStandalone_Do(t *testing.T) {
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
mockCertPool := mock_certpool.NewMockInterface(ctrl)
mockKubeconfig := mock_kubeconfig.NewMockInterface(ctrl)
mockKubeconfig.EXPECT().
GetCurrentAuthProvider("", kubeconfig.ContextName(""), kubeconfig.UserName("")).
@@ -245,7 +226,6 @@ func TestStandalone_Do(t *testing.T) {
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
CertPool: mockCertPool,
},
}).
Return(&authentication.Output{
@@ -257,7 +237,6 @@ func TestStandalone_Do(t *testing.T) {
u := Standalone{
Authentication: mockAuthentication,
Kubeconfig: mockKubeconfig,
NewCertPool: func() certpool.Interface { return mockCertPool },
Logger: logger.New(t),
}
if err := u.Do(ctx, in); err == nil {