diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index f219c0c28f..d2bb2ac598 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -6,6 +6,7 @@ import inspect import logging import types +import typing from collections.abc import Awaitable, Callable from typing import Any, TypeVar, overload @@ -722,20 +723,30 @@ def _validate_handler_signature( if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter") + # Resolve string annotations from `from __future__ import annotations`. + # Fall back to raw annotations if resolution fails (e.g. unresolvable forward refs, + # AttributeError, or RecursionError), so registration failures are easier to diagnose. + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {p.name: p.annotation for p in params} + # Validate ctx parameter is WorkflowContext and extract type args ctx_param = params[2] - if skip_message_annotation and ctx_param.annotation == inspect.Parameter.empty: + ctx_annotation = type_hints.get(ctx_param.name, ctx_param.annotation) + if skip_message_annotation and ctx_annotation == inspect.Parameter.empty: # When explicit types are provided via @handler(input=..., output=...), # the ctx parameter doesn't need a type annotation - types come from the decorator. output_types: list[type[Any] | types.UnionType] = [] workflow_output_types: list[type[Any] | types.UnionType] = [] else: output_types, workflow_output_types = validate_workflow_context_annotation( - ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler" + ctx_annotation, f"parameter '{ctx_param.name}'", "Handler" ) - message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None - ctx_annotation = ctx_param.annotation + message_type = type_hints.get(message_param.name, message_param.annotation) + if message_type == inspect.Parameter.empty: + message_type = None return message_type, ctx_annotation, output_types, workflow_output_types diff --git a/python/packages/core/tests/workflow/test_executor_future.py b/python/packages/core/tests/workflow/test_executor_future.py new file mode 100644 index 0000000000..c0916b9cf7 --- /dev/null +++ b/python/packages/core/tests/workflow/test_executor_future.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +from typing import Any + +import pytest +from pydantic import BaseModel + +from agent_framework import Executor, WorkflowContext, handler + + +class MyTypeA(BaseModel): + pass + + +class MyTypeB(BaseModel): + pass + + +class MyTypeC(BaseModel): + pass + + +class TestExecutorFutureAnnotations: + """Test suite for Executor with from __future__ import annotations.""" + + def test_handler_decorator_future_annotations(self): + """Test @handler decorator works with stringified annotations (issue #3898).""" + + class MyExecutor(Executor): + @handler + async def example(self, input: str, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: + pass + + exec_instance = MyExecutor(id="test") + assert str in exec_instance._handlers + spec = exec_instance._handler_specs[0] + assert spec["message_type"] is str + assert spec["output_types"] == [MyTypeA] + assert spec["workflow_output_types"] == [MyTypeB] + + def test_handler_decorator_future_annotations_single_type_arg(self): + """Test @handler with single type argument and future annotations.""" + + class MyExecutor(Executor): + @handler + async def example(self, input: int, ctx: WorkflowContext[MyTypeA]) -> None: + pass + + exec_instance = MyExecutor(id="test") + assert int in exec_instance._handlers + spec = exec_instance._handler_specs[0] + assert spec["message_type"] is int + assert spec["output_types"] == [MyTypeA] + + def test_handler_decorator_future_annotations_complex(self): + """Test @handler with complex type annotations and future annotations.""" + + class MyExecutor(Executor): + @handler + async def example(self, data: dict[str, Any], ctx: WorkflowContext[list[str]]) -> None: + pass + + exec_instance = MyExecutor(id="test") + spec = exec_instance._handler_specs[0] + assert spec["message_type"] == dict[str, Any] + assert spec["output_types"] == [list[str]] + + def test_handler_decorator_future_annotations_bare_context(self): + """Test @handler with bare WorkflowContext and future annotations.""" + + class MyExecutor(Executor): + @handler + async def example(self, input: str, ctx: WorkflowContext) -> None: + pass + + exec_instance = MyExecutor(id="test") + assert str in exec_instance._handlers + spec = exec_instance._handler_specs[0] + assert spec["output_types"] == [] + assert spec["workflow_output_types"] == [] + + def test_handler_decorator_future_annotations_explicit_types(self): + """Test @handler with explicit type parameters under future annotations.""" + + class MyExecutor(Executor): + @handler(input=str, output=MyTypeA) + async def example(self, input, ctx) -> None: + pass + + exec_instance = MyExecutor(id="test") + assert str in exec_instance._handlers + spec = exec_instance._handler_specs[0] + assert spec["message_type"] is str + assert spec["output_types"] == [MyTypeA] + + def test_handler_decorator_future_annotations_union_context(self): + """Test @handler with union type context annotations and future annotations.""" + + class MyExecutor(Executor): + @handler + async def example(self, input: str, ctx: WorkflowContext[MyTypeA | MyTypeB, MyTypeC]) -> None: + pass + + exec_instance = MyExecutor(id="test") + assert str in exec_instance._handlers + spec = exec_instance._handler_specs[0] + assert spec["output_types"] == [MyTypeA, MyTypeB] + assert spec["workflow_output_types"] == [MyTypeC] + + def test_handler_unresolvable_annotation_raises(self): + """Test that an unresolvable forward-reference annotation raises ValueError. + + When get_type_hints fails (e.g. NameError for NonExistentType), the code falls back + to raw string annotations. The ctx parameter's raw string annotation is then not + recognised as a valid WorkflowContext type, so a ValueError is still raised. + """ + with pytest.raises(ValueError): + + class Bad(Executor): + @handler + async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 + pass