diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index cc43c192f89..0bf5f07e8f6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -295,10 +295,10 @@ public override async Task GetResponseAsync( // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if // the inner client had returned them directly. - (responseMessages, var notInvokedApprovals) = ProcessFunctionApprovalResponses( + (responseMessages, var notInvokedApprovals, var approvalRequestIndices) = ProcessFunctionApprovalResponses( originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId: null, functionCallContentFallbackMessageId: null); (IList? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = - await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: false, cancellationToken); + await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, approvalRequestIndices, isStreaming: false, cancellationToken); if (invokedApprovedFunctionApprovalResponses is not null) { @@ -381,7 +381,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: false, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -447,7 +447,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if // the inner client had returned them directly. - var (preDownstreamCallHistory, notInvokedApprovals) = ProcessFunctionApprovalResponses( + var (preDownstreamCallHistory, notInvokedApprovals, approvalRequestIndices) = ProcessFunctionApprovalResponses( originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId, functionCallContentFallbackMessageId); if (preDownstreamCallHistory is not null) { @@ -460,7 +460,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // Invoke approved approval responses, which generates some additional FRC wrapped in ChatMessage. (IList? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = - await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: true, cancellationToken); + await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, approvalRequestIndices, isStreaming: true, cancellationToken); if (invokedApprovedFunctionApprovalResponses is not null) { @@ -604,7 +604,7 @@ public override async IAsyncEnumerable GetStreamingResponseA FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: true, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -880,13 +880,14 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(ListThe function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. + /// The index at which to insert the function result messages, or -1 to append to the end. /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( List messages, ChatOptions? options, Dictionary? toolMap, List functionCallContents, int iteration, int consecutiveErrorCount, - bool isStreaming, CancellationToken cancellationToken) + int insertionIndex, bool isStreaming, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. @@ -905,7 +906,16 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List addedMessages = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(addedMessages); UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); - messages.AddRange(addedMessages); + + // Insert at the specified position or append if no valid insertion index + if (insertionIndex >= 0 && insertionIndex <= messages.Count) + { + messages.InsertRange(insertionIndex, addedMessages); + } + else + { + messages.AddRange(addedMessages); + } return (result.Terminate, consecutiveErrorCount, addedMessages); } @@ -950,7 +960,16 @@ select ProcessFunctionCallAsync( IList addedMessages = CreateResponseMessages(results.ToArray()); ThrowIfNoFunctionResultsAdded(addedMessages); UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); - messages.AddRange(addedMessages); + + // Insert at the specified position or append if no valid insertion index + if (insertionIndex >= 0 && insertionIndex <= messages.Count) + { + messages.InsertRange(insertionIndex, addedMessages); + } + else + { + messages.AddRange(addedMessages); + } return (shouldTerminate, consecutiveErrorCount, addedMessages); } @@ -1248,46 +1267,90 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul /// 3. Genreate failed for any rejected . /// 4. add all the new content items to and return them as the pre-invocation history. /// - private static (List? preDownstreamCallHistory, List? approvals) ProcessFunctionApprovalResponses( + private static (List? preDownstreamCallHistory, List? approvals, Dictionary? approvalRequestIndices) ProcessFunctionApprovalResponses( List originalMessages, bool hasConversationId, string? toolMessageId, string? functionCallContentFallbackMessageId) { // Extract any approval responses where we need to execute or reject the function calls. // The original messages are also modified to remove all approval requests and responses. - var notInvokedResponses = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages); + var (notInvokedApprovalsResult, notInvokedRejectionsResult, approvalRequestIndices) = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages); + var notInvokedResponses = (approvals: notInvokedApprovalsResult, rejections: notInvokedRejectionsResult); - // Wrap the function call content in message(s). - ICollection? allPreDownstreamCallMessages = ConvertToFunctionCallContentMessages( - [.. notInvokedResponses.rejections ?? Enumerable.Empty(), .. notInvokedResponses.approvals ?? Enumerable.Empty()], - functionCallContentFallbackMessageId); + // Group approvals and rejections by their original index for proper insertion + var allResults = new List(); + if (notInvokedResponses.rejections is not null) + { + allResults.AddRange(notInvokedResponses.rejections); + } - // Generate failed function result contents for any rejected requests and wrap it in a message. - List? rejectedFunctionCallResults = GenerateRejectedFunctionResults(notInvokedResponses.rejections); - ChatMessage? rejectedPreDownstreamCallResultsMessage = rejectedFunctionCallResults is not null ? - new ChatMessage(ChatRole.Tool, rejectedFunctionCallResults) { MessageId = toolMessageId } : - null; + if (notInvokedResponses.approvals is not null) + { + allResults.AddRange(notInvokedResponses.approvals); + } + + // Sort by index in descending order so we can insert from end to start without index shifting issues + var sortedResults = allResults + .Where(r => approvalRequestIndices?.ContainsKey(r.Response.FunctionCall.CallId) == true) + .OrderByDescending(r => approvalRequestIndices![r.Response.FunctionCall.CallId]) + .ToList(); - // Add all the FCC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. - // Also, if we are not dealing with a service thread (i.e. we don't have a conversation ID), add them - // into the original messages list so that they are passed to the inner client and can be used to generate a result. List? preDownstreamCallHistory = null; - if (allPreDownstreamCallMessages is not null) + + // Process each approval/rejection and insert at its original position + foreach (var result in sortedResults) { - preDownstreamCallHistory = [.. allPreDownstreamCallMessages]; - if (!hasConversationId) + string callId = result.Response.FunctionCall.CallId; + int insertionIndex = approvalRequestIndices![callId]; + + // Convert this specific result to FunctionCallContent message + var fccMessages = ConvertToFunctionCallContentMessages([result], functionCallContentFallbackMessageId); + if (fccMessages is not null) { - originalMessages.AddRange(preDownstreamCallHistory); + // Add to history + if (preDownstreamCallHistory is null) + { + preDownstreamCallHistory = [.. fccMessages]; + } + else + { + preDownstreamCallHistory.InsertRange(0, fccMessages); + } + + // Insert into original messages if not using conversation ID + if (!hasConversationId && insertionIndex >= 0 && insertionIndex <= originalMessages.Count) + { + originalMessages.InsertRange(insertionIndex, fccMessages); + } } - } - // Add all the FRC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. - // Also, add them into the original messages list so that they are passed to the inner client and can be used to generate a result. - if (rejectedPreDownstreamCallResultsMessage is not null) - { - (preDownstreamCallHistory ??= []).Add(rejectedPreDownstreamCallResultsMessage); - originalMessages.Add(rejectedPreDownstreamCallResultsMessage); + // For rejections, also insert the rejection result + if (!result.Response.Approved) + { + var rejectedContent = GenerateRejectedFunctionResults([result]); + if (rejectedContent is not null) + { + var rejectedMessage = new ChatMessage(ChatRole.Tool, rejectedContent) { MessageId = toolMessageId }; + + // Add to history + if (preDownstreamCallHistory is null) + { + preDownstreamCallHistory = [rejectedMessage]; + } + else + { + preDownstreamCallHistory.Insert(fccMessages?.Count ?? 0, rejectedMessage); + } + + // Insert rejection result right after the FCC messages + int rejectedInsertionIndex = insertionIndex + (fccMessages?.Count ?? 0); + if (rejectedInsertionIndex >= 0 && rejectedInsertionIndex <= originalMessages.Count) + { + originalMessages.Insert(rejectedInsertionIndex, rejectedMessage); + } + } + } } - return (preDownstreamCallHistory, notInvokedResponses.approvals); + return (preDownstreamCallHistory, notInvokedResponses.approvals, approvalRequestIndices); } /// @@ -1299,19 +1362,21 @@ private static (List? preDownstreamCallHistory, List - private static (List? approvals, List? rejections) ExtractAndRemoveApprovalRequestsAndResponses( + private static (List? approvals, List? rejections, Dictionary? approvalRequestIndices) ExtractAndRemoveApprovalRequestsAndResponses( List messages) { Dictionary? allApprovalRequestsMessages = null; List? allApprovalResponses = null; HashSet? approvalRequestCallIds = null; HashSet? functionResultCallIds = null; + Dictionary? approvalRequestIndices = null; // 1st iteration, over all messages and content: // - Build a list of all function call ids that are already executed. // - Build a list of all function approval requests and responses. // - Build a list of the content we want to keep (everything except approval requests and responses) and create a new list of messages for those. // - Validate that we have an approval response for each approval request. + // - Track the original index of each approval request by call ID bool anyRemoved = false; int i = 0; for (; i < messages.Count; i++) @@ -1330,6 +1395,14 @@ private static (List? approvals, List? approvals, List m is null); + + // Adjust all approval request indices + if (approvalRequestIndices is not null) + { + List callIds = [.. approvalRequestIndices.Keys]; + foreach (var callId in callIds) + { + int originalIndex = approvalRequestIndices[callId]; + approvalRequestIndices[callId] = originalIndex - removedBeforeIndex[originalIndex]; + } + } } // Validation: If we got an approval for each request, we should have no call ids left. @@ -1388,6 +1485,7 @@ private static (List? approvals, List? approvedFunctionCalls = null, rejectedFunctionCalls = null; + bool hasAlreadyExecutedApprovals = false; if (allApprovalResponses is { Count: > 0 }) { foreach (var approvalResponse in allApprovalResponses) @@ -1395,6 +1493,7 @@ private static (List? approvals, List? approvals, List callIds = [.. approvalRequestIndices.Keys]; + foreach (var callId in callIds) + { + approvalRequestIndices[callId] = messages.Count; + } + } + + return (approvedFunctionCalls, rejectedFunctionCalls, approvalRequestIndices); } /// @@ -1658,15 +1786,28 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => List originalMessages, ChatOptions? options, int consecutiveErrorCount, + Dictionary? approvalRequestIndices, bool isStreaming, CancellationToken cancellationToken) { // Check if there are any function calls to do for any approved functions and execute them. if (notInvokedApprovals is { Count: > 0 }) { + // For now, use the first approval's index, or -1 if not found + // Future enhancement: Process each approval individually at its correct position + int insertionIndex = -1; + if (approvalRequestIndices is not null && notInvokedApprovals.Count > 0) + { + string firstCallId = notInvokedApprovals[0].Response.FunctionCall.CallId; + if (approvalRequestIndices.TryGetValue(firstCallId, out int index)) + { + insertionIndex = index; + } + } + // The FRC that is generated here is already added to originalMessages by ProcessFunctionCallsAsync. var modeAndMessages = await ProcessFunctionCallsAsync( - originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, isStreaming, cancellationToken); + originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, insertionIndex, isStreaming, cancellationToken); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; return (modeAndMessages.MessagesAdded, modeAndMessages.ShouldTerminate, consecutiveErrorCount); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs index 01f2e111447..9d9483a4cfa 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs @@ -1113,6 +1113,198 @@ async IAsyncEnumerable YieldInnerClientUpdates( } } + [Fact] + public async Task RejectionWithUserMessageAfterApprovalResponsePreservesOrderingAsync() + { + // This test verifies that when a user adds a message after the approval response, + // the message ordering is preserved. The reconstructed FunctionCallContent and + // FunctionResultContent should be inserted at the position where the approval + // request was originally located, not at the end. + var options = new ChatOptions + { + Tools = + [ + new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1")), + ] + }; + + List input = + [ + new ChatMessage(ChatRole.User, "1st message"), + new ChatMessage(ChatRole.Assistant, + [ + new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1")) + ]) { MessageId = "resp1" }, + new ChatMessage(ChatRole.User, + [ + new FunctionApprovalResponseContent("callId1", false, new FunctionCallContent("callId1", "Func1")) + ]), + new ChatMessage(ChatRole.User, "2nd message"), // This should stay at the end + ]; + + // The expected input to downstream client should have messages in this order: + // 1. User "1st message" + // 2. Assistant with FunctionCallContent + // 3. Tool with rejection result + // 4. User "2nd message" (preserved at the end) + List expectedDownstreamClientInput = + [ + new ChatMessage(ChatRole.User, "1st message"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Tool call invocation rejected.")]), + new ChatMessage(ChatRole.User, "2nd message"), + ]; + + List downstreamClientOutput = + [ + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + List output = + [ + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Tool call invocation rejected.")]), + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + await InvokeAndAssertAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + + await InvokeAndAssertStreamingAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + } + + [Fact] + public async Task ApprovalWithUserMessageAfterApprovalResponsePreservesOrderingAsync() + { + // This test verifies that when a user approves and adds a message after the approval response, + // the message ordering is preserved. + var options = new ChatOptions + { + Tools = + [ + new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1")), + ] + }; + + List input = + [ + new ChatMessage(ChatRole.User, "1st message"), + new ChatMessage(ChatRole.Assistant, + [ + new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1")) + ]) { MessageId = "resp1" }, + new ChatMessage(ChatRole.User, + [ + new FunctionApprovalResponseContent("callId1", true, new FunctionCallContent("callId1", "Func1")) + ]), + new ChatMessage(ChatRole.User, "2nd message"), // This should stay at the end + ]; + + // The expected input to downstream client should have messages in this order: + // 1. User "1st message" + // 2. Assistant with FunctionCallContent + // 3. Tool with function result + // 4. User "2nd message" (preserved at the end) + List expectedDownstreamClientInput = + [ + new ChatMessage(ChatRole.User, "1st message"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.User, "2nd message"), + ]; + + List downstreamClientOutput = + [ + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + List output = + [ + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + await InvokeAndAssertAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + + await InvokeAndAssertStreamingAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + } + + [Fact] + public async Task MultipleApprovalRequestResponsePairsWithInterleavedUserMessagesPreservesOrderingAsync() + { + // This test verifies that when there are multiple approval request/response pairs + // in a single call with user messages interleaved between them, the message ordering + // is preserved correctly. All approvals are processed in one invocation. + var options = new ChatOptions + { + Tools = + [ + new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1")), + new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 2", "Func2")), + ] + }; + + List input = + [ + new ChatMessage(ChatRole.User, "1st user message"), + new ChatMessage(ChatRole.Assistant, + [ + new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1")) + ]) { MessageId = "resp1" }, + new ChatMessage(ChatRole.User, + [ + new FunctionApprovalResponseContent("callId1", true, new FunctionCallContent("callId1", "Func1")) + ]), + new ChatMessage(ChatRole.User, "2nd user message"), + new ChatMessage(ChatRole.Assistant, + [ + new FunctionApprovalRequestContent("callId2", new FunctionCallContent("callId2", "Func2")) + ]) { MessageId = "resp2" }, + new ChatMessage(ChatRole.User, + [ + new FunctionApprovalResponseContent("callId2", true, new FunctionCallContent("callId2", "Func2")) + ]), + new ChatMessage(ChatRole.User, "3rd user message"), + ]; + + // The expected input to downstream client should preserve all message ordering: + // 1. User "1st user message" - should remain in place + // 2. Assistant with FunctionCallContent(callId1) - recreated from approval + // 3. Tool with FunctionResultContent(callId1) - from executing approved function + // 4. User "2nd user message" - should remain in place + // 5. Assistant with FunctionCallContent(callId2) - recreated from approval + // 6. Tool with FunctionResultContent(callId2) - from executing approved function + // 7. User "3rd user message" - should remain at the end + List expectedDownstreamClientInput = + [ + new ChatMessage(ChatRole.User, "1st user message"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.User, "2nd user message"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2")]), + new ChatMessage(ChatRole.User, "3rd user message"), + ]; + + List downstreamClientOutput = + [ + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + List output = + [ + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2")]), + new ChatMessage(ChatRole.Assistant, "Final response"), + ]; + + await InvokeAndAssertAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + + await InvokeAndAssertStreamingAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); + } + private static Task> InvokeAndAssertAsync( ChatOptions? options, List input,