diff --git a/src/SosThreadingTools/DumpAsyncCommand.cs b/src/SosThreadingTools/DumpAsyncCommand.cs index de3686f8d..88c73a216 100644 --- a/src/SosThreadingTools/DumpAsyncCommand.cs +++ b/src/SosThreadingTools/DumpAsyncCommand.cs @@ -164,6 +164,12 @@ private static void GetAllStateMachines(ClrHeap heap, List al } } } + + ClrObject currentObject = stateMachine.TryGetObjectField("<>4__this"); + if (!currentObject.IsNull && currentObject.Type is not null && string.Equals(currentObject.Type.Name, "Microsoft.VisualStudio.Threading.JoinableTaskCollection", StringComparison.Ordinal)) + { + asyncState.WaitingJoinableTasks = GetJoinableTasksFromCollection(currentObject); + } } } #pragma warning disable CA1031 // Do not catch general exception types @@ -183,6 +189,44 @@ private static void GetAllStateMachines(ClrHeap heap, List al } } + private static List GetJoinableTasksFromCollection(ClrObject joinableTaskCollection) + { + var joinableTasks = new List(); + ClrValueType? dependentData = joinableTaskCollection.TryGetValueClassField("dependentData"); + if (dependentData is not null) + { + ClrObject childDependentNodes = dependentData.TryGetObjectField("childDependentNodes"); + if (!childDependentNodes.IsNull && + childDependentNodes.TryReadField("count", out int count) && + childDependentNodes.TryReadField("freeCount", out int freeCount)) + { + count -= freeCount; + if (count > 0) + { + ClrObject entries = childDependentNodes.TryGetObjectField("entries"); + if (!entries.IsNull && entries.IsArray && entries.AsArray() is ClrArray entriesArray) + { + for (int i = 0; i < entriesArray.Length; i++) + { + ClrValueType? value = entriesArray.GetStructValue(i); + ClrObject key = value.TryGetObjectField("key"); + if (!key.IsNull) + { + joinableTasks.Add(key); + if (--count == 0) + { + break; + } + } + } + } + } + } + } + + return joinableTasks; + } + private static void ChainStateMachinesBasedOnTaskContinuations(Dictionary knownStateMachines) { foreach (AsyncStateMachine? stateMachine in knownStateMachines.Values) @@ -285,16 +329,13 @@ private static void ChainStateMachinesBasedOnJointableTasks(List4__this"); - ClrObject wrappedTask = joinableTask.TryGetObjectField("wrappedTask"); - if (!wrappedTask.IsNull) + FindWaitingTaskFromJoinableTask(joinableTask, stateMachine); + + if (stateMachine.WaitingJoinableTasks is List joinableTasks) { - AsyncStateMachine? previousStateMachine = allStateMachines - .FirstOrDefault(s => s.Task.Address == wrappedTask.Address); - if (previousStateMachine is object && stateMachine != previousStateMachine) + foreach (ClrObject waitingJoinableTask in joinableTasks) { - stateMachine.Previous = previousStateMachine; - previousStateMachine.Next = stateMachine; - previousStateMachine.DependentCount++; + FindWaitingTaskFromJoinableTask(waitingJoinableTask, stateMachine); } } } @@ -306,6 +347,30 @@ private static void ChainStateMachinesBasedOnJointableTasks(List s.Task.Address == wrappedTask.Address); + if (previousStateMachine is object && currentStateMachine != previousStateMachine) + { + if (currentStateMachine.Previous is null) + { + currentStateMachine.Previous = previousStateMachine; + previousStateMachine.Next = currentStateMachine; + } + else + { + previousStateMachine.Next ??= currentStateMachine; + } + + previousStateMachine.DependentCount++; + } + } + } } private static void MarkUIThreadDependingTasks(List allStateMachines) @@ -388,7 +453,7 @@ private void MarkThreadingBlockTasks(ClrHeap heap, List allSt } } - break; + continue; } } } @@ -436,12 +501,9 @@ private void PrintOutStateMachines(List allStateMachines) .OrderByDescending(m => m.Depth) .ThenByDescending(m => m.SwitchToMainThreadTask.Address)) { - bool multipleLineBlock = this.PrintAsyncStateMachineChain(node, printedMachines); + this.PrintAsyncStateMachineChain(node, printedMachines); - if (multipleLineBlock) - { - Console.WriteLine(string.Empty); - } + Console.WriteLine(string.Empty); } // Print nodes which we didn't print because of loops. @@ -459,10 +521,9 @@ private void PrintOutStateMachines(List allStateMachines) } } - private bool PrintAsyncStateMachineChain(AsyncStateMachine node, HashSet printedMachines) + private void PrintAsyncStateMachineChain(AsyncStateMachine node, HashSet printedMachines) { int nLevel = 0; - bool multipleLineBlock = false; var loopDetection = new HashSet(); for (AsyncStateMachine? p = node; p is object; p = p.Next) @@ -472,7 +533,6 @@ private bool PrintAsyncStateMachineChain(AsyncStateMachine node, HashSet 0) { this.WriteString(".."); - multipleLineBlock = true; } else if (p.AlterPrevious is object) { @@ -481,14 +541,12 @@ private bool PrintAsyncStateMachineChain(AsyncStateMachine node, HashSet allStateMachines) @@ -573,6 +627,8 @@ public AsyncStateMachine(int state, ClrObject stateMachine, ClrObject task) public AsyncStateMachine? AlterPrevious { get; set; } + public List? WaitingJoinableTasks { get; set; } + public ulong CodeAddress { get; set; } public override string ToString()