diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index cc0c15eda5..dd294040c4 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -62,7 +62,7 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages) { ResponseMessages = responseMessages }; @@ -94,7 +94,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages) { ResponseMessages = responseMessages }; diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 509b79e53f..edd9248ff9 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -104,7 +104,7 @@ public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonS public UserInfo UserInfo { get; set; } - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { // Try and extract the user name and age from the message if we don't have it already and it's a user message. if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) @@ -122,7 +122,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } } - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { StringBuilder instructions = new(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 81a2beb3da..33176d8fdf 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -89,7 +89,7 @@ public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serialized public string? SessionDbKey { get; private set; } - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); @@ -107,7 +107,7 @@ public override async ValueTask> InvokingAsync(Invoking return messages; } - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { // Don't store messages if the request failed. if (context.InvokeException is not null) @@ -122,7 +122,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio // Add both request and response messages to the store // Optionally messages produced by the AIContextProvider can also be persisted (not shown). - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 28b9780d17..055172fa82 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -92,7 +92,7 @@ public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions? } } - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { StringBuilder outputMessageBuilder = new(); outputMessageBuilder.AppendLine("Your todo list contains the following items:"); @@ -132,7 +132,7 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio /// internal sealed class CalendarSearchAIContextProvider(Func> loadNextThreeCalendarEvents) : AIContextProvider { - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var events = await loadNextThreeCalendarEvents(); @@ -179,7 +179,7 @@ public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonEle .ToList(); } - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { // Invoke all the sub providers. var tasks = this._providers.Select(provider => provider.InvokingAsync(context, cancellationToken).AsTask()); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f79b0a851d..a4b606e6a1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -31,6 +32,25 @@ namespace Microsoft.Agents.AI; /// public abstract class AIContextProvider { + private readonly string _sourceName; + + /// + /// Initializes a new instance of the class. + /// + protected AIContextProvider() + { + this._sourceName = this.GetType().FullName!; + } + + /// + /// Initializes a new instance of the class with the specified source name. + /// + /// The source name to stamp on for each messages produced by the . + protected AIContextProvider(string sourceName) + { + this._sourceName = sourceName; + } + /// /// Called at the start of agent invocation to provide additional context. /// @@ -48,7 +68,81 @@ public abstract class AIContextProvider /// /// /// - public abstract ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); + public async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var aiContext = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); + if (aiContext.Messages is null) + { + return aiContext; + } + + aiContext.Messages = aiContext.Messages.Select(message => + { + if (message.AdditionalProperties != null + // Check if the message was already tagged with this provider's source type + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) + && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType + && typedMessageSourceType == AgentRequestMessageSourceType.AIContextProvider + // Check if the message was already tagged with this provider's source + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) + && messageSource is string typedMessageSource + && typedMessageSource == this._sourceName) + { + return message; + } + + message = message.Clone(); + message.AdditionalProperties ??= new(); + message.AdditionalProperties[AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.AIContextProvider; + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = this._sourceName; + return message; + }).ToList(); + + return aiContext; + } + + /// + /// Called at the start of agent invocation to provide additional context. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the with additional context to be used by the agent during this invocation. + /// + /// + /// Implementers can load any additional context required at this time, such as: + /// + /// Retrieving relevant information from knowledge bases + /// Adding system instructions or prompts + /// Providing function tools for the current invocation + /// Injecting contextual messages from conversation history + /// + /// + /// + protected abstract ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default); + + /// + /// Called at the end of the agent invocation to process the invocation results. + /// + /// Contains the invocation context including request messages, response messages, and any exception that occurred. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + /// + /// + /// Implementers can use the request and response messages in the provided to: + /// + /// Update internal state based on conversation outcomes + /// Extract and store memories or preferences from user messages + /// Log or audit conversation details + /// Perform cleanup or finalization tasks + /// + /// + /// + /// This method is called regardless of whether the invocation succeeded or failed. + /// To check if the invocation was successful, inspect the property. + /// + /// + public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + => this.InvokedCoreAsync(context, cancellationToken); /// /// Called at the end of the agent invocation to process the invocation results. @@ -71,7 +165,7 @@ public abstract class AIContextProvider /// To check if the invocation was successful, inspect the property. /// /// - public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; /// @@ -117,7 +211,7 @@ public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOption => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about the invocation before the underlying AI model is invoked, including the messages @@ -163,7 +257,7 @@ public InvokingContext( } /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about a completed agent invocation, including both the @@ -178,18 +272,15 @@ public sealed class InvokedContext /// The agent being invoked. /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. - /// The messages provided by the for this invocation, if any. /// is . public InvokedContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages, - IEnumerable? aiContextProviderMessages) + IEnumerable requestMessages) { this.Agent = Throw.IfNull(agent); this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); - this.AIContextProviderMessages = aiContextProviderMessages; } /// @@ -211,15 +302,6 @@ public InvokedContext( /// public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } - /// - /// Gets the messages provided by the for this invocation, if any. - /// - /// - /// A collection of instances that were provided by the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? AIContextProviderMessages { get; set; } - /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs new file mode 100644 index 0000000000..127f1c1b8d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSource.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Provides a constant for the key used to store the source of the agent request message. +/// +public static class AgentRequestMessageSource +{ + /// + /// Provides the key used in to store the source of the agent request message. + /// + public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSource"; +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs new file mode 100644 index 0000000000..1cca747906 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Represents the source of an agent request message. +/// +/// +/// Input messages for a specific agent run can originate from various sources. +/// This type helps to identify whether a message came from outside the agent pipeline, +/// whether it was produced by middleware, or came from chat history. +/// +public sealed class AgentRequestMessageSourceType : IEquatable +{ + /// + /// Provides the key used in to store the source type of the agent request message. + /// + public static readonly string AdditionalPropertiesKey = "Agent.RequestMessageSourceType"; + + /// + /// Initializes a new instance of the class. + /// + /// The string value representing the source of the agent request message. + public AgentRequestMessageSourceType(string value) => this.Value = Throw.IfNullOrWhitespace(value); + + /// + /// Get the string value representing the source of the agent request message. + /// + public string Value { get; } + + /// + /// The message came from outside the agent pipeline (e.g., user input). + /// + public static AgentRequestMessageSourceType External { get; } = new AgentRequestMessageSourceType(nameof(External)); + + /// + /// The message was produced by middleware. + /// + public static AgentRequestMessageSourceType AIContextProvider { get; } = new AgentRequestMessageSourceType(nameof(AIContextProvider)); + + /// + /// The message came from chat history. + /// + public static AgentRequestMessageSourceType ChatHistory { get; } = new AgentRequestMessageSourceType(nameof(ChatHistory)); + + /// + /// Determines whether this instance and another specified object have the same value. + /// + /// The to compare to this instance. + /// if the value of the parameter is the same as the value of this instance; otherwise, . + public bool Equals(AgentRequestMessageSourceType? other) + { + if (other is null) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return string.Equals(this.Value, other.Value, StringComparison.Ordinal); + } + + /// + /// Determines whether this instance and a specified object have the same value. + /// + /// The object to compare to this instance. + /// if is a and its value is the same as this instance; otherwise, . + public override bool Equals(object? obj) => this.Equals(obj as AgentRequestMessageSourceType); + + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() => this.Value?.GetHashCode() ?? 0; + + /// + /// Determines whether two specified objects have the same value. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is the same as the value of ; otherwise, . + public static bool operator ==(AgentRequestMessageSourceType? left, AgentRequestMessageSourceType? right) + { + if (left is null) + { + return right is null; + } + + return left.Equals(right); + } + + /// + /// Determines whether two specified objects have different values. + /// + /// The first to compare. + /// The second to compare. + /// if the value of is different from the value of ; otherwise, . + public static bool operator !=(AgentRequestMessageSourceType? left, AgentRequestMessageSourceType? right) => !(left == right); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 352bae3355..f49c5d46a7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -36,6 +37,81 @@ namespace Microsoft.Agents.AI; /// public abstract class ChatHistoryProvider { + private readonly string _sourceName; + + /// + /// Initializes a new instance of the class. + /// + protected ChatHistoryProvider() + { + this._sourceName = this.GetType().FullName!; + } + + /// + /// Initializes a new instance of the class with the specified source name. + /// + /// The source name to stamp on for each messages produced by the . + protected ChatHistoryProvider(string sourceName) + { + this._sourceName = sourceName; + } + + /// + /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. + /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. + /// The to monitor for cancellation requests. The default is . + /// + /// A task that represents the asynchronous operation. The task result contains a collection of + /// instances in ascending chronological order (oldest first). + /// + /// + /// + /// Messages are returned in chronological order to maintain proper conversation flow and context for the agent. + /// The oldest messages appear first in the collection, followed by more recent messages. + /// + /// + /// If the total message history becomes very large, implementations should apply appropriate strategies to manage + /// storage constraints, such as: + /// + /// Truncating older messages while preserving recent context + /// Summarizing message groups to maintain essential context + /// Implementing sliding window approaches for message retention + /// Archiving old messages while keeping active conversation context + /// + /// + /// + /// Each instance should be associated with a single to ensure proper message isolation + /// and context management. + /// + /// + public async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var messages = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); + + return messages.Select(message => + { + if (message.AdditionalProperties != null + // Check if the message was already tagged with this provider's source type + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var messageSourceType) + && messageSourceType is AgentRequestMessageSourceType typedMessageSourceType + && typedMessageSourceType == AgentRequestMessageSourceType.ChatHistory + // Check if the message was already tagged with this provider's source + && message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out var messageSource) + && messageSource is string typedMessageSource + && typedMessageSource == this._sourceName) + { + return message; + } + + message = message.Clone(); + message.AdditionalProperties ??= new(); + message.AdditionalProperties[AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.ChatHistory; + message.AdditionalProperties[AgentRequestMessageSource.AdditionalPropertiesKey] = this._sourceName; + return message; + }); + } + /// /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. /// @@ -65,7 +141,7 @@ public abstract class ChatHistoryProvider /// and context management. /// /// - public abstract ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); + protected abstract ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default); /// /// Called at the end of the agent invocation to add new messages to the chat history. @@ -77,7 +153,7 @@ public abstract class ChatHistoryProvider /// /// Messages should be added in the order they were generated to maintain proper chronological sequence. /// The is responsible for preserving message ordering and ensuring that subsequent calls to - /// return messages in the correct chronological order. + /// return messages in the correct chronological order. /// /// /// Implementations may perform additional processing during message addition, such as: @@ -92,7 +168,35 @@ public abstract class ChatHistoryProvider /// To check if the invocation was successful, inspect the property. /// /// - public abstract ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default); + public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) => + this.InvokedCoreAsync(context, cancellationToken); + + /// + /// Called at the end of the agent invocation to add new messages to the chat history. + /// + /// Contains the invocation context including request messages, response messages, and any exception that occurred. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous add operation. + /// + /// + /// Messages should be added in the order they were generated to maintain proper chronological sequence. + /// The is responsible for preserving message ordering and ensuring that subsequent calls to + /// return messages in the correct chronological order. + /// + /// + /// Implementations may perform additional processing during message addition, such as: + /// + /// Validating message content and metadata + /// Applying storage optimizations or compression + /// Triggering background maintenance operations + /// + /// + /// + /// This method is called regardless of whether the invocation succeeded or failed. + /// To check if the invocation was successful, inspect the property. + /// + /// + protected abstract ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default); /// /// Serializes the current object's state to a using the specified serialization options. @@ -131,7 +235,7 @@ public abstract class ChatHistoryProvider => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about the invocation including the new messages that will be used. @@ -177,7 +281,7 @@ public InvokingContext( } /// - /// Contains the context information provided to . + /// Contains the context information provided to . /// /// /// This class provides context about a completed agent invocation, including both the @@ -192,18 +296,15 @@ public sealed class InvokedContext /// The agent being invoked. /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. - /// The messages retrieved from the for this invocation. /// is . public InvokedContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages, - IEnumerable? chatHistoryProviderMessages) + IEnumerable requestMessages) { this.Agent = Throw.IfNull(agent); this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); - this.ChatHistoryProviderMessages = chatHistoryProviderMessages; } /// @@ -225,24 +326,6 @@ public InvokedContext( /// public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } - /// - /// Gets the messages retrieved from the for this invocation, if any. - /// - /// - /// A collection of instances that were retrieved from the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? ChatHistoryProviderMessages { get; set; } - - /// - /// Gets or sets the messages provided by the for this invocation, if any. - /// - /// - /// A collection of instances that were provided by the , - /// and were used by the agent as part of the invocation. - /// - public IEnumerable? AIContextProviderMessages { get; set; } - /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs index 0f5d9524cb..c2ff8bf3e5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -33,8 +34,8 @@ public static ChatHistoryProvider WithMessageFilters( } /// - /// Decorates the provided chat message so that it does not add - /// messages produced by any to chat history. + /// Decorates the provided so that it does not add + /// messages with to chat history. /// /// The to add the message filter to. /// A new instance that filters out messages so they do not get added. @@ -44,7 +45,7 @@ public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatH innerProvider: provider, invokedMessagesFilter: (ctx) => { - ctx.AIContextProviderMessages = null; + ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() != AgentRequestMessageSourceType.AIContextProvider); return ctx; }); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs index df7b536ea2..6cee80986b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs @@ -49,14 +49,14 @@ public ChatHistoryProviderMessageFilter( } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { var messages = await this._innerProvider.InvokingAsync(context, cancellationToken).ConfigureAwait(false); return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages; } /// - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (this._invokedMessagesFilter != null) { diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs new file mode 100644 index 0000000000..01edcb4eff --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Contains extension methods for +/// +public static class ChatMessageExtensions +{ + /// + /// Gets the source of the provided in the context of messages passed into an agent run. + /// + /// The for which we need the source. + /// An value indicating the source of the . Defaults to if no explicit source is defined. + public static AgentRequestMessageSourceType GetAgentRequestMessageSource(this ChatMessage message) + { + if (message.AdditionalProperties?.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out var source) is true && source is AgentRequestMessageSourceType typedSource) + { + return typedSource; + } + + return AgentRequestMessageSourceType.External; + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index ab408c6a5e..001b3a3bcc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -133,7 +133,7 @@ public ChatMessage this[int index] } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -146,7 +146,7 @@ public override async ValueTask> InvokingAsync(Invoking } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -155,8 +155,8 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; } - // Add request, AI context provider, and response messages to the provider - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + // Add request and response messages to the provider + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); this._messages.AddRange(allNewMessages); if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) @@ -229,7 +229,7 @@ public enum ChatReducerTriggerEvent { /// /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. + /// will only complete when reducer processing is done. /// AfterMessageAdded, diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index 41c9a211dc..85c5865f07 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -287,7 +287,7 @@ public static CosmosChatHistoryProvider CreateFromSerializedState(CosmosClient c } /// - public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -347,7 +347,7 @@ public override async ValueTask> InvokingAsync(Invoking } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { Throw.IfNull(context); @@ -364,7 +364,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } #pragma warning restore CA1513 - var messageList = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []).ToList(); + var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 0e9b4288b1..8a0c016f07 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -131,13 +131,16 @@ public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSeri } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { Throw.IfNull(context); string queryText = string.Join( Environment.NewLine, - context.RequestMessages.Where(m => !string.IsNullOrWhiteSpace(m.Text)).Select(m => m.Text)); + context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) + .Where(m => !string.IsNullOrWhiteSpace(m.Text)) + .Select(m => m.Text)); try { @@ -202,7 +205,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (context.InvokeException is not null) { @@ -212,7 +215,11 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio try { // Persist request and response messages after invocation. - await this.PersistMessagesAsync(context.RequestMessages.Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); + await this.PersistMessagesAsync( + context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) + .Concat(context.ResponseMessages ?? []), + cancellationToken).ConfigureAwait(false); } catch (Exception ex) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index afe6706553..f631de8e8a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -46,17 +46,17 @@ internal sealed class StoreState internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) => new(this._chatMessages.AsReadOnly()); - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { if (context.InvokeException is not null) { return default; } - var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); this._chatMessages.AddRange(allNewMessages); return default; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 1567e2bcc1..5878d877b2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -206,9 +206,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA (ChatClientAgentSession safeSession, ChatOptions? chatOptions, + List inputMessagesForProviders, List inputMessagesForChatClient, - IList? aiContextProviderMessages, - IList? chatHistoryProviderMessages, ChatClientAgentContinuationToken? continuationToken) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -227,12 +226,12 @@ protected override async IAsyncEnumerable RunCoreStreamingA try { // Using the enumerator to ensure we consider the case where no updates are returned for notification. - responseUpdatesEnumerator = chatClient.GetStreamingResponseAsync(inputMessagesForChatClient, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); + responseUpdatesEnumerator = chatClient.GetStreamingResponseAsync(inputMessagesForProviders, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -246,8 +245,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -273,8 +272,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); throw; } } @@ -286,10 +285,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -433,9 +432,8 @@ private async Task RunCoreAsync inputMessagesForProviders, List inputMessagesForChatClient, - IList? aiContextProviderMessages, - IList? chatHistoryProviderMessages, ChatClientAgentContinuationToken? _) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -455,8 +453,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync inputMessages, - IList? aiContextProviderMessages, IEnumerable responseMessages, CancellationToken cancellationToken) { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -509,12 +506,11 @@ private async Task NotifyAIContextProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable inputMessages, - IList? aiContextProviderMessages, CancellationToken cancellationToken) { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -683,9 +679,8 @@ private async Task <( ChatClientAgentSession AgentSession, ChatOptions? ChatOptions, + List inputMessagesForProviders, List InputMessagesForChatClient, - IList? AIContextProviderMessages, - IList? ChatHistoryProviderMessages, ChatClientAgentContinuationToken? ContinuationToken )> PrepareSessionAndMessagesAsync( AgentSession? session, @@ -714,9 +709,8 @@ private async Task throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token."); } + List inputMessagesForProviders = []; List inputMessagesForChatClient = []; - IList? aiContextProviderMessages = null; - IList? chatHistoryProviderMessages = null; // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) @@ -729,10 +723,10 @@ private async Task var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages); var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); inputMessagesForChatClient.AddRange(providerMessages); - chatHistoryProviderMessages = providerMessages as IList ?? providerMessages.ToList(); } // Add the input messages before getting context from AIContextProvider. + inputMessagesForProviders.AddRange(inputMessages); inputMessagesForChatClient.AddRange(inputMessages); // If we have an AIContextProvider, we should get context from it, and update our @@ -743,8 +737,8 @@ private async Task var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { + inputMessagesForProviders.AddRange(aiContext.Messages); inputMessagesForChatClient.AddRange(aiContext.Messages); - aiContextProviderMessages = aiContext.Messages; } if (aiContext.Tools is { Count: > 0 }) @@ -783,7 +777,7 @@ private async Task chatOptions.ConversationId = typedSession.ConversationId; } - return (typedSession, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatHistoryProviderMessages, continuationToken); + return (typedSession, chatOptions, inputMessagesForProviders, inputMessagesForChatClient, continuationToken); } private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) @@ -816,8 +810,6 @@ private Task NotifyChatHistoryProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable requestMessages, - IEnumerable? chatHistoryProviderMessages, - IEnumerable? aiContextProviderMessages, ChatOptions? chatOptions, CancellationToken cancellationToken) { @@ -827,9 +819,8 @@ private Task NotifyChatHistoryProviderOfFailureAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, InvokeException = ex }; @@ -842,8 +833,6 @@ private Task NotifyChatHistoryProviderOfFailureAsync( private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatClientAgentSession session, IEnumerable requestMessages, - IEnumerable? chatHistoryProviderMessages, - IEnumerable? aiContextProviderMessages, IEnumerable responseMessages, ChatOptions? chatOptions, CancellationToken cancellationToken) @@ -854,9 +843,8 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; return provider.InvokedAsync(invokedContext, cancellationToken).AsTask(); diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 87adc9fd7a..c63e8ac682 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -25,8 +25,8 @@ namespace Microsoft.Agents.AI; /// abstractions to work with any compatible vector store implementation. /// /// -/// Messages are stored during the method and retrieved during the -/// method using semantic similarity search. +/// Messages are stored during the method and retrieved during the +/// method using semantic similarity search. /// /// /// Behavior is configurable through . When @@ -175,7 +175,7 @@ private ChatHistoryMemoryProvider( } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -189,6 +189,7 @@ public override async ValueTask InvokingAsync(InvokingContext context { // Get the text from the current request messages var requestText = string.Join("\n", context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); @@ -228,7 +229,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); @@ -244,6 +245,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); List> itemsToStore = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []) .Select(message => new Dictionary { diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index be9eba1365..ee87d4f00c 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -107,7 +107,7 @@ public TextSearchProvider( } /// - public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke) { @@ -117,7 +117,9 @@ public override async ValueTask InvokingAsync(InvokingContext context // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); - var requestMessagesText = context.RequestMessages.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); + var requestMessagesText = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) + .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) { if (sbInput.Length > 0) @@ -166,7 +168,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } /// - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { int limit = this._recentMessageMemoryLimit; if (limit <= 0) @@ -180,6 +182,7 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken } var messagesText = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []) .Where(m => this._recentMessageRolesIncluded.Contains(m.Role) && diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 94aa73858a..44d1be2e74 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -15,35 +16,155 @@ public class AIContextProviderTests private static readonly AIAgent s_mockAgent = new Mock().Object; private static readonly AgentSession s_mockSession = new Mock().Object; + #region InvokingAsync Message Stamping Tests + + [Fact] + public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() + { + // Arrange + var provider = new TestAIContextProviderWithMessages(); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSourceAsync() + { + // Arrange + const string CustomSourceName = "CustomContextSource"; + var provider = new TestAIContextProviderWithCustomSource(CustomSourceName); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(CustomSourceName, source); + } + + [Fact] + public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() + { + // Arrange + var provider = new TestAIContextProviderWithPreStampedMessages(); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + ChatMessage message = aiContext.Messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithPreStampedMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_StampsMultipleMessagesAsync() + { + // Arrange + var provider = new TestAIContextProviderWithMultipleMessages(); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(aiContext.Messages); + List messageList = aiContext.Messages.ToList(); + Assert.Equal(3, messageList.Count); + + foreach (ChatMessage message in messageList) + { + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestAIContextProviderWithMultipleMessages).FullName, source); + } + } + + [Fact] + public async Task InvokingAsync_WithNullMessages_ReturnsContextWithoutStampingAsync() + { + // Arrange + var provider = new TestAIContextProvider(); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + AIContext aiContext = await provider.InvokingAsync(context); + + // Assert + Assert.Null(aiContext.Messages); + } + + #endregion + + #region Basic Tests + [Fact] public async Task InvokedAsync_ReturnsCompletedTaskAsync() { + // Arrange var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); + + // Act + ValueTask task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); + + // Assert Assert.Equal(default, task); } [Fact] public void Serialize_ReturnsEmptyElement() { + // Arrange var provider = new TestAIContextProvider(); + + // Act var actual = provider.Serialize(); + + // Assert Assert.Equal(default, actual); } [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { + // Act & Assert Assert.Throws(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null)); + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!)); } + #endregion + #region GetService Method Tests /// @@ -246,7 +367,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -258,7 +379,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -267,28 +388,13 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } - [Fact] - public void InvokedContext_AIContextProviderMessages_Roundtrips() - { - // Arrange - var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); - - // Act - context.AIContextProviderMessages = aiContextMessages; - - // Assert - Assert.Same(aiContextMessages, context.AIContextProviderMessages); - } - [Fact] public void InvokedContext_ResponseMessages_Roundtrips() { // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.ResponseMessages = responseMessages; @@ -303,7 +409,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.InvokeException = exception; @@ -319,7 +425,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -332,7 +438,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockSession, context.Session); @@ -345,7 +451,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages); // Assert Assert.Null(context.Session); @@ -358,16 +464,66 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act & Assert - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages)); } #endregion private sealed class TestAIContextProvider : AIContextProvider { - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext()); + } + + private sealed class TestAIContextProviderWithMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [new ChatMessage(ChatRole.System, "Context Message")] + }); + } + + private sealed class TestAIContextProviderWithCustomSource : AIContextProvider + { + public TestAIContextProviderWithCustomSource(string sourceName) : base(sourceName) { - return default; } + + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [new ChatMessage(ChatRole.System, "Context Message")] + }); + } + + private sealed class TestAIContextProviderWithPreStampedMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var message = new ChatMessage(ChatRole.System, "Pre-stamped Message"); + message.AdditionalProperties = new AdditionalPropertiesDictionary + { + [AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.AIContextProvider, + [AgentRequestMessageSource.AdditionalPropertiesKey] = this.GetType().FullName! + }; + return new(new AIContext + { + Messages = [message] + }); + } + } + + private sealed class TestAIContextProviderWithMultipleMessages : AIContextProvider + { + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(new AIContext + { + Messages = [ + new ChatMessage(ChatRole.System, "Message 1"), + new ChatMessage(ChatRole.User, "Message 2"), + new ChatMessage(ChatRole.Assistant, "Message 3") + ] + }); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs new file mode 100644 index 0000000000..f6149092f3 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs @@ -0,0 +1,489 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentRequestMessageSourceTypeTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_WithValue_SetsValueProperty() + { + // Arrange + const string ExpectedValue = "CustomSource"; + + // Act + AgentRequestMessageSourceType source = new(ExpectedValue); + + // Assert + Assert.Equal(ExpectedValue, source.Value); + } + + [Fact] + public void Constructor_WithNullValue_Throws() + { + // Act & Assert + Assert.Throws(() => new AgentRequestMessageSourceType(null!)); + } + + [Fact] + public void Constructor_WithEmptyValue_Throws() + { + // Act & Assert + Assert.Throws(() => new AgentRequestMessageSourceType(string.Empty)); + } + + #endregion + + #region Static Properties Tests + + [Fact] + public void External_ReturnsInstanceWithExternalValue() + { + // Arrange & Act + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.External; + + // Assert + Assert.NotNull(source); + Assert.Equal("External", source.Value); + } + + [Fact] + public void AIContextProvider_ReturnsInstanceWithAIContextProviderValue() + { + // Arrange & Act + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.AIContextProvider; + + // Assert + Assert.NotNull(source); + Assert.Equal("AIContextProvider", source.Value); + } + + [Fact] + public void ChatHistory_ReturnsInstanceWithChatHistoryValue() + { + // Arrange & Act + AgentRequestMessageSourceType source = AgentRequestMessageSourceType.ChatHistory; + + // Assert + Assert.NotNull(source); + Assert.Equal("ChatHistory", source.Value); + } + + [Fact] + public void AdditionalPropertiesKey_ReturnsExpectedValue() + { + // Arrange & Act + string key = AgentRequestMessageSourceType.AdditionalPropertiesKey; + + // Assert + Assert.Equal("Agent.RequestMessageSourceType", key); + } + + [Fact] + public void StaticProperties_ReturnSameInstanceOnMultipleCalls() + { + // Arrange & Act + AgentRequestMessageSourceType external1 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType external2 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType aiContextProvider1 = AgentRequestMessageSourceType.AIContextProvider; + AgentRequestMessageSourceType aiContextProvider2 = AgentRequestMessageSourceType.AIContextProvider; + AgentRequestMessageSourceType chatHistory1 = AgentRequestMessageSourceType.ChatHistory; + AgentRequestMessageSourceType chatHistory2 = AgentRequestMessageSourceType.ChatHistory; + + // Assert + Assert.Same(external1, external2); + Assert.Same(aiContextProvider1, aiContextProvider2); + Assert.Same(chatHistory1, chatHistory2); + } + + #endregion + + #region Equals Tests + + [Fact] + public void Equals_WithSameInstance_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source = new("Test"); + + // Act + bool result = source.Equals(source); + + // Assert + Assert.True(result); + } + + [Fact] + public void Equals_WithEqualValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.True(result); + } + + [Fact] + public void Equals_WithDifferentValue_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_WithNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source = new("Test"); + + // Act + bool result = source.Equals(null); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_WithDifferentCase_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.False(result); + } + + [Fact] + public void Equals_StaticExternalWithNewInstanceHavingSameValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType newExternal = new("External"); + + // Act + bool result = external.Equals(newExternal); + + // Assert + Assert.True(result); + } + + #endregion + + #region Object.Equals Tests + + [Fact] + public void ObjectEquals_WithEqualAgentRequestMessageSource_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + object source2 = new AgentRequestMessageSourceType("Test"); + + // Act + bool result = source1.Equals(source2); + + // Assert + Assert.True(result); + } + + [Fact] + public void ObjectEquals_WithDifferentType_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source = new("Test"); + object other = "Test"; + + // Act + bool result = source.Equals(other); + + // Assert + Assert.False(result); + } + + [Fact] + public void ObjectEquals_WithNullObject_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source = new("Test"); + object? other = null; + + // Act + bool result = source.Equals(other); + + // Assert + Assert.False(result); + } + + #endregion + + #region GetHashCode Tests + + [Fact] + public void GetHashCode_WithSameValue_ReturnsSameHashCode() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + int hashCode1 = source1.GetHashCode(); + int hashCode2 = source2.GetHashCode(); + + // Assert + Assert.Equal(hashCode1, hashCode2); + } + + [Fact] + public void GetHashCode_WithDifferentValue_ReturnsDifferentHashCode() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); + + // Act + int hashCode1 = source1.GetHashCode(); + int hashCode2 = source2.GetHashCode(); + + // Assert + Assert.NotEqual(hashCode1, hashCode2); + } + + [Fact] + public void GetHashCode_ConsistentWithEquals() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); + + // Act & Assert + // If two objects are equal, they must have the same hash code + Assert.True(source1.Equals(source2)); + Assert.Equal(source1.GetHashCode(), source2.GetHashCode()); + } + + #endregion + + #region Equality Operator Tests + + [Fact] + public void EqualityOperator_WithEqualValues_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_WithDifferentValues_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithBothNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType? source2 = null; + + // Act + bool result = source1 == source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_WithLeftNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithRightNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType? source2 = null; + + // Act + bool result = source1 == source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void EqualityOperator_WithStaticInstances_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType external1 = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType external2 = AgentRequestMessageSourceType.External; + + // Act + bool result = external1 == external2; + + // Assert + Assert.True(result); + } + + [Fact] + public void EqualityOperator_StaticWithNewInstanceHavingSameValue_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType newExternal = new("External"); + + // Act + bool result = external == newExternal; + + // Assert + Assert.True(result); + } + + #endregion + + #region Inequality Operator Tests + + [Fact] + public void InequalityOperator_WithEqualValues_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void InequalityOperator_WithDifferentValues_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test1"); + AgentRequestMessageSourceType source2 = new("Test2"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_WithBothNull_ReturnsFalse() + { + // Arrange + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType? source2 = null; + + // Act + bool result = source1 != source2; + + // Assert + Assert.False(result); + } + + [Fact] + public void InequalityOperator_WithLeftNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType? source1 = null; + AgentRequestMessageSourceType source2 = new("Test"); + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_WithRightNull_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType source1 = new("Test"); + AgentRequestMessageSourceType? source2 = null; + + // Act + bool result = source1 != source2; + + // Assert + Assert.True(result); + } + + [Fact] + public void InequalityOperator_DifferentStaticInstances_ReturnsTrue() + { + // Arrange + AgentRequestMessageSourceType external = AgentRequestMessageSourceType.External; + AgentRequestMessageSourceType chatHistory = AgentRequestMessageSourceType.ChatHistory; + + // Act + bool result = external != chatHistory; + + // Assert + Assert.True(result); + } + + #endregion + + #region IEquatable Tests + + [Fact] + public void IEquatable_ImplementedCorrectly() + { + // Arrange + AgentRequestMessageSourceType source = new("Test"); + + // Act & Assert + Assert.IsAssignableFrom>(source); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index a74906c801..1244209a97 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -41,7 +42,8 @@ public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); providerMock - .Setup(p => p.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters( @@ -60,16 +62,20 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() { // Arrange Mock providerMock = new(); - List requestMessages = [new(ChatRole.User, "Hello")]; - List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, + new(ChatRole.User, "Hello") + ]; + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] }; ChatHistoryProvider.InvokedContext? capturedContext = null; providerMock - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, _) => capturedContext = ctx) .Returns(default(ValueTask)); @@ -106,17 +112,18 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe { // Arrange Mock providerMock = new(); - List requestMessages = [new(ChatRole.User, "Hello")]; - List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - List aiContextProviderMessages = [new(ChatRole.System, "Context")]; - ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) - { - AIContextProviderMessages = aiContextProviderMessages - }; + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, + new(ChatRole.User, "Hello"), + new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } } } + ]; + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages); ChatHistoryProvider.InvokedContext? capturedContext = null; providerMock - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, _) => capturedContext = ctx) .Returns(default(ValueTask)); @@ -127,6 +134,8 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe // Assert Assert.NotNull(capturedContext); - Assert.Null(capturedContext.AIContextProviderMessages); + Assert.Equal(2, capturedContext.RequestMessages.Count()); + Assert.Contains("System", capturedContext.RequestMessages.Select(x => x.Text)); + Assert.Contains("Hello", capturedContext.RequestMessages.Select(x => x.Text)); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 4b955a43c0..5b48d025be 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -65,7 +66,8 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(expectedMessages); var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); @@ -77,7 +79,9 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn Assert.Equal(2, result.Count); Assert.Equal("Hello", result[0].Text); Assert.Equal("Hi there!", result[1].Text); - innerProviderMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -94,7 +98,8 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); // Filter to only user messages @@ -108,7 +113,9 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() // Assert Assert.Equal(2, result.Count); Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role)); - innerProviderMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -124,7 +131,8 @@ public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock - .Setup(s => s.InvokingAsync(context, It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(innerMessages); // Filter that transforms messages @@ -147,28 +155,31 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() { // Arrange var innerProviderMock = new Mock(); - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var chatHistoryProviderMessages = new List { new(ChatRole.System, "System") }; + List requestMessages = + [ + new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, + new(ChatRole.User, "Hello"), + ]; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }; ChatHistoryProvider.InvokedContext? capturedContext = null; innerProviderMock - .Setup(s => s.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedContext = ctx) .Returns(default(ValueTask)); // Filter that modifies the context ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { - var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages) + var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); + return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages) { ResponseMessages = ctx.ResponseMessages, - AIContextProviderMessages = ctx.AIContextProviderMessages, InvokeException = ctx.InvokeException }; } @@ -182,7 +193,9 @@ ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedCont Assert.NotNull(capturedContext); Assert.Single(capturedContext.RequestMessages); Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text); - innerProviderMock.Verify(s => s.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + innerProviderMock + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 5e0fbe9817..e158b159ca 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -18,6 +19,92 @@ public class ChatHistoryProviderTests private static readonly AIAgent s_mockAgent = new Mock().Object; private static readonly AgentSession s_mockSession = new Mock().Object; + #region InvokingAsync Message Stamping Tests + + [Fact] + public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceAsync() + { + // Arrange + var provider = new TestChatHistoryProvider(); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProvider).FullName, source); + } + + [Fact] + public async Task InvokingAsync_WithCustomSourceName_StampsMessagesWithCustomSourceAsync() + { + // Arrange + const string CustomSourceName = "CustomHistorySource"; + var provider = new TestChatHistoryProviderWithCustomSource(CustomSourceName); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(CustomSourceName, source); + } + + [Fact] + public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() + { + // Arrange + var provider = new TestChatHistoryProviderWithPreStampedMessages(); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + ChatMessage message = messages.Single(); + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProviderWithPreStampedMessages).FullName, source); + } + + [Fact] + public async Task InvokingAsync_StampsMultipleMessagesAsync() + { + // Arrange + var provider = new TestChatHistoryProviderWithMultipleMessages(); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); + + // Act + IEnumerable messages = await provider.InvokingAsync(context); + + // Assert + List messageList = messages.ToList(); + Assert.Equal(3, messageList.Count); + + foreach (ChatMessage message in messageList) + { + Assert.NotNull(message.AdditionalProperties); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceType.AdditionalPropertiesKey, out object? sourceType)); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, sourceType); + Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSource.AdditionalPropertiesKey, out object? source)); + Assert.Equal(typeof(TestChatHistoryProviderWithMultipleMessages).FullName, source); + } + } + + #endregion + #region GetService Method Tests [Fact] @@ -172,7 +259,7 @@ public void InvokingContext_Constructor_ThrowsForNullAgent() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!)); } [Fact] @@ -180,7 +267,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -192,7 +279,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -201,43 +288,13 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } - [Fact] - public void InvokedContext_ChatHistoryProviderMessages_SetterRoundtrips() - { - // Arrange - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var newProviderMessages = new List { new(ChatRole.System, "System message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); - - // Act - context.ChatHistoryProviderMessages = newProviderMessages; - - // Assert - Assert.Same(newProviderMessages, context.ChatHistoryProviderMessages); - } - - [Fact] - public void InvokedContext_AIContextProviderMessages_Roundtrips() - { - // Arrange - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); - - // Act - context.AIContextProviderMessages = aiContextMessages; - - // Assert - Assert.Same(aiContextMessages, context.AIContextProviderMessages); - } - [Fact] public void InvokedContext_ResponseMessages_Roundtrips() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.ResponseMessages = responseMessages; @@ -252,7 +309,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act context.InvokeException = exception; @@ -268,7 +325,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -281,7 +338,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Assert Assert.Same(s_mockSession, context.Session); @@ -294,7 +351,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages); // Assert Assert.Null(context.Session); @@ -307,17 +364,69 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages)); } #endregion private sealed class TestChatHistoryProvider : ChatHistoryProvider { - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(Array.Empty()); + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new([new ChatMessage(ChatRole.User, "Test Message")]); - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithCustomSource : ChatHistoryProvider + { + public TestChatHistoryProviderWithCustomSource(string sourceName) : base(sourceName) + { + } + + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new([new ChatMessage(ChatRole.User, "Test Message")]); + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithPreStampedMessages : ChatHistoryProvider + { + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var message = new ChatMessage(ChatRole.User, "Pre-stamped Message"); + message.AdditionalProperties = new AdditionalPropertiesDictionary + { + [AgentRequestMessageSourceType.AdditionalPropertiesKey] = AgentRequestMessageSourceType.ChatHistory, + [AgentRequestMessageSource.AdditionalPropertiesKey] = this.GetType().FullName! + }; + return new([message]); + } + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + => default; + } + + private sealed class TestChatHistoryProviderWithMultipleMessages : ChatHistoryProvider + { + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new([ + new ChatMessage(ChatRole.User, "Message 1"), + new ChatMessage(ChatRole.Assistant, "Message 2"), + new ChatMessage(ChatRole.User, "Message 3") + ]); + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs new file mode 100644 index 0000000000..f389c567d2 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class ChatMessageExtensionsTests +{ + #region GetAgentRequestMessageSource Tests + + [Fact] + public void GetAgentRequestMessageSource_WithNoAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello"); + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithNullAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = null + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithEmptyAdditionalProperties_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary() + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithExternalSource_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.External } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithAIContextProviderSource_ReturnsAIContextProvider() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithChatHistorySource_ReturnsChatHistory() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithCustomSource_ReturnsCustomSource() + { + // Arrange + AgentRequestMessageSourceType customSource = new("CustomSource"); + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, customSource } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(customSource, result); + Assert.Equal("CustomSource", result.Value); + } + + [Fact] + public void GetAgentRequestMessageSource_WithWrongKeyType_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, "NotAnAgentRequestMessageSource" } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithNullValue_ReturnsExternal() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { AgentRequestMessageSourceType.AdditionalPropertiesKey, null! } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.External, result); + } + + [Fact] + public void GetAgentRequestMessageSource_WithMultipleProperties_ReturnsCorrectSource() + { + // Arrange + ChatMessage message = new(ChatRole.User, "Hello") + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { "OtherProperty", "SomeValue" }, + { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory }, + { "AnotherProperty", 123 } + } + }; + + // Act + AgentRequestMessageSourceType result = message.GetAgentRequestMessageSource(); + + // Assert + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index bf8ff998b9..75232073a6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -54,7 +54,8 @@ public async Task InvokedAsyncAddsMessagesAsync() { var requestMessages = new List { - new(ChatRole.User, "Hello") + new(ChatRole.User, "Hello"), + new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.ChatHistory } } }, }; var responseMessages = new List { @@ -64,16 +65,11 @@ public async Task InvokedAsyncAddsMessagesAsync() { new(ChatRole.System, "original instructions") }; - var aiContextProviderMessages = new List() - { - new(ChatRole.System, "additional context") - }; var provider = new InMemoryChatHistoryProvider(); provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; await provider.InvokedAsync(context, CancellationToken.None); @@ -90,7 +86,7 @@ public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -186,7 +182,7 @@ public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() var provider = new InMemoryChatHistoryProvider(); var messages = new List(); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -523,7 +519,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -582,7 +578,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -630,7 +626,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index 447c195c83..e2cbb16b1b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -3140,7 +3140,7 @@ private sealed class TestSchema /// private sealed class TestAIContextProvider : AIContextProvider { - public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { return new ValueTask(new AIContext()); } @@ -3151,12 +3151,12 @@ public override ValueTask InvokingAsync(InvokingContext context, Canc /// private sealed class TestChatHistoryProvider : ChatHistoryProvider { - public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { return new ValueTask>(Array.Empty()); } - public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index f6589ff9e3..e1d3c612c8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -217,7 +217,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]) { ResponseMessages = [] }; @@ -285,20 +285,16 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { new ChatMessage(ChatRole.User, "First message"), new ChatMessage(ChatRole.Assistant, "Second message"), - new ChatMessage(ChatRole.User, "Third message") - }; - var aiContextProviderMessages = new[] - { - new ChatMessage(ChatRole.System, "System context message") + new ChatMessage(ChatRole.User, "Third message"), + new ChatMessage(ChatRole.System, "System context message") { AdditionalProperties = new() { { AgentRequestMessageSourceType.AdditionalPropertiesKey, AgentRequestMessageSourceType.AIContextProvider } } } }; var responseMessages = new[] { new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { - AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; @@ -349,8 +345,8 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")]); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -394,7 +390,7 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added @@ -548,7 +544,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]); // Act await provider.InvokedAsync(context); @@ -605,7 +601,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); // Act await provider.InvokedAsync(context); @@ -640,8 +636,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")]); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -678,7 +674,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")]); await originalStore.InvokedAsync(context); // Act - Serialize the provider state @@ -720,8 +716,8 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")]); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -763,7 +759,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -801,7 +797,7 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); await provider.InvokedAsync(context); // Wait for eventual consistency diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index 81ca4eb588..a10f1246aa 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -56,7 +56,7 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input])); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); @@ -80,7 +80,7 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); @@ -108,7 +108,7 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index b886784af9..53c87b09ba 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -218,7 +218,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -245,7 +245,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -271,7 +271,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -321,7 +321,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index e1ff5f8cbd..41fb29bfed 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.UnitTests; @@ -342,7 +343,8 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, @@ -350,7 +352,8 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -378,12 +381,15 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Equal("context provider message", chatHistoryProvider[1].Text); Assert.Equal("response", chatHistoryProvider[2].Text); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == responseMessages && - x.InvokeException == null), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == responseMessages && + x.InvokeException == null), ItExpr.IsAny()); } /// @@ -394,7 +400,6 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() { // Arrange ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; - ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")]; ChatMessage[] aiContextProviderMessages = [new(ChatRole.System, "context provider message")]; Mock mockService = new(); mockService @@ -406,13 +411,15 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -421,12 +428,15 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); // Assert - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == null && - x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == null && + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); } /// @@ -458,7 +468,8 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext()); ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); @@ -474,7 +485,9 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA Assert.Equal(ChatRole.User, capturedMessages[0].Role); Assert.Single(capturedTools); Assert.Contains(capturedTools, t => t.Name == "base function"); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } #endregion @@ -1371,7 +1384,8 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, @@ -1379,7 +1393,8 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new( @@ -1414,13 +1429,16 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Equal("context provider message", chatHistoryProvider[1].Text); Assert.Equal("response", chatHistoryProvider[2].Text); - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages!.Count() == 1 && - x.ResponseMessages!.ElementAt(0).Text == "response" && - x.InvokeException == null), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages!.Count() == 1 && + x.ResponseMessages!.ElementAt(0).Text == "response" && + x.InvokeException == null), ItExpr.IsAny()); } /// @@ -1442,13 +1460,15 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA var mockProvider = new Mock(); mockProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = aiContextProviderMessages, }); mockProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); ChatClientAgent agent = new( @@ -1467,12 +1487,15 @@ await Assert.ThrowsAsync(async () => }); // Assert - mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); - mockProvider.Verify(p => p.InvokedAsync(It.Is(x => - x.RequestMessages == requestMessages && - x.AIContextProviderMessages == aiContextProviderMessages && - x.ResponseMessages == null && - x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once); + mockProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.RequestMessages.Count() == requestMessages.Length + aiContextProviderMessages.Length && + x.ResponseMessages == null && + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 87be3fb96e..2eed890292 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.UnitTests; @@ -339,13 +340,15 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync([new(ChatRole.User, "Message from chat history provider")]); // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = [new(ChatRole.System, "Message from AI context")], @@ -385,14 +388,14 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu Assert.Empty(capturedMessages); // Verify that chat history provider was never called due to continuation token - mockChatHistoryProvider.Verify( - ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); // Verify that AI context provider was never called due to continuation token - mockContextProvider.Verify( - p => p.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockContextProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -404,13 +407,15 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync([new(ChatRole.User, "Message from chat history provider")]); // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext { Messages = [new(ChatRole.System, "Message from AI context")], @@ -449,14 +454,14 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe Assert.Empty(capturedMessages); // Verify that chat history provider was never called due to continuation token - mockChatHistoryProvider.Verify( - ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); // Verify that AI context provider was never called due to continuation token - mockContextProvider.Verify( - p => p.InvokingAsync(It.IsAny(), It.IsAny()), - Times.Never); + mockContextProvider + .Protected() + .Verify>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); } [Fact] @@ -633,14 +638,16 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedMessagesAddedToProvider.AddRange(ctx.ResponseMessages ?? [])) .Returns(new ValueTask()); AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); @@ -662,11 +669,15 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial await agent.RunStreamingAsync(session, options: runOptions).ToListAsync(); // Assert - mockChatHistoryProvider.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.Single(capturedMessagesAddedToProvider); Assert.Contains("once upon a time", capturedMessagesAddedToProvider[0].Text); - mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockContextProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.NotNull(capturedInvokedContext?.ResponseMessages); Assert.Single(capturedInvokedContext.ResponseMessages); Assert.Contains("once upon a time", capturedInvokedContext.ResponseMessages.ElementAt(0).Text); @@ -689,14 +700,16 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider - .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((ctx, ct) => capturedMessagesAddedToProvider.AddRange(ctx.RequestMessages)) .Returns(new ValueTask()); AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); @@ -718,11 +731,15 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI await agent.RunStreamingAsync(session, options: runOptions).ToListAsync(); // Assert - mockChatHistoryProvider.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.Single(capturedMessagesAddedToProvider); Assert.Contains("Tell me a story", capturedMessagesAddedToProvider[0].Text); - mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + mockContextProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); Assert.NotNull(capturedInvokedContext?.RequestMessages); Assert.Single(capturedInvokedContext.RequestMessages); Assert.Contains("Tell me a story", capturedInvokedContext.RequestMessages.ElementAt(0).Text); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index a854a76622..4de8f01f8e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; using Xunit.Sdk; namespace Microsoft.Agents.AI.UnitTests; @@ -183,12 +184,14 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); Mock mockChatHistoryProvider = new(); - mockChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); - mockChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Returns(new ValueTask()); + mockChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + mockChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); Mock>> mockFactory = new(); mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); @@ -211,14 +214,16 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny(), It.IsAny()), Times.Once); - mockChatHistoryProvider.Verify(s => s.InvokingAsync( - It.Is(x => x.RequestMessages.Count() == 1), - It.IsAny()), - Times.Once); - mockChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), - It.IsAny()), - Times.Once); + mockChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1), + ItExpr.IsAny()); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), + ItExpr.IsAny()); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -253,10 +258,11 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() // Assert Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); - mockChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), - It.IsAny()), - Times.Once); + mockChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), + ItExpr.IsAny()); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -308,22 +314,26 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi // Arrange a chat history provider to override the factory provided one. Mock mockOverrideChatHistoryProvider = new(); - mockOverrideChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); - mockOverrideChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Returns(new ValueTask()); + mockOverrideChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + mockOverrideChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); // Arrange a chat history provider to provide to the agent via a factory at construction time. // This one shouldn't be used since it is being overridden. Mock mockFactoryChatHistoryProvider = new(); - mockFactoryChatHistoryProvider.Setup(s => s.InvokingAsync( - It.IsAny(), - It.IsAny())).ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - mockFactoryChatHistoryProvider.Setup(s => s.InvokedAsync( - It.IsAny(), - It.IsAny())).Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); + mockFactoryChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); + mockFactoryChatHistoryProvider + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); Mock>> mockFactory = new(); mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); @@ -348,23 +358,27 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny(), It.IsAny()), Times.Once); - mockOverrideChatHistoryProvider.Verify(s => s.InvokingAsync( - It.Is(x => x.RequestMessages.Count() == 1), - It.IsAny()), - Times.Once); - mockOverrideChatHistoryProvider.Verify(s => s.InvokedAsync( - It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), - It.IsAny()), - Times.Once); - - mockFactoryChatHistoryProvider.Verify(s => s.InvokingAsync( - It.IsAny(), - It.IsAny()), - Times.Never); - mockFactoryChatHistoryProvider.Verify(s => s.InvokedAsync( - It.IsAny(), - It.IsAny()), - Times.Never); + mockOverrideChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1), + ItExpr.IsAny()); + mockOverrideChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), + ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), + ItExpr.IsAny()); + + mockFactoryChatHistoryProvider + .Protected() + .Verify>>("InvokingCoreAsync", Times.Never(), + ItExpr.IsAny(), + ItExpr.IsAny()); + mockFactoryChatHistoryProvider + .Protected() + .Verify("InvokedCoreAsync", Times.Never(), + ItExpr.IsAny(), + ItExpr.IsAny()); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 360c3071ae..ec8dda3c45 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -345,7 +345,7 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, @@ -387,7 +387,7 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, @@ -423,22 +423,22 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc // First memory update (A,B) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ])); // Second memory update (C,D,E) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ])); var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); @@ -475,7 +475,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, @@ -533,7 +533,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory. + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -562,7 +562,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Act var state = provider.Serialize(); @@ -603,7 +603,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); + await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); var state = initialProvider.Serialize(); string? capturedInput = null; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index 8d3cad85ae..0a11e74528 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -119,7 +119,7 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls]) { ResponseMessages = [responseMsg] }; @@ -177,7 +177,7 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" }); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -206,7 +206,7 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() new ChatHistoryMemoryProviderScope() { UserId = "UID" }, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -257,7 +257,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None);