diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs index 6987c6aca3..df9afd5df7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs @@ -133,7 +133,18 @@ internal sealed class ExecutorProtocol(MessageRouter router, ISet sendType public bool CanHandle(Type type) => router.CanHandle(type); - public bool CanOutput(Type type) => this._yieldTypes.Contains(new(type)); + public bool CanOutput(Type type) + { + foreach (TypeId yieldType in this._yieldTypes) + { + if (yieldType.IsMatchPolymorphic(type)) + { + return true; + } + } + + return false; + } public ProtocolDescriptor Describe() => new(this.Router.IncomingTypes, yieldTypes, sendTypes, this.Router.HasCatchAll); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PolymorphicOutputTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PolymorphicOutputTests.cs new file mode 100644 index 0000000000..040975e6a0 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PolymorphicOutputTests.cs @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +/// +/// Regression tests for polymorphic output type handling in workflows. +/// Verifies that executors can return derived types when the declared output type is a base class. +/// +/// +/// This addresses GitHub issue #4134: InvalidOperationException when returning derived type as workflow output. +/// +public partial class PolymorphicOutputTests +{ + #region Test Type Hierarchy + + /// + /// Base class used as declared output type. + /// + public class BaseOutput + { + public virtual string Name => "BaseOutput"; + } + + /// + /// Derived class returned at runtime. + /// + public class DerivedOutput : BaseOutput + { + public override string Name => "DerivedOutput"; + } + + /// + /// Second-level derived class for testing multiple inheritance levels. + /// + public class GrandchildOutput : DerivedOutput + { + public override string Name => "GrandchildOutput"; + } + + /// + /// Unrelated class that should NOT be accepted as output. + /// + public class UnrelatedOutput + { + public string Name => "UnrelatedOutput"; + } + + #endregion + + #region Test Executors + + /// + /// Executor that declares BaseOutput as yield type but returns DerivedOutput. + /// + internal sealed class DerivedOutputExecutor() : Executor(nameof(DerivedOutputExecutor)) + { + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler(this.HandleAsync)); + } + + private async ValueTask HandleAsync(string input, IWorkflowContext context, CancellationToken cancellationToken) + { + await Task.Delay(10, cancellationToken); + + // Arrange: Return a derived type where the method signature declares the base type + return new DerivedOutput(); + } + } + + /// + /// Executor that declares BaseOutput as yield type but returns GrandchildOutput (two levels deep). + /// + internal sealed class GrandchildOutputExecutor() : Executor(nameof(GrandchildOutputExecutor)) + { + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler(this.HandleAsync)); + } + + private async ValueTask HandleAsync(string input, IWorkflowContext context, CancellationToken cancellationToken) + { + await Task.Delay(10, cancellationToken); + + // Arrange: Return a grandchild type (two inheritance levels) + return new GrandchildOutput(); + } + } + + /// + /// Executor that attempts to return an unrelated type - should fail validation. + /// This executor intentionally bypasses type safety to test runtime validation. + /// + internal sealed class UnrelatedOutputExecutor() : Executor(nameof(UnrelatedOutputExecutor)) + { + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler(this.HandleAsync)); + } + + private async ValueTask HandleAsync(string input, IWorkflowContext context, CancellationToken cancellationToken) + { + // Arrange: Attempt to yield an unrelated type - should throw + UnrelatedOutput unrelated = new(); + await context.YieldOutputAsync(unrelated, cancellationToken).ConfigureAwait(false); + + // This line should not be reached + return new BaseOutput(); + } + } + + /// + /// Executor that returns the exact declared type (baseline test). + /// + internal sealed class ExactTypeExecutor() : Executor(nameof(ExactTypeExecutor)) + { + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return protocolBuilder.ConfigureRoutes(routeBuilder => + routeBuilder.AddHandler(this.HandleAsync)); + } + + private ValueTask HandleAsync(string input, IWorkflowContext context, CancellationToken cancellationToken) + { + BaseOutput result = new(); + return new ValueTask(result); + } + } + + #endregion + + #region Tests + + /// + /// Verifies that returning a derived type when the declared output type is a base class succeeds. + /// This is the main regression test for GitHub issue #4134. + /// + [Fact] + public async Task ReturningDerivedType_WhenBaseTypeIsDeclared_ShouldSucceedAsync() + { + // Arrange + DerivedOutputExecutor executor = new(); + WorkflowBuilder builder = new WorkflowBuilder(executor).WithOutputFrom(executor); + Workflow workflow = builder.Build(); + + // Act + List events = []; + await using StreamingRun run = await InProcessExecution.RunStreamingAsync(workflow, "test input"); + await foreach (WorkflowEvent evt in run.WatchStreamAsync()) + { + events.Add(evt); + } + + // Assert + events.Should().NotBeEmpty("workflow should produce events"); + + List outputEvents = events.OfType().ToList(); + outputEvents.Should().ContainSingle("workflow should produce exactly one output event"); + + WorkflowOutputEvent outputEvent = outputEvents.Single(); + outputEvent.Data.Should().BeOfType("output should be the derived type"); + ((DerivedOutput)outputEvent.Data!).Name.Should().Be("DerivedOutput"); + + // Verify no error events + List errorEvents = events.OfType().ToList(); + errorEvents.Should().BeEmpty("workflow should not produce error events"); + } + + /// + /// Verifies that returning a grandchild type (multiple inheritance levels) succeeds. + /// + [Fact] + public async Task ReturningGrandchildType_WhenBaseTypeIsDeclared_ShouldSucceedAsync() + { + // Arrange + GrandchildOutputExecutor executor = new(); + WorkflowBuilder builder = new WorkflowBuilder(executor).WithOutputFrom(executor); + Workflow workflow = builder.Build(); + + // Act + List events = []; + await using StreamingRun run = await InProcessExecution.RunStreamingAsync(workflow, "test input"); + await foreach (WorkflowEvent evt in run.WatchStreamAsync()) + { + events.Add(evt); + } + + // Assert + events.Should().NotBeEmpty("workflow should produce events"); + + List outputEvents = events.OfType().ToList(); + outputEvents.Should().ContainSingle("workflow should produce exactly one output event"); + + WorkflowOutputEvent outputEvent = outputEvents.Single(); + outputEvent.Data.Should().BeOfType("output should be the grandchild type"); + ((GrandchildOutput)outputEvent.Data!).Name.Should().Be("GrandchildOutput"); + + // Verify no error events + List errorEvents = events.OfType().ToList(); + errorEvents.Should().BeEmpty("workflow should not produce error events"); + } + + /// + /// Verifies that returning an unrelated type still throws InvalidOperationException. + /// This ensures the fix doesn't break the existing validation for truly incompatible types. + /// + [Fact] + public async Task ReturningUnrelatedType_WhenBaseTypeIsDeclared_ShouldFailAsync() + { + // Arrange + UnrelatedOutputExecutor executor = new(); + WorkflowBuilder builder = new WorkflowBuilder(executor).WithOutputFrom(executor); + Workflow workflow = builder.Build(); + + // Act + List events = []; + await using StreamingRun run = await InProcessExecution.RunStreamingAsync(workflow, "test input"); + await foreach (WorkflowEvent evt in run.WatchStreamAsync()) + { + events.Add(evt); + } + + // Assert: Should have an error event with InvalidOperationException message + List errorEvents = events.OfType().ToList(); + errorEvents.Should().ContainSingle("workflow should produce exactly one error event"); + + WorkflowErrorEvent errorEvent = errorEvents.Single(); + string errorMessage = errorEvent.Data?.ToString() ?? string.Empty; + errorMessage.Should().Contain("Cannot output object of type UnrelatedOutput"); + errorMessage.Should().Contain("BaseOutput"); + } + + /// + /// Verifies that returning the exact declared type still works (baseline test). + /// + [Fact] + public async Task ReturningExactType_WhenSameTypeIsDeclared_ShouldSucceedAsync() + { + // Arrange: Create an executor that returns the exact declared type + ExactTypeExecutor executor = new(); + WorkflowBuilder builder = new WorkflowBuilder(executor).WithOutputFrom(executor); + Workflow workflow = builder.Build(); + + // Act + List events = []; + await using StreamingRun run = await InProcessExecution.RunStreamingAsync(workflow, "test input"); + await foreach (WorkflowEvent evt in run.WatchStreamAsync()) + { + events.Add(evt); + } + + // Assert + events.Should().NotBeEmpty("workflow should produce events"); + + List outputEvents = events.OfType().ToList(); + outputEvents.Should().ContainSingle("workflow should produce exactly one output event"); + + WorkflowOutputEvent outputEvent = outputEvents.Single(); + outputEvent.Data.Should().BeOfType("output should be the exact base type"); + + // Verify no error events + List errorEvents = events.OfType().ToList(); + errorEvents.Should().BeEmpty("workflow should not produce error events"); + } + + #endregion +}