diff --git a/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs index a17438de..ae804ddc 100644 --- a/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs +++ b/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs @@ -1,10 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; diff --git a/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs new file mode 100644 index 00000000..6124ac8f --- /dev/null +++ b/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.VisualStudio.Threading; + +/// +/// A variant on that tracks pending tasks and blocks disposal until all tasks have completed. +/// A cancellation token is provided so pending tasks can cooperatively cancel when disposal is requested. +/// +/// +/// +/// Cancellation of pending tasks is cooperative. +/// If a pending task does not observe , then disposal may take longer to complete, +/// or even never complete if a pending task never completes. +/// +/// +/// Creating tasks after disposal has been requested is not prevented by this class. +/// +/// +public class DisposableJoinableTaskFactory : DelegatingJoinableTaskFactory, IDisposable, System.IAsyncDisposable +{ + private readonly CancellationTokenSource disposalTokenSource = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The factory instance to be wrapped. Must have an associated collection. + public DisposableJoinableTaskFactory(JoinableTaskFactory innerFactory) + : base(innerFactory) + { + Requires.Argument(this.Collection is not null, nameof(innerFactory), "A collection must be associated with the factory."); + + // Get it now, since after the CTS is disposed, it throws when we try to access the token. + this.DisposalToken = this.disposalTokenSource.Token; + } + + /// + /// Initializes a new instance of the class. + /// + /// The used to construct the . + /// + /// This constructor creates a using . + /// + public DisposableJoinableTaskFactory(JoinableTaskContext joinableTaskContext) + : this(Requires.NotNull(joinableTaskContext).CreateFactory(Requires.NotNull(joinableTaskContext).CreateCollection())) + { + } + + /// + /// Gets a disposal token that should be used by tasks created by this factory to know when they should stop doing work. + /// + /// + /// This token is canceled when the factory is disposed. + /// + public CancellationToken DisposalToken { get; } + + /// + /// Gets the collection to which created tasks belong until they complete. + /// + protected new JoinableTaskCollection Collection => base.Collection!; + + /// + public void Dispose() + { + if (!this.disposalTokenSource.IsCancellationRequested) + { + this.disposalTokenSource.Cancel(); + this.disposalTokenSource.Dispose(); + } + + this.Context.Factory.Run(() => this.Collection.JoinTillEmptyAsync()); + } + + /// + public async ValueTask DisposeAsync() + { + if (!this.disposalTokenSource.IsCancellationRequested) + { + this.disposalTokenSource.Cancel(); + this.disposalTokenSource.Dispose(); + } + + await this.Collection.JoinTillEmptyAsync(); + } +} diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs index 3417adb3..b1bfb533 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs @@ -93,7 +93,7 @@ internal SynchronizationContext? ApplicableJobSyncContext /// /// Gets the collection to which created tasks belong until they complete. May be null. /// - internal JoinableTaskCollection? Collection + protected internal JoinableTaskCollection? Collection { get { return this.jobCollection; } } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs new file mode 100644 index 00000000..0b54c85f --- /dev/null +++ b/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading.Tasks; + +public class DisposableJoinableTaskFactoryTests : TestBase +{ + private readonly JoinableTaskContext context; + private readonly JoinableTaskCollection joinableCollection; + private readonly JoinableTaskFactory asyncPump; + + public DisposableJoinableTaskFactoryTests(ITestOutputHelper logger) + : base(logger) + { + this.context = JoinableTaskContext.CreateNoOpContext(); + this.joinableCollection = this.context.CreateCollection(); + this.asyncPump = this.context.CreateFactory(this.joinableCollection); + } + + [Fact] + public void DisposeCancelsToken() + { + DisposableJoinableTaskFactory factory = new(this.asyncPump); + + // The collection is already empty, so this completes immediately. + factory.Dispose(); + + Assert.True(factory.DisposalToken.IsCancellationRequested); + } + + [Fact] + public async Task DisposeAsyncCancelsToken() + { + DisposableJoinableTaskFactory factory = new(this.asyncPump); + + // The collection is already empty, so this completes immediately. + await factory.DisposeAsync(); + + Assert.True(factory.DisposalToken.IsCancellationRequested); + } + + [Fact] + public void MultipleDisposalsDoNotThrow() + { + using DisposableJoinableTaskFactory factory = new(this.asyncPump); + + factory.Dispose(); + factory.Dispose(); + + Assert.True(factory.DisposalToken.IsCancellationRequested); + } + + [Fact] + public async Task MultipleDisposeAsyncsDoesNotThrow() + { + DisposableJoinableTaskFactory factory = new(this.asyncPump); + + await factory.DisposeAsync(); + await factory.DisposeAsync(); + + Assert.True(factory.DisposalToken.IsCancellationRequested); + } + + [Fact] + public void ConstructorRequiresCollection() + { + // A factory created with just a context has no collection. + JoinableTaskFactory factoryWithoutCollection = this.context.Factory; + Assert.ThrowsAny(() => new DisposableJoinableTaskFactory(factoryWithoutCollection)); + } + + [Fact] + public async Task TaskObservesCancellationDuringDrain() + { + DisposableJoinableTaskFactory factory = new(this.asyncPump); + + JoinableTask jt = factory.RunAsync(() => Task.Delay(UnexpectedTimeout, factory.DisposalToken)); + + await factory.DisposeAsync().AsTask().WithTimeout(UnexpectedTimeout); + Assert.True(jt.IsCompleted); + await Assert.ThrowsAsync(async () => await jt); + } +}