diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs index 733a7af9a7..aa49957152 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs @@ -29,7 +29,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = sp.GetRequiredService(); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -49,7 +49,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -70,7 +70,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -92,7 +92,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools); + return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools, services: sp); }); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs index 03ab65c9f2..47b06884d4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs @@ -1,7 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Moq; @@ -203,4 +207,120 @@ public void AddAIAgent_ValidSpecialCharactersInName_Succeeds(string name) d.ServiceType == typeof(AIAgent)); Assert.NotNull(descriptor); } + + /// + /// Verifies that AddAIAgent passes the application's IServiceProvider to the + /// ChatClientAgent, enabling tool dependency injection. Regression test for + /// https://github.com/microsoft/agent-framework/issues/4453. + /// + [Fact] + public void AddAIAgent_PassesServiceProvider_ToChatClientAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with a chat client instance passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithChatClient_PassesServiceProvider() + { + var services = new ServiceCollection(); + var chatClient = new MockChatClient(); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", chatClient); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with a chat client key passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithChatClientKey_PassesServiceProvider() + { + var services = new ServiceCollection(); + services.AddKeyedSingleton("myKey", new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", "myKey"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with description and chat client key passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithDescriptionAndKey_PassesServiceProvider() + { + var services = new ServiceCollection(); + services.AddKeyedSingleton("myKey", new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", "A test agent", "myKey"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that the FunctionInvokingChatClient in the agent's pipeline received + /// the application's IServiceProvider (not null) by checking that it can resolve + /// a service registered in the DI container. + /// + private static void AssertServiceProviderPassedThrough(ChatClientAgent agent) + { + var funcClient = agent.ChatClient.GetService(); + Assert.NotNull(funcClient); + + // Use reflection to access the internal IServiceProvider stored in FunctionInvokingChatClient. + // This verifies the application's service provider was forwarded, not null or an empty provider. + var spField = typeof(FunctionInvokingChatClient) + .GetField("_serviceProvider", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + + if (spField is not null) + { + var innerProvider = spField.GetValue(funcClient) as IServiceProvider; + Assert.NotNull(innerProvider); + var marker = innerProvider!.GetService(); + Assert.NotNull(marker); + } + } + + private interface IMarkerService; + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via DI")] + private sealed class MarkerService : IMarkerService; + + private sealed class MockChatClient : IChatClient + { + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() { } + } }