From d5e0b7b8fdf765ca1b703ccb5e9da42e89137153 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Thu, 17 Jul 2025 15:13:09 +0100 Subject: [PATCH 1/2] - implement internal cancellation for SCAN via WithCancellation --- docs/AsyncTimeouts.md | 16 ++++- src/StackExchange.Redis/CursorEnumerable.cs | 3 +- src/StackExchange.Redis/TaskExtensions.cs | 38 ++++++++++++ .../CancellationTests.cs | 58 +++++++++++++------ 4 files changed, 94 insertions(+), 21 deletions(-) diff --git a/docs/AsyncTimeouts.md b/docs/AsyncTimeouts.md index 5ba4fd3f1..04892d59a 100644 --- a/docs/AsyncTimeouts.md +++ b/docs/AsyncTimeouts.md @@ -62,4 +62,18 @@ using var cts = CancellationTokenSource.CreateLinkedTokenSource(token); // or mu cts.CancelAfter(timeout); await database.StringSetAsync("key", "value").WaitAsync(cts.Token); var value = await database.StringGetAsync("key").WaitAsync(cts.Token); -`````` \ No newline at end of file +``` + +### Cancelling keys enumeration + +Keys being enumerated (via `SCAN`) can *also* be cancelled, using the inbuilt `.WithCancellation(...)` method: + +```csharp +CancellationToken token = ...; // for example, from HttpContext.RequestAborted +await foreach (var key in server.KeysAsync(pattern: "*foo*").WithCancellation(token)) +{ + ... +} +``` + +To use a timeout instead, you can use the `CancellationTokenSource` approach shown above. \ No newline at end of file diff --git a/src/StackExchange.Redis/CursorEnumerable.cs b/src/StackExchange.Redis/CursorEnumerable.cs index 55d93d6a6..e526eceaa 100644 --- a/src/StackExchange.Redis/CursorEnumerable.cs +++ b/src/StackExchange.Redis/CursorEnumerable.cs @@ -141,6 +141,7 @@ private bool SimpleNext() { if (_pageOffset + 1 < _pageCount) { + cancellationToken.ThrowIfCancellationRequested(); _pageOffset++; return true; } @@ -274,7 +275,7 @@ private async ValueTask AwaitedNextAsync(bool isInitial) ScanResult scanResult; try { - scanResult = await pending.ForAwait(); + scanResult = await pending.WaitAsync(cancellationToken).ForAwait(); } catch (Exception ex) { diff --git a/src/StackExchange.Redis/TaskExtensions.cs b/src/StackExchange.Redis/TaskExtensions.cs index 081a691ec..921e320f8 100644 --- a/src/StackExchange.Redis/TaskExtensions.cs +++ b/src/StackExchange.Redis/TaskExtensions.cs @@ -25,6 +25,44 @@ internal static Task ObserveErrors(this Task task) return task; } +#if !NET6_0_OR_GREATER + // suboptimal polyfill version of the .NET 6+ API, but reasonable for light use + internal static Task WaitAsync(this Task task, CancellationToken cancellationToken) + { + if (task.IsCompleted || !cancellationToken.CanBeCanceled) return task; + return Wrap(task, cancellationToken); + + static async Task Wrap(Task task, CancellationToken cancellationToken) + { + var tcs = new TaskSourceWithToken(cancellationToken); + using var reg = cancellationToken.Register( + static state => ((TaskSourceWithToken)state!).Cancel(), tcs); + _ = task.ContinueWith( + static (t, state) => + { + var tcs = (TaskSourceWithToken)state!; + if (t.IsCanceled) tcs.TrySetCanceled(); + else if (t.IsFaulted) tcs.TrySetException(t.Exception!); + else tcs.TrySetResult(t.Result); + }, + tcs); + return await tcs.Task; + } + } + + // the point of this type is to combine TCS and CT so that we can use a static + // registration via Register + private sealed class TaskSourceWithToken : TaskCompletionSource + { + public TaskSourceWithToken(CancellationToken cancellationToken) + => _cancellationToken = cancellationToken; + + private readonly CancellationToken _cancellationToken; + + public void Cancel() => TrySetCanceled(_cancellationToken); + } +#endif + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static ConfiguredTaskAwaitable ForAwait(this Task task) => task.ConfigureAwait(false); [MethodImpl(MethodImplOptions.AggressiveInlining)] diff --git a/tests/StackExchange.Redis.Tests/CancellationTests.cs b/tests/StackExchange.Redis.Tests/CancellationTests.cs index 58e33e3de..b04ae8a92 100644 --- a/tests/StackExchange.Redis.Tests/CancellationTests.cs +++ b/tests/StackExchange.Redis.Tests/CancellationTests.cs @@ -12,25 +12,6 @@ internal static class TaskExtensions { // suboptimal polyfill version of the .NET 6+ API; I'm not recommending this for production use, // but it's good enough for tests - public static Task WaitAsync(this Task task, CancellationToken cancellationToken) - { - if (task.IsCompleted || !cancellationToken.CanBeCanceled) return task; - return Wrap(task, cancellationToken); - - static async Task Wrap(Task task, CancellationToken cancellationToken) - { - var tcs = new TaskCompletionSource(); - using var reg = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); - _ = task.ContinueWith(t => - { - if (t.IsCanceled) tcs.TrySetCanceled(); - else if (t.IsFaulted) tcs.TrySetException(t.Exception!); - else tcs.TrySetResult(t.Result); - }); - return await tcs.Task; - } - } - public static Task WaitAsync(this Task task, TimeSpan timeout) { if (task.IsCompleted) return task; @@ -92,6 +73,11 @@ private void Pause(IDatabase db) db.Execute("client", new object[] { "pause", ConnectionPauseMilliseconds }, CommandFlags.FireAndForget); } + private void Pause(IServer server) + { + server.Execute("client", new object[] { "pause", ConnectionPauseMilliseconds }, CommandFlags.FireAndForget); + } + [Fact] public async Task WithTimeout_ShortTimeout_Async_ThrowsOperationCanceledException() { @@ -195,4 +181,38 @@ public async Task CancellationDuringOperation_Async_CancelsGracefully(CancelStra Assert.Equal(cts.Token, oce.CancellationToken); } } + + [Fact] + public async Task ScanCancellable() + { + using var conn = Create(); + var db = conn.GetDatabase(); + var server = conn.GetServer(conn.GetEndPoints()[0]); + + using var cts = new CancellationTokenSource(); + + var watch = Stopwatch.StartNew(); + Pause(server); + try + { + db.StringSet(Me(), "value", TimeSpan.FromMinutes(5), flags: CommandFlags.FireAndForget); + await using var iter = server.KeysAsync(pageSize: 1000).WithCancellation(cts.Token).GetAsyncEnumerator(); + var pending = iter.MoveNextAsync(); + Assert.False(cts.Token.IsCancellationRequested); + cts.CancelAfter(ShortDelayMilliseconds); // start this *after* we've got past the initial check + while (await pending) + { + pending = iter.MoveNextAsync(); + } + Assert.Fail($"{ExpectedCancel}: {watch.ElapsedMilliseconds}ms"); + } + catch (OperationCanceledException oce) + { + var taken = watch.ElapsedMilliseconds; + // Expected if cancellation happens during operation + Log($"Cancelled after {taken}ms"); + Assert.True(taken < ConnectionPauseMilliseconds / 2, "Should have cancelled much sooner"); + Assert.Equal(cts.Token, oce.CancellationToken); + } + } } From 872f1613f66c3691fc8e2a627493de417d12e870 Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Thu, 17 Jul 2025 15:15:36 +0100 Subject: [PATCH 2/2] release notes --- docs/ReleaseNotes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index b3a788286..2269dcc6c 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -11,6 +11,7 @@ Current package versions: - Add support for new `BITOP` operations in CE 8.2 ([#2900 by atakavci](https://github.com/StackExchange/StackExchange.Redis/pull/2900)) - Package updates ([#2906 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2906)) - Fix handshake error with `CLIENT ID` ([#2909 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2909)) +- Support async cancellation of `SCAN` enumeration ([#2911 by mgravell](https://github.com/StackExchange/StackExchange.Redis/pull/2911)) ## 2.8.41