Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ public override async Task<ChatResponse> GetResponseAsync(
// approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message
// list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if
// the inner client had returned them directly.
(responseMessages, var notInvokedApprovals) = ProcessFunctionApprovalResponses(
(responseMessages, var notInvokedApprovals, var resultInsertionIndex) = ProcessFunctionApprovalResponses(
originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId: null, functionCallContentFallbackMessageId: null);
(IList<ChatMessage>? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) =
await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: false, cancellationToken);
await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, resultInsertionIndex, isStreaming: false, cancellationToken);

if (invokedApprovedFunctionApprovalResponses is not null)
{
Expand Down Expand Up @@ -381,7 +381,7 @@ public override async Task<ChatResponse> GetResponseAsync(

// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken);
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: false, cancellationToken);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

Expand Down Expand Up @@ -447,7 +447,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
// approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message
// list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if
// the inner client had returned them directly.
var (preDownstreamCallHistory, notInvokedApprovals) = ProcessFunctionApprovalResponses(
var (preDownstreamCallHistory, notInvokedApprovals, resultInsertionIndex) = ProcessFunctionApprovalResponses(
originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId, functionCallContentFallbackMessageId);
if (preDownstreamCallHistory is not null)
{
Expand All @@ -460,7 +460,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA

// Invoke approved approval responses, which generates some additional FRC wrapped in ChatMessage.
(IList<ChatMessage>? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) =
await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: true, cancellationToken);
await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, resultInsertionIndex, isStreaming: true, cancellationToken);

if (invokedApprovedFunctionApprovalResponses is not null)
{
Expand Down Expand Up @@ -604,7 +604,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId);

// Process all of the functions, adding their results into the history.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken);
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: true, cancellationToken);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

Expand Down Expand Up @@ -880,13 +880,14 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont
/// <param name="functionCallContents">The function call contents representing the functions to be invoked.</param>
/// <param name="iteration">The iteration number of how many roundtrips have been made to the inner client.</param>
/// <param name="consecutiveErrorCount">The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors.</param>
/// <param name="insertionIndex">The index at which to insert the function result messages, or -1 to append to the end.</param>
/// <param name="isStreaming">Whether the function calls are being processed in a streaming context.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>A value indicating how the caller should proceed.</returns>
private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList<ChatMessage> MessagesAdded)> ProcessFunctionCallsAsync(
List<ChatMessage> messages, ChatOptions? options,
Dictionary<string, AITool>? toolMap, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount,
bool isStreaming, CancellationToken cancellationToken)
int insertionIndex, bool isStreaming, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
Expand All @@ -905,7 +906,16 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont
IList<ChatMessage> addedMessages = CreateResponseMessages([result]);
ThrowIfNoFunctionResultsAdded(addedMessages);
UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount);
messages.AddRange(addedMessages);

// Insert at the specified position or append if no valid insertion index
if (insertionIndex >= 0 && insertionIndex <= messages.Count)
{
messages.InsertRange(insertionIndex, addedMessages);
}
else
{
messages.AddRange(addedMessages);
}

return (result.Terminate, consecutiveErrorCount, addedMessages);
}
Expand Down Expand Up @@ -950,7 +960,16 @@ select ProcessFunctionCallAsync(
IList<ChatMessage> addedMessages = CreateResponseMessages(results.ToArray());
ThrowIfNoFunctionResultsAdded(addedMessages);
UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount);
messages.AddRange(addedMessages);

// Insert at the specified position or append if no valid insertion index
if (insertionIndex >= 0 && insertionIndex <= messages.Count)
{
messages.InsertRange(insertionIndex, addedMessages);
}
else
{
messages.AddRange(addedMessages);
}

