add context to long running operations

This commit is contained in:
Fan Shang Xiang
2023-06-30 10:17:13 +08:00
parent 55586431bd
commit 471ab88240
12 changed files with 82 additions and 90 deletions

View File

@@ -17,6 +17,8 @@ limitations under the License.
package main
import (
"context"
"github.com/golang/glog"
_ "k8s.io/node-problem-detector/cmd/nodeproblemdetector/exporterplugins"
@@ -31,16 +33,7 @@ import (
"k8s.io/node-problem-detector/pkg/version"
)
func npdInteractive(npdo *options.NodeProblemDetectorOptions) {
termCh := make(chan error, 1)
defer close(termCh)
if err := npdMain(npdo, termCh); err != nil {
glog.Fatalf("Problem detector failed with error: %v", err)
}
}
func npdMain(npdo *options.NodeProblemDetectorOptions, termCh <-chan error) error {
func npdMain(ctx context.Context, npdo *options.NodeProblemDetectorOptions) error {
if npdo.PrintVersion {
version.PrintVersion()
return nil
@@ -58,7 +51,7 @@ func npdMain(npdo *options.NodeProblemDetectorOptions, termCh <-chan error) erro
// Initialize exporters.
defaultExporters := []types.Exporter{}
if ke := k8sexporter.NewExporterOrDie(npdo); ke != nil {
if ke := k8sexporter.NewExporterOrDie(ctx, npdo); ke != nil {
defaultExporters = append(defaultExporters, ke)
glog.Info("K8s exporter started.")
}
@@ -79,5 +72,5 @@ func npdMain(npdo *options.NodeProblemDetectorOptions, termCh <-chan error) erro
// Initialize NPD core.
p := problemdetector.NewProblemDetector(problemDaemons, npdExporters)
return p.Run(termCh)
return p.Run(ctx)
}

View File

@@ -17,6 +17,9 @@ limitations under the License.
package main
import (
"context"
"github.com/golang/glog"
"github.com/spf13/pflag"
"k8s.io/node-problem-detector/cmd/options"
)
@@ -26,5 +29,7 @@ func main() {
npdo.AddFlags(pflag.CommandLine)
pflag.Parse()
npdInteractive(npdo)
if err := npdMain(context.Background(), npdo); err != nil {
glog.Fatalf("Problem detector failed with error: %v", err)
}
}

View File

@@ -20,7 +20,7 @@ limitations under the License.
package main
import (
"errors"
"context"
"fmt"
"os"
"strings"
@@ -81,11 +81,9 @@ func TestNPDMain(t *testing.T) {
npdo, cleanup := setupNPD(t)
defer cleanup()
termCh := make(chan error, 2)
termCh <- errors.New("close")
defer close(termCh)
if err := npdMain(npdo, termCh); err != nil {
ctx, cancelFunc := context.WithCancel(context.Background())
cancelFunc()
if err := npdMain(ctx, npdo); err != nil {
t.Errorf("termination signal should not return error got, %v", err)
}
}

View File

@@ -17,7 +17,7 @@ limitations under the License.
package main
import (
"errors"
"context"
"fmt"
"sync"
"time"
@@ -102,26 +102,20 @@ type npdService struct {
}
func (s *npdService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
appTermCh := make(chan error, 1)
svcLoopTermCh := make(chan error, 1)
defer func() {
close(appTermCh)
close(svcLoopTermCh)
}()
changes <- svc.Status{State: svc.StartPending}
changes <- svc.Status{State: svc.Running, Accepts: svcCommandsAccepted}
var appWG sync.WaitGroup
var svcWG sync.WaitGroup
options := s.options
ctx, cancelFunc := context.WithCancel(context.Background())
// NPD application goroutine.
appWG.Add(1)
go func() {
defer appWG.Done()
if err := npdMain(options, appTermCh); err != nil {
if err := npdMain(ctx, options); err != nil {
elog.Warning(windowsEventLogID, err.Error())
}
@@ -132,16 +126,36 @@ func (s *npdService) Execute(args []string, r <-chan svc.ChangeRequest, changes
svcWG.Add(1)
go func() {
defer svcWG.Done()
serviceLoop(r, changes, appTermCh, svcLoopTermCh)
for {
select {
case <-ctx.Done():
return
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
// Testing deadlock from https://code.google.com/p/winsvc/issues/detail?id=4
time.Sleep(100 * time.Millisecond)
changes <- c.CurrentStatus
case svc.Stop, svc.Shutdown:
elog.Info(windowsEventLogID, fmt.Sprintf("Stopping %s service, %v", svcName, c.Context))
cancelFunc()
case svc.Pause:
elog.Info(windowsEventLogID, "ignoring pause command from Windows service control, not supported")
changes <- svc.Status{State: svc.Paused, Accepts: svcCommandsAccepted}
case svc.Continue:
elog.Info(windowsEventLogID, "ignoring continue command from Windows service control, not supported")
changes <- svc.Status{State: svc.Running, Accepts: svcCommandsAccepted}
default:
elog.Error(windowsEventLogID, fmt.Sprintf("unexpected control request #%d", c))
}
}
}
}()
// Wait for the application go routine to die.
appWG.Wait()
// Ensure that the service control loop is killed.
svcLoopTermCh <- nil
// Wait for the service control loop to terminate.
// Otherwise it's possible that the channel closures cause the application to panic.
svcWG.Wait()
@@ -151,31 +165,3 @@ func (s *npdService) Execute(args []string, r <-chan svc.ChangeRequest, changes
return false, uint32(0)
}
func serviceLoop(r <-chan svc.ChangeRequest, changes chan<- svc.Status, appTermCh chan error, svcLoopTermCh chan error) {
for {
select {
case <-svcLoopTermCh:
return
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
// Testing deadlock from https://code.google.com/p/winsvc/issues/detail?id=4
time.Sleep(100 * time.Millisecond)
changes <- c.CurrentStatus
case svc.Stop, svc.Shutdown:
elog.Info(windowsEventLogID, fmt.Sprintf("Stopping %s service, %v", svcName, c.Context))
appTermCh <- errors.New("stopping service")
case svc.Pause:
elog.Info(windowsEventLogID, "ignoring pause command from Windows service control, not supported")
changes <- svc.Status{State: svc.Paused, Accepts: svcCommandsAccepted}
case svc.Continue:
elog.Info(windowsEventLogID, "ignoring continue command from Windows service control, not supported")
changes <- svc.Status{State: svc.Running, Accepts: svcCommandsAccepted}
default:
elog.Error(windowsEventLogID, fmt.Sprintf("unexpected control request #%d", c))
}
}
}
}

View File

@@ -20,6 +20,7 @@ limitations under the License.
package main
import (
"context"
"testing"
"golang.org/x/sys/windows/svc"

View File

@@ -17,6 +17,7 @@ limitations under the License.
package condition
import (
"context"
"reflect"
"sync"
"time"
@@ -49,7 +50,7 @@ const (
// not. This addresses 3).
type ConditionManager interface {
// Start starts the condition manager.
Start()
Start(ctx context.Context)
// UpdateCondition updates a specific condition.
UpdateCondition(types.Condition)
// GetConditions returns all current conditions.
@@ -88,8 +89,8 @@ func NewConditionManager(client problemclient.Client, clock clock.Clock, heartbe
}
}
func (c *conditionManager) Start() {
go c.syncLoop()
func (c *conditionManager) Start(ctx context.Context) {
go c.syncLoop(ctx)
}
func (c *conditionManager) UpdateCondition(condition types.Condition) {
@@ -110,15 +111,17 @@ func (c *conditionManager) GetConditions() []types.Condition {
return conditions
}
func (c *conditionManager) syncLoop() {
func (c *conditionManager) syncLoop(ctx context.Context) {
ticker := c.clock.NewTicker(updatePeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C():
if c.needUpdates() || c.needResync() || c.needHeartbeat() {
c.sync()
c.sync(ctx)
}
case <-ctx.Done():
break
}
}
}
@@ -150,14 +153,14 @@ func (c *conditionManager) needHeartbeat() bool {
}
// sync synchronizes node conditions with the apiserver.
func (c *conditionManager) sync() {
func (c *conditionManager) sync(ctx context.Context) {
c.latestTry = c.clock.Now()
c.resyncNeeded = false
conditions := []v1.NodeCondition{}
for i := range c.conditions {
conditions = append(conditions, problemutil.ConvertToAPICondition(c.conditions[i]))
}
if err := c.client.SetConditions(conditions); err != nil {
if err := c.client.SetConditions(ctx, conditions); err != nil {
// The conditions will be updated again in future sync
glog.Errorf("failed to update node conditions: %v", err)
c.resyncNeeded = true

View File

@@ -17,6 +17,7 @@ limitations under the License.
package condition
import (
"context"
"fmt"
"testing"
"time"
@@ -109,7 +110,7 @@ func TestResync(t *testing.T) {
m, fakeClient, fakeClock := newTestManager()
condition := newTestCondition("TestCondition")
m.conditions = map[string]types.Condition{condition.Type: condition}
m.sync()
m.sync(context.Background())
expected := []v1.NodeCondition{problemutil.ConvertToAPICondition(condition)}
assert.Nil(t, fakeClient.AssertConditions(expected), "Condition should be updated via client")
@@ -118,7 +119,7 @@ func TestResync(t *testing.T) {
assert.False(t, m.needResync(), "Should not resync after resync period without resync needed")
fakeClient.InjectError("SetConditions", fmt.Errorf("injected error"))
m.sync()
m.sync(context.Background())
assert.False(t, m.needResync(), "Should not resync before resync period")
fakeClock.Step(resyncPeriod)
@@ -129,7 +130,7 @@ func TestHeartbeat(t *testing.T) {
m, fakeClient, fakeClock := newTestManager()
condition := newTestCondition("TestCondition")
m.conditions = map[string]types.Condition{condition.Type: condition}
m.sync()
m.sync(context.Background())
expected := []v1.NodeCondition{problemutil.ConvertToAPICondition(condition)}
assert.Nil(t, fakeClient.AssertConditions(expected), "Condition should be updated via client")

View File

@@ -17,6 +17,7 @@ limitations under the License.
package k8sexporter
import (
"context"
"net"
"net/http"
_ "net/http/pprof"
@@ -44,7 +45,7 @@ type k8sExporter struct {
//
// Note that this function may be blocked (until a timeout occurs) before
// kube-apiserver becomes ready.
func NewExporterOrDie(npdo *options.NodeProblemDetectorOptions) types.Exporter {
func NewExporterOrDie(ctx context.Context, npdo *options.NodeProblemDetectorOptions) types.Exporter {
if !npdo.EnableK8sExporter {
return nil
}
@@ -52,7 +53,7 @@ func NewExporterOrDie(npdo *options.NodeProblemDetectorOptions) types.Exporter {
c := problemclient.NewClientOrDie(npdo)
glog.Infof("Waiting for kube-apiserver to be ready (timeout %v)...", npdo.APIServerWaitTimeout)
if err := waitForAPIServerReadyWithTimeout(c, npdo); err != nil {
if err := waitForAPIServerReadyWithTimeout(ctx, c, npdo); err != nil {
glog.Warningf("kube-apiserver did not become ready: timed out on waiting for kube-apiserver to return the node object: %v", err)
}
@@ -62,7 +63,7 @@ func NewExporterOrDie(npdo *options.NodeProblemDetectorOptions) types.Exporter {
}
ke.startHTTPReporting(npdo)
ke.conditionManager.Start()
ke.conditionManager.Start(ctx)
return &ke
}
@@ -103,11 +104,11 @@ func (ke *k8sExporter) startHTTPReporting(npdo *options.NodeProblemDetectorOptio
}()
}
func waitForAPIServerReadyWithTimeout(c problemclient.Client, npdo *options.NodeProblemDetectorOptions) error {
func waitForAPIServerReadyWithTimeout(ctx context.Context, c problemclient.Client, npdo *options.NodeProblemDetectorOptions) error {
return wait.PollImmediate(npdo.APIServerWaitInterval, npdo.APIServerWaitTimeout, func() (done bool, err error) {
// If NPD can get the node object from kube-apiserver, the server is
// ready and the RBAC permission is set correctly.
if _, err := c.GetNode(); err != nil {
if _, err := c.GetNode(ctx); err != nil {
glog.Errorf("Can't get node object: %v", err)
return false, nil
}

View File

@@ -17,6 +17,7 @@ limitations under the License.
package problemclient
import (
"context"
"fmt"
"reflect"
"sync"
@@ -60,7 +61,7 @@ func (f *FakeProblemClient) AssertConditions(expected []v1.NodeCondition) error
}
// SetConditions is a fake mimic of SetConditions, it only update the internal condition cache.
func (f *FakeProblemClient) SetConditions(conditions []v1.NodeCondition) error {
func (f *FakeProblemClient) SetConditions(ctx context.Context, conditions []v1.NodeCondition) error {
f.Lock()
defer f.Unlock()
if err, ok := f.errors["SetConditions"]; ok {
@@ -73,7 +74,7 @@ func (f *FakeProblemClient) SetConditions(conditions []v1.NodeCondition) error {
}
// GetConditions is a fake mimic of GetConditions, it returns the conditions cached internally.
func (f *FakeProblemClient) GetConditions(types []v1.NodeConditionType) ([]*v1.NodeCondition, error) {
func (f *FakeProblemClient) GetConditions(ctx context.Context, types []v1.NodeConditionType) ([]*v1.NodeCondition, error) {
f.Lock()
defer f.Unlock()
if err, ok := f.errors["GetConditions"]; ok {
@@ -93,6 +94,6 @@ func (f *FakeProblemClient) GetConditions(types []v1.NodeConditionType) ([]*v1.N
func (f *FakeProblemClient) Eventf(eventType string, source, reason, messageFmt string, args ...interface{}) {
}
func (f *FakeProblemClient) GetNode() (*v1.Node, error) {
func (f *FakeProblemClient) GetNode(ctx context.Context) (*v1.Node, error) {
return nil, fmt.Errorf("GetNode() not implemented")
}

View File

@@ -17,6 +17,7 @@ limitations under the License.
package problemclient
import (
"context"
"encoding/json"
"fmt"
"net/url"
@@ -40,14 +41,14 @@ import (
// Client is the interface of problem client
type Client interface {
// GetConditions get all specific conditions of current node.
GetConditions(conditionTypes []v1.NodeConditionType) ([]*v1.NodeCondition, error)
GetConditions(ctx context.Context, conditionTypes []v1.NodeConditionType) ([]*v1.NodeCondition, error)
// SetConditions set or update conditions of current node.
SetConditions(conditions []v1.NodeCondition) error
SetConditions(ctx context.Context, conditionTypes []v1.NodeCondition) error
// Eventf reports the event.
Eventf(eventType string, source, reason, messageFmt string, args ...interface{})
// GetNode returns the Node object of the node on which the
// node-problem-detector runs.
GetNode() (*v1.Node, error)
GetNode(ctx context.Context) (*v1.Node, error)
}
type nodeProblemClient struct {
@@ -81,8 +82,8 @@ func NewClientOrDie(npdo *options.NodeProblemDetectorOptions) Client {
return c
}
func (c *nodeProblemClient) GetConditions(conditionTypes []v1.NodeConditionType) ([]*v1.NodeCondition, error) {
node, err := c.GetNode()
func (c *nodeProblemClient) GetConditions(ctx context.Context, conditionTypes []v1.NodeConditionType) ([]*v1.NodeCondition, error) {
node, err := c.GetNode(ctx)
if err != nil {
return nil, err
}
@@ -97,7 +98,7 @@ func (c *nodeProblemClient) GetConditions(conditionTypes []v1.NodeConditionType)
return conditions, nil
}
func (c *nodeProblemClient) SetConditions(newConditions []v1.NodeCondition) error {
func (c *nodeProblemClient) SetConditions(ctx context.Context, newConditions []v1.NodeCondition) error {
for i := range newConditions {
// Each time we update the conditions, we update the heart beat time
newConditions[i].LastHeartbeatTime = metav1.NewTime(c.clock.Now())
@@ -119,7 +120,7 @@ func (c *nodeProblemClient) Eventf(eventType, source, reason, messageFmt string,
recorder.Eventf(c.nodeRef, eventType, reason, messageFmt, args...)
}
func (c *nodeProblemClient) GetNode() (*v1.Node, error) {
func (c *nodeProblemClient) GetNode(ctx context.Context) (*v1.Node, error) {
return c.client.Nodes().Get(c.nodeName, metav1.GetOptions{})
}

View File

@@ -17,6 +17,7 @@ limitations under the License.
package problemdetector
import (
"context"
"fmt"
"github.com/golang/glog"
@@ -26,7 +27,7 @@ import (
// ProblemDetector collects statuses from all problem daemons and update the node condition and send node event.
type ProblemDetector interface {
Run(termCh <-chan error) error
Run(context.Context) error
}
type problemDetector struct {
@@ -44,7 +45,7 @@ func NewProblemDetector(monitors []types.Monitor, exporters []types.Exporter) Pr
}
// Run starts the problem detector.
func (p *problemDetector) Run(termCh <-chan error) error {
func (p *problemDetector) Run(ctx context.Context) error {
// Start the log monitors one by one.
var chans []<-chan *types.Status
failureCount := 0
@@ -77,7 +78,7 @@ func (p *problemDetector) Run(termCh <-chan error) error {
for {
select {
case <-termCh:
case <-ctx.Done():
return nil
case status := <-ch:
for _, exporter := range p.exporters {

View File

@@ -17,6 +17,7 @@ limitations under the License.
package problemdetector
import (
"context"
"testing"
"k8s.io/node-problem-detector/pkg/types"
@@ -24,7 +25,7 @@ import (
func TestEmpty(t *testing.T) {
pd := NewProblemDetector([]types.Monitor{}, []types.Exporter{})
if err := pd.Run(nil); err == nil {
if err := pd.Run(context.Background()); err == nil {
t.Error("expected error when running an empty problem detector")
}
}