diff --git a/pkg/common/executor_test.go b/pkg/common/executor_test.go index e4a7dfa4..2d36b267 100644 --- a/pkg/common/executor_test.go +++ b/pkg/common/executor_test.go @@ -3,6 +3,7 @@ package common import ( "context" "errors" + "sync/atomic" "testing" "time" @@ -80,37 +81,40 @@ func TestNewParallelExecutor(t *testing.T) { ctx := context.Background() - count := 0 - activeCount := 0 - maxCount := 0 + var count atomic.Int32 + var activeCount atomic.Int32 + var maxCount atomic.Int32 emptyWorkflow := NewPipelineExecutor(func(_ context.Context) error { - count++ + count.Add(1) - activeCount++ - if activeCount > maxCount { - maxCount = activeCount + cur := activeCount.Add(1) + for { + old := maxCount.Load() + if cur <= old || maxCount.CompareAndSwap(old, cur) { + break + } } time.Sleep(2 * time.Second) - activeCount-- + activeCount.Add(-1) return nil }) err := NewParallelExecutor(2, emptyWorkflow, emptyWorkflow, emptyWorkflow)(ctx) - assert.Equal(3, count, "should run all 3 executors") - assert.Equal(2, maxCount, "should run at most 2 executors in parallel") + assert.Equal(int32(3), count.Load(), "should run all 3 executors") + assert.Equal(int32(2), maxCount.Load(), "should run at most 2 executors in parallel") require.NoError(t, err) // Reset to test running the executor with 0 parallelism - count = 0 - activeCount = 0 - maxCount = 0 + count.Store(0) + activeCount.Store(0) + maxCount.Store(0) errSingle := NewParallelExecutor(0, emptyWorkflow, emptyWorkflow, emptyWorkflow)(ctx) - assert.Equal(3, count, "should run all 3 executors") - assert.Equal(1, maxCount, "should run at most 1 executors in parallel") + assert.Equal(int32(3), count.Load(), "should run all 3 executors") + assert.Equal(int32(1), maxCount.Load(), "should run at most 1 executors in parallel") require.NoError(t, errSingle) } @@ -127,7 +131,7 @@ func TestNewParallelExecutorFailed(t *testing.T) { }) err := NewParallelExecutor(1, errorWorkflow)(ctx) assert.Equal(1, count) - assert.ErrorIs(context.Canceled, err) + assert.ErrorIs(err, context.Canceled) } func TestNewParallelExecutorCanceled(t *testing.T) { @@ -136,18 +140,16 @@ func TestNewParallelExecutorCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - errExpected := errors.New("fake error") - - count := 0 + var count atomic.Int32 successWorkflow := NewPipelineExecutor(func(_ context.Context) error { - count++ + count.Add(1) return nil }) errorWorkflow := NewPipelineExecutor(func(_ context.Context) error { - count++ - return errExpected + count.Add(1) + return errors.New("fake error") }) err := NewParallelExecutor(3, errorWorkflow, successWorkflow, successWorkflow)(ctx) - assert.Equal(3, count) - assert.ErrorIs(errExpected, err) + assert.Equal(int32(3), count.Load()) + assert.ErrorIs(err, context.Canceled) }