return (shouldTerminate, consecutiveErrorCount, addedMessages);
}
Expand Down Expand Up @@ -1248,12 +1267,13 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
/// 3. Genreate failed <see cref="FunctionResultContent"/> for any rejected <see cref="FunctionApprovalResponseContent"/>.
/// 4. add all the new content items to <paramref name="originalMessages"/> and return them as the pre-invocation history.
/// </summary>
private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals) ProcessFunctionApprovalResponses(
private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals, int insertionIndex) ProcessFunctionApprovalResponses(
List<ChatMessage> originalMessages, bool hasConversationId, string? toolMessageId, string? functionCallContentFallbackMessageId)
{
// Extract any approval responses where we need to execute or reject the function calls.
// The original messages are also modified to remove all approval requests and responses.
var notInvokedResponses = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages);
var (notInvokedApprovalsResult, notInvokedRejectionsResult, insertionIndex) = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages);
var notInvokedResponses = (approvals: notInvokedApprovalsResult, rejections: notInvokedRejectionsResult);

// Wrap the function call content in message(s).
ICollection<ChatMessage>? allPreDownstreamCallMessages = ConvertToFunctionCallContentMessages(
Expand All @@ -1269,25 +1289,54 @@ private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResult
// Add all the FCC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response.
// Also, if we are not dealing with a service thread (i.e. we don't have a conversation ID), add them
// into the original messages list so that they are passed to the inner client and can be used to generate a result.
// Insert at the position where the approval request was originally located to preserve message ordering.
List<ChatMessage>? preDownstreamCallHistory = null;
if (allPreDownstreamCallMessages is not null)
{
preDownstreamCallHistory = [.. allPreDownstreamCallMessages];
if (!hasConversationId)
{
originalMessages.AddRange(preDownstreamCallHistory);
// If we have a valid insertion index, insert at that position. Otherwise, append to the end.
if (insertionIndex >= 0 && insertionIndex <= originalMessages.Count)
{
originalMessages.InsertRange(insertionIndex, preDownstreamCallHistory);
}
else
{
originalMessages.AddRange(preDownstreamCallHistory);
}
}
}

// Add all the FRC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response.
// Also, add them into the original messages list so that they are passed to the inner client and can be used to generate a result.
// Insert immediately after the FCC messages to preserve message ordering.
if (rejectedPreDownstreamCallResultsMessage is not null)
{
(preDownstreamCallHistory ??= []).Add(rejectedPreDownstreamCallResultsMessage);
originalMessages.Add(rejectedPreDownstreamCallResultsMessage);

// Calculate the insertion position: right after the FCC messages we just inserted
// Only add the FCC count if they were actually inserted (!hasConversationId)
int rejectedInsertionIndex = insertionIndex >= 0 && insertionIndex <= originalMessages.Count
? insertionIndex + (!hasConversationId ? (allPreDownstreamCallMessages?.Count ?? 0) : 0)
: originalMessages.Count;

if (rejectedInsertionIndex >= 0 && rejectedInsertionIndex <= originalMessages.Count)
{
originalMessages.Insert(rejectedInsertionIndex, rejectedPreDownstreamCallResultsMessage);
}
else
{
originalMessages.Add(rejectedPreDownstreamCallResultsMessage);
}
}

return (preDownstreamCallHistory, notInvokedResponses.approvals);
// Calculate the insertion index for function result content (after the FCC messages and rejected FRC messages)
int resultInsertionIndex = insertionIndex >= 0 && insertionIndex <= originalMessages.Count && !hasConversationId
? insertionIndex + (allPreDownstreamCallMessages?.Count ?? 0) + (rejectedPreDownstreamCallResultsMessage is not null ? 1 : 0)
: -1;

return (preDownstreamCallHistory, notInvokedResponses.approvals, resultInsertionIndex);
}

