From a3d7b31122c0ad69aba07fc0aa445e1090b302fb Mon Sep 17 00:00:00 2001 From: max-montes <77820353+max-montes@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:49:09 -0800 Subject: [PATCH] fix: pass IServiceProvider to ChatClientAgent in AddAIAgent overloads All four AddAIAgent overloads in AgentHostingServiceCollectionExtensions were creating ChatClientAgent without forwarding the IServiceProvider. This meant tools registered via dependency injection could not resolve their dependencies at invocation time. Pass services: sp to each ChatClientAgent constructor call so that the FunctionInvokingChatClient middleware receives the application's service provider and can resolve tool dependencies correctly. Fixes #4453 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...AgentHostingServiceCollectionExtensions.cs | 8 +- ...HostingServiceCollectionExtensionsTests.cs | 120 ++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs index 733a7af9a7..aa49957152 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs @@ -29,7 +29,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = sp.GetRequiredService(); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -49,7 +49,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -70,7 +70,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions, key, tools: tools); + return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp); }); } @@ -92,7 +92,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); var tools = sp.GetKeyedServices(name).ToList(); - return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools); + return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools, services: sp); }); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs index 03ab65c9f2..47b06884d4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/AgentHostingServiceCollectionExtensionsTests.cs @@ -1,7 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Moq; @@ -203,4 +207,120 @@ public void AddAIAgent_ValidSpecialCharactersInName_Succeeds(string name) d.ServiceType == typeof(AIAgent)); Assert.NotNull(descriptor); } + + /// + /// Verifies that AddAIAgent passes the application's IServiceProvider to the + /// ChatClientAgent, enabling tool dependency injection. Regression test for + /// https://github.com/microsoft/agent-framework/issues/4453. + /// + [Fact] + public void AddAIAgent_PassesServiceProvider_ToChatClientAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with a chat client instance passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithChatClient_PassesServiceProvider() + { + var services = new ServiceCollection(); + var chatClient = new MockChatClient(); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", chatClient); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with a chat client key passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithChatClientKey_PassesServiceProvider() + { + var services = new ServiceCollection(); + services.AddKeyedSingleton("myKey", new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", "myKey"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that AddAIAgent with description and chat client key passes the IServiceProvider. + /// + [Fact] + public void AddAIAgent_WithDescriptionAndKey_PassesServiceProvider() + { + var services = new ServiceCollection(); + services.AddKeyedSingleton("myKey", new MockChatClient()); + services.AddSingleton(); + services.AddAIAgent("test-agent", "Test instructions", "A test agent", "myKey"); + + var serviceProvider = services.BuildServiceProvider(); + var agent = serviceProvider.GetRequiredKeyedService("test-agent") as ChatClientAgent; + + Assert.NotNull(agent); + AssertServiceProviderPassedThrough(agent!); + } + + /// + /// Verifies that the FunctionInvokingChatClient in the agent's pipeline received + /// the application's IServiceProvider (not null) by checking that it can resolve + /// a service registered in the DI container. + /// + private static void AssertServiceProviderPassedThrough(ChatClientAgent agent) + { + var funcClient = agent.ChatClient.GetService(); + Assert.NotNull(funcClient); + + // Use reflection to access the internal IServiceProvider stored in FunctionInvokingChatClient. + // This verifies the application's service provider was forwarded, not null or an empty provider. + var spField = typeof(FunctionInvokingChatClient) + .GetField("_serviceProvider", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + + if (spField is not null) + { + var innerProvider = spField.GetValue(funcClient) as IServiceProvider; + Assert.NotNull(innerProvider); + var marker = innerProvider!.GetService(); + Assert.NotNull(marker); + } + } + + private interface IMarkerService; + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via DI")] + private sealed class MarkerService : IMarkerService; + + private sealed class MockChatClient : IChatClient + { + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => throw new NotImplementedException(); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() { } + } }