Lock token cache file before authentication (#1126)

* Lock token cache file in authentication

* Fix tests

* make generate

* Lock before FindByKey

* Fix test
This commit is contained in:
Hidetake Iwata
2024-09-21 14:54:32 +09:00
committed by GitHub
parent f0c3628f2a
commit 3d114bfeba
18 changed files with 290 additions and 515 deletions

View File

@@ -6,3 +6,6 @@ packages:
config: config:
all: true all: true
recursive: true recursive: true
io:
interfaces:
Closer:

2
go.mod
View File

@@ -3,9 +3,9 @@ module github.com/int128/kubelogin
go 1.22.2 go 1.22.2
require ( require (
github.com/alexflint/go-filemutex v1.3.0
github.com/chromedp/chromedp v0.10.0 github.com/chromedp/chromedp v0.10.0
github.com/coreos/go-oidc/v3 v3.11.0 github.com/coreos/go-oidc/v3 v3.11.0
github.com/gofrs/flock v0.12.1
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/google/wire v0.6.0 github.com/google/wire v0.6.0

7
go.sum
View File

@@ -33,8 +33,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chromedp/cdproto v0.0.0-20240801214329-3f85d328b335 h1:bATMoZLH2QGct1kzDxfmeBUQI/QhQvB0mBrOTct+YlQ= github.com/chromedp/cdproto v0.0.0-20240801214329-3f85d328b335 h1:bATMoZLH2QGct1kzDxfmeBUQI/QhQvB0mBrOTct+YlQ=
github.com/chromedp/cdproto v0.0.0-20240801214329-3f85d328b335/go.mod h1:GKljq0VrfU4D5yc+2qA6OVr8pmO/MBbPEWqWQ/oqGEs= github.com/chromedp/cdproto v0.0.0-20240801214329-3f85d328b335/go.mod h1:GKljq0VrfU4D5yc+2qA6OVr8pmO/MBbPEWqWQ/oqGEs=
@@ -81,6 +79,8 @@ github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
@@ -514,8 +514,9 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=

View File

@@ -3,6 +3,8 @@ package main
import ( import (
"context" "context"
"os" "os"
"os/signal"
"syscall"
"github.com/int128/kubelogin/pkg/di" "github.com/int128/kubelogin/pkg/di"
) )
@@ -10,5 +12,8 @@ import (
var version = "HEAD" var version = "HEAD"
func main() { func main() {
os.Exit(di.NewCmd().Run(context.Background(), os.Args, version)) ctx := context.Background()
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()
os.Exit(di.NewCmd().Run(ctx, os.Args, version))
} }

View File

@@ -1,142 +0,0 @@
// Code generated by mockery v2.46.0. DO NOT EDIT.
package mutex_mock
import (
context "context"
mutex "github.com/int128/kubelogin/pkg/infrastructure/mutex"
mock "github.com/stretchr/testify/mock"
)
// MockInterface is an autogenerated mock type for the Interface type
type MockInterface struct {
mock.Mock
}
type MockInterface_Expecter struct {
mock *mock.Mock
}
func (_m *MockInterface) EXPECT() *MockInterface_Expecter {
return &MockInterface_Expecter{mock: &_m.Mock}
}
// Acquire provides a mock function with given fields: ctx, name
func (_m *MockInterface) Acquire(ctx context.Context, name string) (*mutex.Lock, error) {
ret := _m.Called(ctx, name)
if len(ret) == 0 {
panic("no return value specified for Acquire")
}
var r0 *mutex.Lock
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*mutex.Lock, error)); ok {
return rf(ctx, name)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *mutex.Lock); ok {
r0 = rf(ctx, name)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*mutex.Lock)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, name)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockInterface_Acquire_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Acquire'
type MockInterface_Acquire_Call struct {
*mock.Call
}
// Acquire is a helper method to define mock.On call
// - ctx context.Context
// - name string
func (_e *MockInterface_Expecter) Acquire(ctx interface{}, name interface{}) *MockInterface_Acquire_Call {
return &MockInterface_Acquire_Call{Call: _e.mock.On("Acquire", ctx, name)}
}
func (_c *MockInterface_Acquire_Call) Run(run func(ctx context.Context, name string)) *MockInterface_Acquire_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockInterface_Acquire_Call) Return(_a0 *mutex.Lock, _a1 error) *MockInterface_Acquire_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockInterface_Acquire_Call) RunAndReturn(run func(context.Context, string) (*mutex.Lock, error)) *MockInterface_Acquire_Call {
_c.Call.Return(run)
return _c
}
// Release provides a mock function with given fields: lock
func (_m *MockInterface) Release(lock *mutex.Lock) error {
ret := _m.Called(lock)
if len(ret) == 0 {
panic("no return value specified for Release")
}
var r0 error
if rf, ok := ret.Get(0).(func(*mutex.Lock) error); ok {
r0 = rf(lock)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockInterface_Release_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Release'
type MockInterface_Release_Call struct {
*mock.Call
}
// Release is a helper method to define mock.On call
// - lock *mutex.Lock
func (_e *MockInterface_Expecter) Release(lock interface{}) *MockInterface_Release_Call {
return &MockInterface_Release_Call{Call: _e.mock.On("Release", lock)}
}
func (_c *MockInterface_Release_Call) Run(run func(lock *mutex.Lock)) *MockInterface_Release_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*mutex.Lock))
})
return _c
}
func (_c *MockInterface_Release_Call) Return(_a0 error) *MockInterface_Release_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockInterface_Release_Call) RunAndReturn(run func(*mutex.Lock) error) *MockInterface_Release_Call {
_c.Call.Return(run)
return _c
}
// NewMockInterface creates a new instance of MockInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockInterface(t interface {
mock.TestingT
Cleanup(func())
}) *MockInterface {
mock := &MockInterface{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@@ -3,6 +3,8 @@
package repository_mock package repository_mock
import ( import (
io "io"
oidc "github.com/int128/kubelogin/pkg/oidc" oidc "github.com/int128/kubelogin/pkg/oidc"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
@@ -81,6 +83,65 @@ func (_c *MockInterface_FindByKey_Call) RunAndReturn(run func(string, tokencache
return _c return _c
} }
// Lock provides a mock function with given fields: dir, key
func (_m *MockInterface) Lock(dir string, key tokencache.Key) (io.Closer, error) {
ret := _m.Called(dir, key)
if len(ret) == 0 {
panic("no return value specified for Lock")
}
var r0 io.Closer
var r1 error
if rf, ok := ret.Get(0).(func(string, tokencache.Key) (io.Closer, error)); ok {
return rf(dir, key)
}
if rf, ok := ret.Get(0).(func(string, tokencache.Key) io.Closer); ok {
r0 = rf(dir, key)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(io.Closer)
}
}
if rf, ok := ret.Get(1).(func(string, tokencache.Key) error); ok {
r1 = rf(dir, key)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockInterface_Lock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Lock'
type MockInterface_Lock_Call struct {
*mock.Call
}
// Lock is a helper method to define mock.On call
// - dir string
// - key tokencache.Key
func (_e *MockInterface_Expecter) Lock(dir interface{}, key interface{}) *MockInterface_Lock_Call {
return &MockInterface_Lock_Call{Call: _e.mock.On("Lock", dir, key)}
}
func (_c *MockInterface_Lock_Call) Run(run func(dir string, key tokencache.Key)) *MockInterface_Lock_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(tokencache.Key))
})
return _c
}
func (_c *MockInterface_Lock_Call) Return(_a0 io.Closer, _a1 error) *MockInterface_Lock_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockInterface_Lock_Call) RunAndReturn(run func(string, tokencache.Key) (io.Closer, error)) *MockInterface_Lock_Call {
_c.Call.Return(run)
return _c
}
// Save provides a mock function with given fields: dir, key, tokenSet // Save provides a mock function with given fields: dir, key, tokenSet
func (_m *MockInterface) Save(dir string, key tokencache.Key, tokenSet oidc.TokenSet) error { func (_m *MockInterface) Save(dir string, key tokencache.Key, tokenSet oidc.TokenSet) error {
ret := _m.Called(dir, key, tokenSet) ret := _m.Called(dir, key, tokenSet)

View File

@@ -0,0 +1,77 @@
// Code generated by mockery v2.46.0. DO NOT EDIT.
package io_mock
import mock "github.com/stretchr/testify/mock"
// MockCloser is an autogenerated mock type for the Closer type
type MockCloser struct {
mock.Mock
}
type MockCloser_Expecter struct {
mock *mock.Mock
}
func (_m *MockCloser) EXPECT() *MockCloser_Expecter {
return &MockCloser_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with given fields:
func (_m *MockCloser) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockCloser_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockCloser_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockCloser_Expecter) Close() *MockCloser_Close_Call {
return &MockCloser_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockCloser_Close_Call) Run(run func()) *MockCloser_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCloser_Close_Call) Return(_a0 error) *MockCloser_Close_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCloser_Close_Call) RunAndReturn(run func() error) *MockCloser_Close_Call {
_c.Call.Return(run)
return _c
}
// NewMockCloser creates a new instance of MockCloser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCloser(t interface {
mock.TestingT
Cleanup(func())
}) *MockCloser {
mock := &MockCloser{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/int128/kubelogin/pkg/infrastructure/browser" "github.com/int128/kubelogin/pkg/infrastructure/browser"
"github.com/int128/kubelogin/pkg/infrastructure/clock" "github.com/int128/kubelogin/pkg/infrastructure/clock"
"github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/infrastructure/logger"
"github.com/int128/kubelogin/pkg/infrastructure/mutex"
"github.com/int128/kubelogin/pkg/infrastructure/reader" "github.com/int128/kubelogin/pkg/infrastructure/reader"
"github.com/int128/kubelogin/pkg/infrastructure/stdio" "github.com/int128/kubelogin/pkg/infrastructure/stdio"
kubeconfigLoader "github.com/int128/kubelogin/pkg/kubeconfig/loader" kubeconfigLoader "github.com/int128/kubelogin/pkg/kubeconfig/loader"
@@ -57,7 +56,6 @@ func NewCmdForHeadless(clock.Interface, stdio.Stdin, stdio.Stdout, logger.Interf
client.Set, client.Set,
loader.Set, loader.Set,
writer.Set, writer.Set,
mutex.Set,
) )
return nil return nil
} }