/// <summary>
Expand All @@ -1299,13 +1348,14 @@ private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResult
/// We can then use the metadata from these messages when we re-create the FunctionCallContent messages/updates to return to the caller. This way, when we finally do return
/// the FuncionCallContent to users it's part of a message/update that contains the same metadata as originally returned to the downstream service.
/// </remarks>
private static (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections) ExtractAndRemoveApprovalRequestsAndResponses(
private static (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections, int insertionIndex) ExtractAndRemoveApprovalRequestsAndResponses(
List<ChatMessage> messages)
{
Dictionary<string, ChatMessage>? allApprovalRequestsMessages = null;
List<FunctionApprovalResponseContent>? allApprovalResponses = null;
HashSet<string>? approvalRequestCallIds = null;
HashSet<string>? functionResultCallIds = null;
int firstApprovalRequestIndex = -1;

// 1st iteration, over all messages and content:
// - Build a list of all function call ids that are already executed.
Expand All @@ -1330,6 +1380,13 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval
// Validation: Capture each call id for each approval request to ensure later we have a matching response.
_ = (approvalRequestCallIds ??= []).Add(farc.FunctionCall.CallId);
(allApprovalRequestsMessages ??= []).Add(farc.Id, message);

// Track the first approval request index for later insertion
if (firstApprovalRequestIndex == -1)
{
firstApprovalRequestIndex = i;
}

break;

case FunctionApprovalResponseContent farc:
Expand Down Expand Up @@ -1371,9 +1428,53 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval
}

// Clean up any messages that were marked for removal during the iteration.
// Also adjust the insertion index to account for removed messages.
int insertionIndex = firstApprovalRequestIndex;
if (anyRemoved)
{
// Count how many messages before the first approval request were removed
int removedBeforeInsertionIndex = 0;
if (firstApprovalRequestIndex >= 0)
{
for (int idx = 0; idx < firstApprovalRequestIndex; idx++)
{
if (messages[idx] is null)
{
removedBeforeInsertionIndex++;
}
}
}

_ = messages.RemoveAll(static m => m is null);

// Adjust the insertion index
if (insertionIndex >= 0)
{
insertionIndex -= removedBeforeInsertionIndex;
}
}

// If there are already-executed function results, insert new function calls at the end instead of at the insertion index
// to preserve the ordering of already-present function calls and results. This handles scenarios where:
// 1. Previous approval responses have been processed and their function calls/results are present in the message list
// 2. New approval responses are being processed
// In this case, we want the new function calls to come AFTER the existing ones, not at the position
// where the first (already-processed) approval request was originally located.
//
// Example:
// Before extraction (original user input with approval messages):
// [User, FunctionApprovalRequest(A), FunctionApprovalResponse(A), FunctionResult(A), FunctionApprovalRequest(B), FunctionApprovalResponse(B)]
// After extraction of approval requests/responses (state of 'messages' at this point):
// [User, FunctionResult(A)]
// After processing approval for B, if we inserted at the original index where B's approval request was,
// we'd incorrectly interleave new calls with existing results:
// [User, FunctionCall(B), FunctionResult(B), FunctionResult(A)] // Wrong order
// But if there are already function results present (e.g., for A), we instead append new function calls/results
// for B at the end to preserve chronological ordering:
// [User, FunctionResult(A), FunctionCall(B), FunctionResult(B)] // Correct order
if (functionResultCallIds is { Count: > 0 } && insertionIndex >= 0)
{
insertionIndex = messages.Count;
}

// Validation: If we got an approval for each request, we should have no call ids left.
Expand Down Expand Up @@ -1408,7 +1509,7 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval
}
}

return (approvedFunctionCalls, rejectedFunctionCalls);
return (approvedFunctionCalls, rejectedFunctionCalls, insertionIndex);
}

/// <summary>
Expand Down Expand Up @@ -1649,6 +1750,7 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
List<ChatMessage> originalMessages,
ChatOptions? options,
int consecutiveErrorCount,
int insertionIndex,
bool isStreaming,
CancellationToken cancellationToken)
{
Expand All @@ -1657,7 +1759,7 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
{
// The FRC that is generated here is already added to originalMessages by ProcessFunctionCallsAsync.
var modeAndMessages = await ProcessFunctionCallsAsync(
originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, isStreaming, cancellationToken);
originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, insertionIndex, isStreaming, cancellationToken);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

return (modeAndMessages.MessagesAdded, modeAndMessages.ShouldTerminate, consecutiveErrorCount);
Expand Down
Loading
Loading