diff --git a/internal/pkg/platform/activity_test.go b/internal/pkg/platform/activity_test.go index b1a58695..aee89b4b 100644 --- a/internal/pkg/platform/activity_test.go +++ b/internal/pkg/platform/activity_test.go @@ -16,6 +16,7 @@ package platform import ( "context" + "sync/atomic" "testing" "time" @@ -212,17 +213,19 @@ func TestPlatformActivity_StreamingLogs(t *testing.T) { PollingIntervalMS: 20, // poll activity every 20 ms }, Setup: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) context.Context { - cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything).Return(api.ActivityResult{}, nil) + var calls atomic.Int32 ctx, cancel := context.WithCancel(ctx) - go func() { - time.Sleep(time.Millisecond * 50) // cancel activity in 50 ms - cancel() - }() + cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + if calls.Add(1) >= 2 { + cancel() + } + }). + Return(api.ActivityResult{}, nil) return ctx }, ExpectedError: nil, ExpectedAsserts: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) { - // with the above polling/canceling setup, expectation is activity called two times. cm.API.AssertNumberOfCalls(t, "Activity", 2) }, }, @@ -232,17 +235,16 @@ func TestPlatformActivity_StreamingLogs(t *testing.T) { PollingIntervalMS: 20, // poll activity every 20 ms }, Setup: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) context.Context { - cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything).Return(api.ActivityResult{}, nil) ctx, cancel := context.WithCancel(ctx) - go func() { - time.Sleep(time.Millisecond * 10) // cancel activity in 10 ms - cancel() - }() + cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + cancel() + }). + Return(api.ActivityResult{}, nil) return ctx }, ExpectedError: nil, ExpectedAsserts: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) { - // with the above polling/canceling setup, expectation is activity called only once. cm.API.AssertNumberOfCalls(t, "Activity", 1) }, }, @@ -253,17 +255,19 @@ func TestPlatformActivity_StreamingLogs(t *testing.T) { PollingIntervalMS: 20, // poll activity every 20 ms }, Setup: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) context.Context { - cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything).Return(api.ActivityResult{}, slackerror.New("mock_broken_logs")) + var calls atomic.Int32 ctx, cancel := context.WithCancel(ctx) - go func() { - time.Sleep(time.Millisecond * 50) // cancel activity in 50 ms - cancel() - }() + cm.API.On("Activity", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + if calls.Add(1) >= 3 { + cancel() + } + }). + Return(api.ActivityResult{}, slackerror.New("mock_broken_logs")) return ctx }, ExpectedError: nil, ExpectedAsserts: func(t *testing.T, ctx context.Context, cm *shared.ClientsMock) { - // with the above polling/canceling setup, expectation is activity called three times. cm.API.AssertNumberOfCalls(t, "Activity", 3) }, },