View File

@@ -12,7 +12,6 @@ import (
"github.com/int128/kubelogin/pkg/infrastructure/browser" "github.com/int128/kubelogin/pkg/infrastructure/browser"
"github.com/int128/kubelogin/pkg/infrastructure/clock" "github.com/int128/kubelogin/pkg/infrastructure/clock"
"github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/infrastructure/logger"
"github.com/int128/kubelogin/pkg/infrastructure/mutex"
"github.com/int128/kubelogin/pkg/infrastructure/reader" "github.com/int128/kubelogin/pkg/infrastructure/reader"
"github.com/int128/kubelogin/pkg/infrastructure/stdio" "github.com/int128/kubelogin/pkg/infrastructure/stdio"
loader2 "github.com/int128/kubelogin/pkg/kubeconfig/loader" loader2 "github.com/int128/kubelogin/pkg/kubeconfig/loader"
@@ -78,7 +77,6 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
authenticationAuthentication := &authentication.Authentication{ authenticationAuthentication := &authentication.Authentication{
ClientFactory: factory, ClientFactory: factory,
Logger: loggerInterface, Logger: loggerInterface,
Clock: clockInterface,
AuthCodeBrowser: authcodeBrowser, AuthCodeBrowser: authcodeBrowser,
AuthCodeKeyboard: keyboard, AuthCodeKeyboard: keyboard,
ROPC: ropcROPC, ROPC: ropcROPC,
@@ -91,6 +89,7 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
KubeconfigLoader: loader3, KubeconfigLoader: loader3,
KubeconfigWriter: writerWriter, KubeconfigWriter: writerWriter,
Logger: loggerInterface, Logger: loggerInterface,
Clock: clockInterface,
} }
root := &cmd.Root{ root := &cmd.Root{
Standalone: standaloneStandalone, Standalone: standaloneStandalone,
@@ -100,15 +99,12 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout
writer3 := &writer2.Writer{ writer3 := &writer2.Writer{
Stdout: stdout, Stdout: stdout,
} }
mutexMutex := &mutex.Mutex{
Logger: loggerInterface,
}
getToken := &credentialplugin.GetToken{ getToken := &credentialplugin.GetToken{
Authentication: authenticationAuthentication, Authentication: authenticationAuthentication,
TokenCacheRepository: repositoryRepository, TokenCacheRepository: repositoryRepository,
Writer: writer3, Writer: writer3,
Mutex: mutexMutex,
Logger: loggerInterface, Logger: loggerInterface,
Clock: clockInterface,
} }
cmdGetToken := &cmd.GetToken{ cmdGetToken := &cmd.GetToken{
GetToken: getToken, GetToken: getToken,

View File

@@ -1,87 +0,0 @@
package mutex
import (
"context"
"fmt"
"os"
"path"
"github.com/alexflint/go-filemutex"
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/infrastructure/logger"
)
var Set = wire.NewSet(
wire.Struct(new(Mutex), "*"),
wire.Bind(new(Interface), new(*Mutex)),
)
type Interface interface {
Acquire(ctx context.Context, name string) (*Lock, error)
Release(lock *Lock) error
}
// Lock holds the lock data.
type Lock struct {
Data interface{}
Name string
}
type Mutex struct {
Logger logger.Interface
}
// internalAcquire wait for acquisition of the lock
func internalAcquire(fm *filemutex.FileMutex) chan error {
result := make(chan error)
go func() {
if err := fm.Lock(); err != nil {
result <- err
}
close(result)
}()
return result
}
// internalRelease disposes of resources associated with a lock
func internalRelease(fm *filemutex.FileMutex, lfn string, log logger.Interface) error {
err := fm.Close()
if err != nil {
log.V(1).Infof("Error closing lock file %s: %s", lfn, err)
}
return err
}
// LockFileName get the lock file name from the lock name.
func LockFileName(name string) string {
return path.Join(os.TempDir(), fmt.Sprintf(".kubelogin.%s.lock", name))
}
// Acquire acquire a lock for the specified name. The context could be used to set a timeout.
func (m *Mutex) Acquire(ctx context.Context, name string) (*Lock, error) {
lfn := LockFileName(name)
fm, err := filemutex.New(lfn)
if err != nil {
return nil, fmt.Errorf("error creating mutex file %s: %w", lfn, err)
}
lockChan := internalAcquire(fm)
select {
case <-ctx.Done():
_ = internalRelease(fm, lfn, m.Logger)
return nil, ctx.Err()
case err := <-lockChan:
if err != nil {
_ = internalRelease(fm, lfn, m.Logger)
return nil, fmt.Errorf("error acquiring lock on file %s: %w", lfn, err)
}
return &Lock{Data: fm, Name: name}, nil
}
}
// Release release the specified lock
func (m *Mutex) Release(lock *Lock) error {
fm := lock.Data.(*filemutex.FileMutex)
lfn := LockFileName(lock.Name)
return internalRelease(fm, lfn, m.Logger)
}

View File

@@ -1,65 +0,0 @@
package mutex
import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/int128/kubelogin/pkg/infrastructure/logger"
)
func TestMutex(t *testing.T) {
t.Run("Test successful parallel acquisition with no reentry allowed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nbConcurrency := 20
wg := sync.WaitGroup{}
events := make(chan int, nbConcurrency*2)
errors := make(chan error, nbConcurrency)
doLockUnlock := func() {
defer wg.Done()
m := Mutex{
Logger: logger.New(),
}
if mutex, err := m.Acquire(ctx, "test"); err == nil {
events <- 1
var dur = time.Duration(rand.Intn(5000))
time.Sleep(dur * time.Microsecond)
events <- -1
if err := m.Release(mutex); err != nil {
errors <- fmt.Errorf("Release error: %w", err)
}
} else {
errors <- fmt.Errorf("Acquire error: %w", err)
}
}
for i := 0; i < nbConcurrency; i++ {
wg.Add(1)
go doLockUnlock()
}
wg.Wait()
close(events)
close(errors)
countConcurrent := 0
for delta := range events {
countConcurrent += delta
if countConcurrent > 1 {
t.Errorf("The mutex did not prevented reentry: %d", countConcurrent)
}
}
for anError := range errors {
t.Errorf("The gorouting returned an error: %s", anError)
}
})
}

View File

@@ -6,9 +6,11 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"github.com/gofrs/flock"
"github.com/google/wire" "github.com/google/wire"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tokencache" "github.com/int128/kubelogin/pkg/tokencache"
@@ -23,6 +25,7 @@ var Set = wire.NewSet(
type Interface interface { type Interface interface {
FindByKey(dir string, key tokencache.Key) (*oidc.TokenSet, error) FindByKey(dir string, key tokencache.Key) (*oidc.TokenSet, error)
Save(dir string, key tokencache.Key, tokenSet oidc.TokenSet) error Save(dir string, key tokencache.Key, tokenSet oidc.TokenSet) error
Lock(dir string, key tokencache.Key) (io.Closer, error)
} }
type entity struct { type entity struct {
@@ -80,6 +83,22 @@ func (r *Repository) Save(dir string, key tokencache.Key, tokenSet oidc.TokenSet
return nil return nil
} }
func (r *Repository) Lock(dir string, key tokencache.Key) (io.Closer, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("could not create directory %s: %w", dir, err)
}
filename, err := computeFilename(key)
if err != nil {
return nil, fmt.Errorf("could not compute the key: %w", err)
}
p := filepath.Join(dir, filename)
f := flock.New(p)
if err := f.Lock(); err != nil {
return nil, fmt.Errorf("could not lock the cache file %s: %w", p, err)
}
return f, nil
}
func computeFilename(key tokencache.Key) (string, error) { func computeFilename(key tokencache.Key) (string, error) {
s := sha256.New() s := sha256.New()
e := gob.NewEncoder(s) e := gob.NewEncoder(s)

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"github.com/google/wire" "github.com/google/wire"
"github.com/int128/kubelogin/pkg/infrastructure/clock"
"github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/infrastructure/logger"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/oidc/client" "github.com/int128/kubelogin/pkg/oidc/client"
@@ -48,8 +47,7 @@ type GrantOptionSet struct {
// Output represents an output DTO of the Authentication use-case. // Output represents an output DTO of the Authentication use-case.
type Output struct { type Output struct {
AlreadyHasValidIDToken bool TokenSet oidc.TokenSet
TokenSet oidc.TokenSet
} }
// Authentication provides the internal use-case of authentication. // Authentication provides the internal use-case of authentication.
@@ -67,7 +65,6 @@ type Output struct {
type Authentication struct { type Authentication struct {
ClientFactory client.FactoryInterface ClientFactory client.FactoryInterface
Logger logger.Interface Logger logger.Interface
Clock clock.Interface
AuthCodeBrowser *authcode.Browser AuthCodeBrowser *authcode.Browser
AuthCodeKeyboard *authcode.Keyboard AuthCodeKeyboard *authcode.Keyboard
ROPC *ropc.ROPC ROPC *ropc.ROPC
@@ -75,29 +72,6 @@ type Authentication struct {
} }
func (u *Authentication) Do(ctx context.Context, in Input) (*Output, error) { func (u *Authentication) Do(ctx context.Context, in Input) (*Output, error) {
if in.CachedTokenSet != nil {
if in.ForceRefresh {
u.Logger.V(1).Infof("forcing refresh of the existing token")
} else {
u.Logger.V(1).Infof("checking expiration of the existing token")
// Skip verification of the token to reduce time of a discovery request.
// Here it trusts the signature and claims and checks only expiration,
// because the token has been verified before caching.
claims, err := in.CachedTokenSet.DecodeWithoutVerify()
if err != nil {
return nil, fmt.Errorf("invalid token cache (you may need to remove): %w", err)
}
if !claims.IsExpired(u.Clock) {
u.Logger.V(1).Infof("you already have a valid token until %s", claims.Expiry)
return &Output{
AlreadyHasValidIDToken: true,
TokenSet: *in.CachedTokenSet,
}, nil
}
u.Logger.V(1).Infof("you have an expired token at %s", claims.Expiry)
}
}
u.Logger.V(1).Infof("initializing an OpenID Connect client") u.Logger.V(1).Infof("initializing an OpenID Connect client")
oidcClient, err := u.ClientFactory.New(ctx, in.Provider, in.TLSClientConfig, in.UseAccessToken) oidcClient, err := u.ClientFactory.New(ctx, in.Provider, in.TLSClientConfig, in.UseAccessToken)
if err != nil { if err != nil {

View File

@@ -11,7 +11,6 @@ import (
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/oidc/client_mock" "github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/oidc/client_mock"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/oidc/client" "github.com/int128/kubelogin/pkg/oidc/client"
"github.com/int128/kubelogin/pkg/testing/clock"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt" testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
testingLogger "github.com/int128/kubelogin/pkg/testing/logger" testingLogger "github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig" "github.com/int128/kubelogin/pkg/tlsclientconfig"
@@ -37,35 +36,6 @@ func TestAuthentication_Do(t *testing.T) {
claims.ExpiresAt = jwt.NewNumericDate(expiryTime) claims.ExpiresAt = jwt.NewNumericDate(expiryTime)
}) })
t.Run("HasValidIDToken", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
in := Input{
Provider: dummyProvider,
TLSClientConfig: dummyTLSClientConfig,
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
},
}
u := Authentication{
Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
}
got, err := u.Do(ctx, in)
if err != nil {
t.Errorf("Do returned error: %+v", err)
}
want := &Output{
AlreadyHasValidIDToken: true,
TokenSet: oidc.TokenSet{
IDToken: issuedIDToken,
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
t.Run("HasValidRefreshToken", func(t *testing.T) { t.Run("HasValidRefreshToken", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout) ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel() defer cancel()
@@ -91,7 +61,6 @@ func TestAuthentication_Do(t *testing.T) {
u := Authentication{ u := Authentication{
ClientFactory: mockClientFactory, ClientFactory: mockClientFactory,
Logger: testingLogger.New(t), Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
} }
got, err := u.Do(ctx, in) got, err := u.Do(ctx, in)
if err != nil { if err != nil {
@@ -149,7 +118,6 @@ func TestAuthentication_Do(t *testing.T) {
u := Authentication{ u := Authentication{
ClientFactory: mockClientFactory, ClientFactory: mockClientFactory,
Logger: testingLogger.New(t), Logger: testingLogger.New(t),
Clock: clock.Fake(expiryTime.Add(+time.Hour)),
AuthCodeBrowser: &authcode.Browser{ AuthCodeBrowser: &authcode.Browser{
Logger: testingLogger.New(t), Logger: testingLogger.New(t),
}, },

View File

@@ -6,20 +6,18 @@ package credentialplugin
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"strings" "strings"
"github.com/google/wire" "github.com/google/wire"
"github.com/int128/kubelogin/pkg/credentialplugin" "github.com/int128/kubelogin/pkg/credentialplugin"
"github.com/int128/kubelogin/pkg/credentialplugin/writer" "github.com/int128/kubelogin/pkg/credentialplugin/writer"
"github.com/int128/kubelogin/pkg/infrastructure/clock"
"github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/infrastructure/logger"
"github.com/int128/kubelogin/pkg/infrastructure/mutex"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/tlsclientconfig" "github.com/int128/kubelogin/pkg/tlsclientconfig"
"github.com/int128/kubelogin/pkg/tokencache" "github.com/int128/kubelogin/pkg/tokencache"
"github.com/int128/kubelogin/pkg/tokencache/repository" "github.com/int128/kubelogin/pkg/tokencache/repository"
"github.com/int128/kubelogin/pkg/usecases/authentication" "github.com/int128/kubelogin/pkg/usecases/authentication"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
) )
var Set = wire.NewSet( var Set = wire.NewSet(
@@ -45,30 +43,13 @@ type GetToken struct {
Authentication authentication.Interface Authentication authentication.Interface
TokenCacheRepository repository.Interface TokenCacheRepository repository.Interface
Writer writer.Interface Writer writer.Interface
Mutex mutex.Interface
Logger logger.Interface Logger logger.Interface
Clock clock.Interface
} }
func (u *GetToken) Do(ctx context.Context, in Input) error { func (u *GetToken) Do(ctx context.Context, in Input) error {
u.Logger.V(1).Infof("WARNING: log may contain your secrets such as token or password") u.Logger.V(1).Infof("WARNING: log may contain your secrets such as token or password")
// Prevent multiple concurrent port binding using a file mutex.
// See https://github.com/int128/kubelogin/issues/389
bindPorts := extractBindAddressPorts(in.GrantOptionSet.AuthCodeBrowserOption)
if bindPorts != nil {
key := fmt.Sprintf("get-token-%s", strings.Join(bindPorts, "-"))
u.Logger.V(1).Infof("acquiring a lock %s", key)
lock, err := u.Mutex.Acquire(ctx, key)
if err != nil {
return fmt.Errorf("could not acquire a lock: %w", err)
}
defer func() {
if err := u.Mutex.Release(lock); err != nil {
u.Logger.V(1).Infof("could not release the lock: %s", err)
}
}()
}
u.Logger.V(1).Infof("finding a token from cache directory %s", in.TokenCacheDir) u.Logger.V(1).Infof("finding a token from cache directory %s", in.TokenCacheDir)
tokenCacheKey := tokencache.Key{ tokenCacheKey := tokencache.Key{
IssuerURL: in.Provider.IssuerURL, IssuerURL: in.Provider.IssuerURL,
@@ -82,10 +63,49 @@ func (u *GetToken) Do(ctx context.Context, in Input) error {
if in.GrantOptionSet.ROPCOption != nil { if in.GrantOptionSet.ROPCOption != nil {
tokenCacheKey.Username = in.GrantOptionSet.ROPCOption.Username tokenCacheKey.Username = in.GrantOptionSet.ROPCOption.Username
} }
u.Logger.V(1).Infof("acquiring the lock of token cache")
lock, err := u.TokenCacheRepository.Lock(in.TokenCacheDir, tokenCacheKey)
if err != nil {
return fmt.Errorf("could not lock the token cache: %w", err)
}
defer func() {
u.Logger.V(1).Infof("releasing the lock of token cache")
if err := lock.Close(); err != nil {
u.Logger.Printf("could not unlock the token cache: %s", err)
}
}()
cachedTokenSet, err := u.TokenCacheRepository.FindByKey(in.TokenCacheDir, tokenCacheKey) cachedTokenSet, err := u.TokenCacheRepository.FindByKey(in.TokenCacheDir, tokenCacheKey)
if err != nil { if err != nil {
u.Logger.V(1).Infof("could not find a token cache: %s", err) u.Logger.V(1).Infof("could not find a token cache: %s", err)
} }
if cachedTokenSet != nil {
if in.ForceRefresh {
u.Logger.V(1).Infof("forcing refresh of the existing token")
} else {
u.Logger.V(1).Infof("checking expiration of the existing token")
// Skip verification of the token to reduce time of a discovery request.
// Here it trusts the signature and claims and checks only expiration,
// because the token has been verified before caching.
claims, err := cachedTokenSet.DecodeWithoutVerify()
if err != nil {
return fmt.Errorf("invalid token cache (you may need to remove): %w", err)
}
if !claims.IsExpired(u.Clock) {
u.Logger.V(1).Infof("you already have a valid token until %s", claims.Expiry)
out := credentialplugin.Output{
Token: cachedTokenSet.IDToken,
Expiry: claims.Expiry,
}
if err := u.Writer.Write(out); err != nil {
return fmt.Errorf("could not write the token to client-go: %w", err)
}
return nil
}
u.Logger.V(1).Infof("you have an expired token at %s", claims.Expiry)
}
}
authenticationInput := authentication.Input{ authenticationInput := authentication.Input{
Provider: in.Provider, Provider: in.Provider,
@@ -104,14 +124,9 @@ func (u *GetToken) Do(ctx context.Context, in Input) error {
return fmt.Errorf("you got an invalid token: %w", err) return fmt.Errorf("you got an invalid token: %w", err)
} }
u.Logger.V(1).Infof("you got a token: %s", idTokenClaims.Pretty) u.Logger.V(1).Infof("you got a token: %s", idTokenClaims.Pretty)
u.Logger.V(1).Infof("you got a valid token until %s", idTokenClaims.Expiry)
if authenticationOutput.AlreadyHasValidIDToken { if err := u.TokenCacheRepository.Save(in.TokenCacheDir, tokenCacheKey, authenticationOutput.TokenSet); err != nil {
u.Logger.V(1).Infof("you already have a valid token until %s", idTokenClaims.Expiry) return fmt.Errorf("could not write the token cache: %w", err)
} else {
u.Logger.V(1).Infof("you got a valid token until %s", idTokenClaims.Expiry)
if err := u.TokenCacheRepository.Save(in.TokenCacheDir, tokenCacheKey, authenticationOutput.TokenSet); err != nil {
return fmt.Errorf("could not write the token cache: %w", err)
}
} }
u.Logger.V(1).Infof("writing the token to client-go") u.Logger.V(1).Infof("writing the token to client-go")
out := credentialplugin.Output{ out := credentialplugin.Output{
@@ -123,21 +138,3 @@ func (u *GetToken) Do(ctx context.Context, in Input) error {
} }
return nil return nil
} }
func extractBindAddressPorts(o *authcode.BrowserOption) []string {
if o == nil {
return nil
}
var ports []string
for _, addr := range o.BindAddress {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil // invalid address
}
if port == "0" {
return nil // any port
}
ports = append(ports, port)
}
return ports
}

View File

@@ -8,11 +8,11 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/credentialplugin/writer_mock" "github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/credentialplugin/writer_mock"
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/infrastructure/mutex_mock"
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/tokencache/repository_mock" "github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/tokencache/repository_mock"
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/usecases/authentication_mock" "github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/usecases/authentication_mock"
"github.com/int128/kubelogin/mocks/io_mock"
"github.com/int128/kubelogin/pkg/credentialplugin" "github.com/int128/kubelogin/pkg/credentialplugin"
"github.com/int128/kubelogin/pkg/infrastructure/mutex" "github.com/int128/kubelogin/pkg/testing/clock"
"github.com/int128/kubelogin/pkg/usecases/authentication/authcode" "github.com/int128/kubelogin/pkg/usecases/authentication/authcode"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
@@ -29,11 +29,11 @@ func TestGetToken_Do(t *testing.T) {
ClientID: "YOUR_CLIENT_ID", ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET", ClientSecret: "YOUR_CLIENT_SECRET",
} }
issuedIDTokenExpiration := time.Now().Add(1 * time.Hour).Round(time.Second) expiryTime := time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC).Local()
issuedIDToken := testingJWT.EncodeF(t, func(claims *testingJWT.Claims) { issuedIDToken := testingJWT.EncodeF(t, func(claims *testingJWT.Claims) {
claims.Issuer = "https://accounts.google.com" claims.Issuer = "https://accounts.google.com"
claims.Subject = "YOUR_SUBJECT" claims.Subject = "YOUR_SUBJECT"
claims.ExpiresAt = jwt.NewNumericDate(issuedIDTokenExpiration) claims.ExpiresAt = jwt.NewNumericDate(expiryTime)
}) })
issuedTokenSet := oidc.TokenSet{ issuedTokenSet := oidc.TokenSet{
IDToken: issuedIDToken, IDToken: issuedIDToken,
@@ -41,7 +41,7 @@ func TestGetToken_Do(t *testing.T) {
} }
issuedOutput := credentialplugin.Output{ issuedOutput := credentialplugin.Output{
Token: issuedIDToken, Token: issuedIDToken,
Expiry: issuedIDTokenExpiration, Expiry: expiryTime,
} }
grantOptionSet := authentication.GrantOptionSet{ grantOptionSet := authentication.GrantOptionSet{
AuthCodeBrowserOption: &authcode.BrowserOption{ AuthCodeBrowserOption: &authcode.BrowserOption{
@@ -68,7 +68,14 @@ func TestGetToken_Do(t *testing.T) {
GrantOptionSet: grantOptionSet, GrantOptionSet: grantOptionSet,
}). }).
Return(&authentication.Output{TokenSet: issuedTokenSet}, nil) Return(&authentication.Output{TokenSet: issuedTokenSet}, nil)
mockCloser := io_mock.NewMockCloser(t)
mockCloser.EXPECT().
Close().
Return(nil)
mockRepository := repository_mock.NewMockInterface(t) mockRepository := repository_mock.NewMockInterface(t)
mockRepository.EXPECT().
Lock("/path/to/token-cache", tokenCacheKey).
Return(mockCloser, nil)
mockRepository.EXPECT(). mockRepository.EXPECT().
FindByKey("/path/to/token-cache", tokenCacheKey). FindByKey("/path/to/token-cache", tokenCacheKey).
Return(nil, errors.New("file not found")) Return(nil, errors.New("file not found"))
@@ -83,63 +90,8 @@ func TestGetToken_Do(t *testing.T) {
Authentication: mockAuthentication, Authentication: mockAuthentication,
TokenCacheRepository: mockRepository, TokenCacheRepository: mockRepository,
Writer: mockWriter, Writer: mockWriter,
Mutex: mutex_mock.NewMockInterface(t),
Logger: logger.New(t),
}
if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err)
}
})
t.Run("NeedBindPortMutex", func(t *testing.T) {
grantOptionSet := authentication.GrantOptionSet{
AuthCodeBrowserOption: &authcode.BrowserOption{
BindAddress: []string{"127.0.0.1:8080"},
},
}
tokenCacheKey := tokencache.Key{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
ctx := context.TODO()
in := Input{
Provider: dummyProvider,
TokenCacheDir: "/path/to/token-cache",
GrantOptionSet: grantOptionSet,
}
mockAuthentication := authentication_mock.NewMockInterface(t)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: dummyProvider,
GrantOptionSet: grantOptionSet,
}).
Return(&authentication.Output{TokenSet: issuedTokenSet}, nil)
mockRepository := repository_mock.NewMockInterface(t)
mockRepository.EXPECT().
FindByKey("/path/to/token-cache", tokenCacheKey).
Return(nil, errors.New("file not found"))
mockRepository.EXPECT().
Save("/path/to/token-cache", tokenCacheKey, issuedTokenSet).
Return(nil)
mockWriter := writer_mock.NewMockInterface(t)
mockWriter.EXPECT().
Write(issuedOutput).
Return(nil)
mockMutex := mutex_mock.NewMockInterface(t)
mockMutex.EXPECT().
Acquire(ctx, "get-token-8080").
Return(&mutex.Lock{Data: "testData"}, nil)
mockMutex.EXPECT().
Release(&mutex.Lock{Data: "testData"}).
Return(nil)
u := GetToken{
Authentication: mockAuthentication,
TokenCacheRepository: mockRepository,
Writer: mockWriter,
Mutex: mockMutex,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err != nil { if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err) t.Errorf("Do returned error: %+v", err)
@@ -170,7 +122,14 @@ func TestGetToken_Do(t *testing.T) {
GrantOptionSet: grantOptionSet, GrantOptionSet: grantOptionSet,
}). }).
Return(&authentication.Output{TokenSet: issuedTokenSet}, nil) Return(&authentication.Output{TokenSet: issuedTokenSet}, nil)
mockCloser := io_mock.NewMockCloser(t)
mockCloser.EXPECT().
Close().
Return(nil)
mockRepository := repository_mock.NewMockInterface(t) mockRepository := repository_mock.NewMockInterface(t)
mockRepository.EXPECT().
Lock("/path/to/token-cache", tokenCacheKey).
Return(mockCloser, nil)
mockRepository.EXPECT(). mockRepository.EXPECT().
FindByKey("/path/to/token-cache", tokenCacheKey). FindByKey("/path/to/token-cache", tokenCacheKey).
Return(nil, errors.New("file not found")) Return(nil, errors.New("file not found"))
@@ -185,8 +144,8 @@ func TestGetToken_Do(t *testing.T) {
Authentication: mockAuthentication, Authentication: mockAuthentication,
TokenCacheRepository: mockRepository, TokenCacheRepository: mockRepository,
Writer: mockWriter, Writer: mockWriter,
Mutex: mutex_mock.NewMockInterface(t),
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err != nil { if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err) t.Errorf("Do returned error: %+v", err)
@@ -194,24 +153,26 @@ func TestGetToken_Do(t *testing.T) {
}) })
t.Run("HasValidIDToken", func(t *testing.T) { t.Run("HasValidIDToken", func(t *testing.T) {
tokenCacheKey := tokencache.Key{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
ctx := context.TODO() ctx := context.TODO()
in := Input{ in := Input{
Provider: dummyProvider, Provider: dummyProvider,
TokenCacheDir: "/path/to/token-cache", TokenCacheDir: "/path/to/token-cache",
GrantOptionSet: grantOptionSet, GrantOptionSet: grantOptionSet,
} }
mockAuthentication := authentication_mock.NewMockInterface(t) mockCloser := io_mock.NewMockCloser(t)
mockAuthentication.EXPECT(). mockCloser.EXPECT().
Do(ctx, authentication.Input{ Close().
Provider: dummyProvider, Return(nil)
CachedTokenSet: &issuedTokenSet,
GrantOptionSet: grantOptionSet,
}).
Return(&authentication.Output{
AlreadyHasValidIDToken: true,
TokenSet: issuedTokenSet,
}, nil)
mockRepository := repository_mock.NewMockInterface(t) mockRepository := repository_mock.NewMockInterface(t)
mockRepository.EXPECT().
Lock("/path/to/token-cache", tokenCacheKey).
Return(mockCloser, nil)
mockRepository.EXPECT(). mockRepository.EXPECT().
FindByKey("/path/to/token-cache", tokencache.Key{ FindByKey("/path/to/token-cache", tokencache.Key{
IssuerURL: "https://accounts.google.com", IssuerURL: "https://accounts.google.com",
@@ -224,11 +185,11 @@ func TestGetToken_Do(t *testing.T) {
Write(issuedOutput). Write(issuedOutput).
Return(nil) Return(nil)
u := GetToken{ u := GetToken{
Authentication: mockAuthentication, Authentication: authentication_mock.NewMockInterface(t),
TokenCacheRepository: mockRepository, TokenCacheRepository: mockRepository,
Writer: mockWriter, Writer: mockWriter,
Mutex: mutex_mock.NewMockInterface(t),
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err != nil { if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err) t.Errorf("Do returned error: %+v", err)
@@ -236,6 +197,11 @@ func TestGetToken_Do(t *testing.T) {
}) })
t.Run("AuthenticationError", func(t *testing.T) { t.Run("AuthenticationError", func(t *testing.T) {
tokenCacheKey := tokencache.Key{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
}
ctx := context.TODO() ctx := context.TODO()
in := Input{ in := Input{
Provider: dummyProvider, Provider: dummyProvider,
@@ -249,7 +215,14 @@ func TestGetToken_Do(t *testing.T) {
GrantOptionSet: grantOptionSet, GrantOptionSet: grantOptionSet,
}). }).
Return(nil, errors.New("authentication error")) Return(nil, errors.New("authentication error"))
mockCloser := io_mock.NewMockCloser(t)
mockCloser.EXPECT().
Close().
Return(nil)
mockRepository := repository_mock.NewMockInterface(t) mockRepository := repository_mock.NewMockInterface(t)
mockRepository.EXPECT().
Lock("/path/to/token-cache", tokenCacheKey).
Return(mockCloser, nil)
mockRepository.EXPECT(). mockRepository.EXPECT().
FindByKey("/path/to/token-cache", tokencache.Key{ FindByKey("/path/to/token-cache", tokencache.Key{
IssuerURL: "https://accounts.google.com", IssuerURL: "https://accounts.google.com",
@@ -261,8 +234,8 @@ func TestGetToken_Do(t *testing.T) {
Authentication: mockAuthentication, Authentication: mockAuthentication,
TokenCacheRepository: mockRepository, TokenCacheRepository: mockRepository,
Writer: writer_mock.NewMockInterface(t), Writer: writer_mock.NewMockInterface(t),
Mutex: mutex_mock.NewMockInterface(t),
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err == nil { if err := u.Do(ctx, in); err == nil {
t.Errorf("err wants non-nil but nil") t.Errorf("err wants non-nil but nil")

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/google/wire" "github.com/google/wire"
"github.com/int128/kubelogin/pkg/infrastructure/clock"
"github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/infrastructure/logger"
"github.com/int128/kubelogin/pkg/kubeconfig" "github.com/int128/kubelogin/pkg/kubeconfig"
"github.com/int128/kubelogin/pkg/kubeconfig/loader" "github.com/int128/kubelogin/pkg/kubeconfig/loader"
@@ -52,6 +53,7 @@ type Standalone struct {
KubeconfigLoader loader.Interface KubeconfigLoader loader.Interface
KubeconfigWriter writer.Interface KubeconfigWriter writer.Interface
Logger logger.Interface Logger logger.Interface
Clock clock.Interface
} }
func (u *Standalone) Do(ctx context.Context, in Input) error { func (u *Standalone) Do(ctx context.Context, in Input) error {
@@ -78,6 +80,18 @@ func (u *Standalone) Do(ctx context.Context, in Input) error {
IDToken: authProvider.IDToken, IDToken: authProvider.IDToken,
RefreshToken: authProvider.RefreshToken, RefreshToken: authProvider.RefreshToken,
} }
u.Logger.V(1).Infof("checking expiration of the existing token")
// Skip verification of the token to reduce time of a discovery request.
// Here it trusts the signature and claims and checks only expiration,
// because the token has been verified before caching.
claims, err := cachedTokenSet.DecodeWithoutVerify()
if err != nil {
return fmt.Errorf("invalid token cache (you may need to remove): %w", err)
}
if !claims.IsExpired(u.Clock) {
u.Logger.V(1).Infof("you already have a valid token until %s", claims.Expiry)
return nil
}
} }
authenticationInput := authentication.Input{ authenticationInput := authentication.Input{
@@ -101,11 +115,6 @@ func (u *Standalone) Do(ctx context.Context, in Input) error {
return fmt.Errorf("you got an invalid token: %w", err) return fmt.Errorf("you got an invalid token: %w", err)
} }
u.Logger.V(1).Infof("you got a token: %s", idTokenClaims.Pretty) u.Logger.V(1).Infof("you got a token: %s", idTokenClaims.Pretty)
if authenticationOutput.AlreadyHasValidIDToken {
u.Logger.Printf("You already have a valid token until %s", idTokenClaims.Expiry)
return nil
}
u.Logger.Printf("You got a valid token until %s", idTokenClaims.Expiry) u.Logger.Printf("You got a valid token until %s", idTokenClaims.Expiry)
authProvider.IDToken = authenticationOutput.TokenSet.IDToken authProvider.IDToken = authenticationOutput.TokenSet.IDToken
authProvider.RefreshToken = authenticationOutput.TokenSet.RefreshToken authProvider.RefreshToken = authenticationOutput.TokenSet.RefreshToken

View File

@@ -12,6 +12,7 @@ import (
"github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/usecases/authentication_mock" "github.com/int128/kubelogin/mocks/github.com/int128/kubelogin/pkg/usecases/authentication_mock"
"github.com/int128/kubelogin/pkg/kubeconfig" "github.com/int128/kubelogin/pkg/kubeconfig"
"github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/oidc"
"github.com/int128/kubelogin/pkg/testing/clock"
testingJWT "github.com/int128/kubelogin/pkg/testing/jwt" testingJWT "github.com/int128/kubelogin/pkg/testing/jwt"
"github.com/int128/kubelogin/pkg/testing/logger" "github.com/int128/kubelogin/pkg/testing/logger"
"github.com/int128/kubelogin/pkg/tlsclientconfig" "github.com/int128/kubelogin/pkg/tlsclientconfig"
@@ -19,11 +20,11 @@ import (
) )
func TestStandalone_Do(t *testing.T) { func TestStandalone_Do(t *testing.T) {
issuedIDTokenExpiration := time.Now().Add(1 * time.Hour).Round(time.Second) expiryTime := time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)
issuedIDToken := testingJWT.EncodeF(t, func(claims *testingJWT.Claims) { issuedIDToken := testingJWT.EncodeF(t, func(claims *testingJWT.Claims) {
claims.Issuer = "https://accounts.google.com" claims.Issuer = "https://accounts.google.com"
claims.Subject = "YOUR_SUBJECT" claims.Subject = "YOUR_SUBJECT"
claims.ExpiresAt = jwt.NewNumericDate(issuedIDTokenExpiration) claims.ExpiresAt = jwt.NewNumericDate(expiryTime)
}) })
t.Run("FullOptions", func(t *testing.T) { t.Run("FullOptions", func(t *testing.T) {
@@ -87,6 +88,7 @@ func TestStandalone_Do(t *testing.T) {
KubeconfigLoader: mockLoader, KubeconfigLoader: mockLoader,
KubeconfigWriter: mockWriter, KubeconfigWriter: mockWriter,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err != nil { if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err) t.Errorf("Do returned error: %+v", err)
@@ -108,28 +110,11 @@ func TestStandalone_Do(t *testing.T) {
mockLoader.EXPECT(). mockLoader.EXPECT().
GetCurrentAuthProvider("", kubeconfig.ContextName(""), kubeconfig.UserName("")). GetCurrentAuthProvider("", kubeconfig.ContextName(""), kubeconfig.UserName("")).
Return(currentAuthProvider, nil) Return(currentAuthProvider, nil)
mockAuthentication := authentication_mock.NewMockInterface(t)
mockAuthentication.EXPECT().
Do(ctx, authentication.Input{
Provider: oidc.Provider{
IssuerURL: "https://accounts.google.com",
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
},
CachedTokenSet: &oidc.TokenSet{
IDToken: issuedIDToken,
},
}).
Return(&authentication.Output{
AlreadyHasValidIDToken: true,
TokenSet: oidc.TokenSet{
IDToken: issuedIDToken,
},
}, nil)
u := Standalone{ u := Standalone{
Authentication: mockAuthentication, Authentication: authentication_mock.NewMockInterface(t),
KubeconfigLoader: mockLoader, KubeconfigLoader: mockLoader,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err != nil { if err := u.Do(ctx, in); err != nil {
t.Errorf("Do returned error: %+v", err) t.Errorf("Do returned error: %+v", err)
@@ -148,6 +133,7 @@ func TestStandalone_Do(t *testing.T) {
Authentication: mockAuthentication, Authentication: mockAuthentication,
KubeconfigLoader: mockLoader, KubeconfigLoader: mockLoader,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err == nil { if err := u.Do(ctx, in); err == nil {
t.Errorf("err wants non-nil but nil") t.Errorf("err wants non-nil but nil")
@@ -182,6 +168,7 @@ func TestStandalone_Do(t *testing.T) {
Authentication: mockAuthentication, Authentication: mockAuthentication,
KubeconfigLoader: mockLoader, KubeconfigLoader: mockLoader,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err == nil { if err := u.Do(ctx, in); err == nil {
t.Errorf("err wants non-nil but nil") t.Errorf("err wants non-nil but nil")
@@ -234,6 +221,7 @@ func TestStandalone_Do(t *testing.T) {
KubeconfigLoader: mockLoader, KubeconfigLoader: mockLoader,
KubeconfigWriter: mockWriter, KubeconfigWriter: mockWriter,
Logger: logger.New(t), Logger: logger.New(t),
Clock: clock.Fake(expiryTime.Add(-time.Hour)),
} }
if err := u.Do(ctx, in); err == nil { if err := u.Do(ctx, in); err == nil {
t.Errorf("err wants non-nil but nil") t.Errorf("err wants non-nil but nil")