Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
{
var chatClient = sp.GetRequiredService<IChatClient>();
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp);
});
}

Expand All @@ -49,7 +49,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
return services.AddAIAgent(name, (sp, key) =>
{
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp);
});
}

Expand All @@ -70,7 +70,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
return new ChatClientAgent(chatClient, instructions, key, tools: tools, services: sp);
});
}

Expand All @@ -92,7 +92,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools);
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools, services: sp);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Moq;
Expand Down Expand Up @@ -203,4 +207,120 @@ public void AddAIAgent_ValidSpecialCharactersInName_Succeeds(string name)
d.ServiceType == typeof(AIAgent));
Assert.NotNull(descriptor);
}

/// <summary>
/// Verifies that AddAIAgent passes the application's IServiceProvider to the
/// ChatClientAgent, enabling tool dependency injection. Regression test for
/// https://github.com/microsoft/agent-framework/issues/4453.
/// </summary>
[Fact]
public void AddAIAgent_PassesServiceProvider_ToChatClientAgent()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
services.AddSingleton<IMarkerService, MarkerService>();
services.AddAIAgent("test-agent", "Test instructions");

var serviceProvider = services.BuildServiceProvider();
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>("test-agent") as ChatClientAgent;

Assert.NotNull(agent);
AssertServiceProviderPassedThrough(agent!);
}

/// <summary>
/// Verifies that AddAIAgent with a chat client instance passes the IServiceProvider.
/// </summary>
[Fact]
public void AddAIAgent_WithChatClient_PassesServiceProvider()
{
var services = new ServiceCollection();
var chatClient = new MockChatClient();
services.AddSingleton<IMarkerService, MarkerService>();
services.AddAIAgent("test-agent", "Test instructions", chatClient);

var serviceProvider = services.BuildServiceProvider();
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>("test-agent") as ChatClientAgent;

Assert.NotNull(agent);
AssertServiceProviderPassedThrough(agent!);
}

/// <summary>
/// Verifies that AddAIAgent with a chat client key passes the IServiceProvider.
/// </summary>
[Fact]
public void AddAIAgent_WithChatClientKey_PassesServiceProvider()
{
var services = new ServiceCollection();
services.AddKeyedSingleton<IChatClient>("myKey", new MockChatClient());
services.AddSingleton<IMarkerService, MarkerService>();
services.AddAIAgent("test-agent", "Test instructions", "myKey");

var serviceProvider = services.BuildServiceProvider();
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>("test-agent") as ChatClientAgent;

Assert.NotNull(agent);
AssertServiceProviderPassedThrough(agent!);
}

/// <summary>
/// Verifies that AddAIAgent with description and chat client key passes the IServiceProvider.
/// </summary>
[Fact]
public void AddAIAgent_WithDescriptionAndKey_PassesServiceProvider()
{
var services = new ServiceCollection();
services.AddKeyedSingleton<IChatClient>("myKey", new MockChatClient());
services.AddSingleton<IMarkerService, MarkerService>();
services.AddAIAgent("test-agent", "Test instructions", "A test agent", "myKey");

var serviceProvider = services.BuildServiceProvider();
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>("test-agent") as ChatClientAgent;

Assert.NotNull(agent);
AssertServiceProviderPassedThrough(agent!);
}

/// <summary>
/// Verifies that the FunctionInvokingChatClient in the agent's pipeline received
/// the application's IServiceProvider (not null) by checking that it can resolve
/// a service registered in the DI container.
/// </summary>
private static void AssertServiceProviderPassedThrough(ChatClientAgent agent)
{
var funcClient = agent.ChatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(funcClient);

// Use reflection to access the internal IServiceProvider stored in FunctionInvokingChatClient.
// This verifies the application's service provider was forwarded, not null or an empty provider.
var spField = typeof(FunctionInvokingChatClient)
.GetField("_serviceProvider", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);

if (spField is not null)
{
var innerProvider = spField.GetValue(funcClient) as IServiceProvider;
Assert.NotNull(innerProvider);
var marker = innerProvider!.GetService<IMarkerService>();
Assert.NotNull(marker);
}
Comment on lines +299 to +306
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if (spField is not null) guard means that if FunctionInvokingChatClient renames or removes the _serviceProvider field in a future version of Microsoft.Extensions.AI, the reflection lookup silently returns null, all the critical assertions inside the block are skipped, and the test passes without actually verifying anything. This undermines the purpose of the regression test.

Consider adding Assert.NotNull(spField) (or using Assert.Fail in the else branch) to ensure the test fails loudly if the internal field disappears, making it clear that the verification approach needs to be updated.

Suggested change
if (spField is not null)
{
var innerProvider = spField.GetValue(funcClient) as IServiceProvider;
Assert.NotNull(innerProvider);
var marker = innerProvider!.GetService<IMarkerService>();
Assert.NotNull(marker);
}
Assert.NotNull(spField);
var innerProvider = spField!.GetValue(funcClient) as IServiceProvider;
Assert.NotNull(innerProvider);
var marker = innerProvider!.GetService<IMarkerService>();
Assert.NotNull(marker);

Copilot uses AI. Check for mistakes.
}

private interface IMarkerService;

[System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via DI")]
private sealed class MarkerService : IMarkerService;

private sealed class MockChatClient : IChatClient
{
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public object? GetService(Type serviceType, object? serviceKey = null) => null;

public void Dispose() { }
}
}
Loading