diff --git a/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs
index a17438de..ae804ddc 100644
--- a/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs
+++ b/src/Microsoft.VisualStudio.Threading/DelegatingJoinableTaskFactory.cs
@@ -1,10 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
using System.Threading;
using System.Threading.Tasks;
diff --git a/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs
new file mode 100644
index 00000000..6124ac8f
--- /dev/null
+++ b/src/Microsoft.VisualStudio.Threading/DisposableJoinableTaskFactory.cs
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.VisualStudio.Threading;
+
+///
+/// A variant on that tracks pending tasks and blocks disposal until all tasks have completed.
+/// A cancellation token is provided so pending tasks can cooperatively cancel when disposal is requested.
+///
+///
+///
+/// Cancellation of pending tasks is cooperative.
+/// If a pending task does not observe , then disposal may take longer to complete,
+/// or even never complete if a pending task never completes.
+///
+///
+/// Creating tasks after disposal has been requested is not prevented by this class.
+///
+///
+public class DisposableJoinableTaskFactory : DelegatingJoinableTaskFactory, IDisposable, System.IAsyncDisposable
+{
+ private readonly CancellationTokenSource disposalTokenSource = new();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The factory instance to be wrapped. Must have an associated collection.
+ public DisposableJoinableTaskFactory(JoinableTaskFactory innerFactory)
+ : base(innerFactory)
+ {
+ Requires.Argument(this.Collection is not null, nameof(innerFactory), "A collection must be associated with the factory.");
+
+ // Get it now, since after the CTS is disposed, it throws when we try to access the token.
+ this.DisposalToken = this.disposalTokenSource.Token;
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The used to construct the .
+ ///
+ /// This constructor creates a using .
+ ///
+ public DisposableJoinableTaskFactory(JoinableTaskContext joinableTaskContext)
+ : this(Requires.NotNull(joinableTaskContext).CreateFactory(Requires.NotNull(joinableTaskContext).CreateCollection()))
+ {
+ }
+
+ ///
+ /// Gets a disposal token that should be used by tasks created by this factory to know when they should stop doing work.
+ ///
+ ///
+ /// This token is canceled when the factory is disposed.
+ ///
+ public CancellationToken DisposalToken { get; }
+
+ ///
+ /// Gets the collection to which created tasks belong until they complete.
+ ///
+ protected new JoinableTaskCollection Collection => base.Collection!;
+
+ ///
+ public void Dispose()
+ {
+ if (!this.disposalTokenSource.IsCancellationRequested)
+ {
+ this.disposalTokenSource.Cancel();
+ this.disposalTokenSource.Dispose();
+ }
+
+ this.Context.Factory.Run(() => this.Collection.JoinTillEmptyAsync());
+ }
+
+ ///
+ public async ValueTask DisposeAsync()
+ {
+ if (!this.disposalTokenSource.IsCancellationRequested)
+ {
+ this.disposalTokenSource.Cancel();
+ this.disposalTokenSource.Dispose();
+ }
+
+ await this.Collection.JoinTillEmptyAsync();
+ }
+}
diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs
index 3417adb3..b1bfb533 100644
--- a/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs
+++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskFactory.cs
@@ -93,7 +93,7 @@ internal SynchronizationContext? ApplicableJobSyncContext
///
/// Gets the collection to which created tasks belong until they complete. May be null.
///
- internal JoinableTaskCollection? Collection
+ protected internal JoinableTaskCollection? Collection
{
get { return this.jobCollection; }
}
diff --git a/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs
new file mode 100644
index 00000000..0b54c85f
--- /dev/null
+++ b/test/Microsoft.VisualStudio.Threading.Tests/DisposableJoinableTaskFactoryTests.cs
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Threading.Tasks;
+
+public class DisposableJoinableTaskFactoryTests : TestBase
+{
+ private readonly JoinableTaskContext context;
+ private readonly JoinableTaskCollection joinableCollection;
+ private readonly JoinableTaskFactory asyncPump;
+
+ public DisposableJoinableTaskFactoryTests(ITestOutputHelper logger)
+ : base(logger)
+ {
+ this.context = JoinableTaskContext.CreateNoOpContext();
+ this.joinableCollection = this.context.CreateCollection();
+ this.asyncPump = this.context.CreateFactory(this.joinableCollection);
+ }
+
+ [Fact]
+ public void DisposeCancelsToken()
+ {
+ DisposableJoinableTaskFactory factory = new(this.asyncPump);
+
+ // The collection is already empty, so this completes immediately.
+ factory.Dispose();
+
+ Assert.True(factory.DisposalToken.IsCancellationRequested);
+ }
+
+ [Fact]
+ public async Task DisposeAsyncCancelsToken()
+ {
+ DisposableJoinableTaskFactory factory = new(this.asyncPump);
+
+ // The collection is already empty, so this completes immediately.
+ await factory.DisposeAsync();
+
+ Assert.True(factory.DisposalToken.IsCancellationRequested);
+ }
+
+ [Fact]
+ public void MultipleDisposalsDoNotThrow()
+ {
+ using DisposableJoinableTaskFactory factory = new(this.asyncPump);
+
+ factory.Dispose();
+ factory.Dispose();
+
+ Assert.True(factory.DisposalToken.IsCancellationRequested);
+ }
+
+ [Fact]
+ public async Task MultipleDisposeAsyncsDoesNotThrow()
+ {
+ DisposableJoinableTaskFactory factory = new(this.asyncPump);
+
+ await factory.DisposeAsync();
+ await factory.DisposeAsync();
+
+ Assert.True(factory.DisposalToken.IsCancellationRequested);
+ }
+
+ [Fact]
+ public void ConstructorRequiresCollection()
+ {
+ // A factory created with just a context has no collection.
+ JoinableTaskFactory factoryWithoutCollection = this.context.Factory;
+ Assert.ThrowsAny(() => new DisposableJoinableTaskFactory(factoryWithoutCollection));
+ }
+
+ [Fact]
+ public async Task TaskObservesCancellationDuringDrain()
+ {
+ DisposableJoinableTaskFactory factory = new(this.asyncPump);
+
+ JoinableTask jt = factory.RunAsync(() => Task.Delay(UnexpectedTimeout, factory.DisposalToken));
+
+ await factory.DisposeAsync().AsTask().WithTimeout(UnexpectedTimeout);
+ Assert.True(jt.IsCompleted);
+ await Assert.ThrowsAsync(async () => await jt);
+ }
+}