diff --git a/graphene/types/tests/test_union.py b/graphene/types/tests/test_union.py index 4d642d6f..4897dfae 100644 --- a/graphene/types/tests/test_union.py +++ b/graphene/types/tests/test_union.py @@ -56,3 +56,50 @@ class Meta: my_union_field = my_union_instance.mount_as(Field) assert isinstance(my_union_field, Field) assert my_union_field.type == MyUnion + + +def test_generate_union_with_string_references(): + class MyUnion(Union): + class Meta: + types = ( + "graphene.types.tests.test_union.MyObjectType1", + "graphene.types.tests.test_union.MyObjectType2", + ) + + assert MyUnion._meta.types == (MyObjectType1, MyObjectType2) + + +def test_generate_union_with_lambda_references(): + class MyUnion(Union): + class Meta: + types = ( + lambda: MyObjectType1, + lambda: MyObjectType2, + ) + + assert MyUnion._meta.types == (MyObjectType1, MyObjectType2) + + +def test_generate_union_with_mixed_references(): + class MyUnion(Union): + class Meta: + types = ( + MyObjectType1, + lambda: MyObjectType2, + "graphene.types.tests.test_union.MyObjectType1", + ) + + assert MyUnion._meta.types == (MyObjectType1, MyObjectType2, MyObjectType1) + + +def test_union_types_property_is_iterable(): + class MyUnion(Union): + class Meta: + types = ( + lambda: MyObjectType1, + "graphene.types.tests.test_union.MyObjectType2", + ) + + first_access = MyUnion._meta.types + second_access = MyUnion._meta.types + assert first_access == second_access diff --git a/graphene/types/union.py b/graphene/types/union.py index 3d10418e..e4333143 100644 --- a/graphene/types/union.py +++ b/graphene/types/union.py @@ -2,15 +2,20 @@ from .base import BaseOptions, BaseType from .unmountedtype import UnmountedType +from .utils import get_type # For static type checking with type checker if TYPE_CHECKING: from .objecttype import ObjectType # NOQA - from typing import Iterable, Type # NOQA + from typing import Callable, Iterable, Tuple, Type # NOQA class UnionOptions(BaseOptions): - types = () # type: Iterable[Type[ObjectType]] + _types = () # type: Iterable[Type[ObjectType] | str | Callable[[], Type[ObjectType]]] + + @property + def types(self) -> Tuple[Type["ObjectType"], ...]: + return tuple(get_type(type_) for type_ in self._types) class Union(UnmountedType, BaseType): @@ -28,6 +33,9 @@ class Union(UnmountedType, BaseType): attribute or ``is_type_of`` method. Or by implementing ``resolve_type`` class method on the Union. + To avoid circular import issues, `Meta.types` can also contain string references or functions + instead of direct type references. + .. code:: python from graphene import Union, ObjectType, List @@ -42,8 +50,9 @@ class Query(ObjectType): ) Meta: - types (Iterable[graphene.ObjectType]): Required. Collection of types that may be returned - by this Union for the graphQL schema. + types (Iterable[type[graphene.ObjectType] | str | Callable[[], type[graphene.ObjectType]]]): Required. + Collection of types that may be returned by this Union for the graphQL schema. Can also be string + references or functions for lazy resolution. name (optional, str): the name of the GraphQL type (must be unique in schema). Defaults to class name. description (optional, str): the description of the GraphQL type in the schema. Defaults to class @@ -59,7 +68,7 @@ def __init_subclass_with_meta__(cls, types=None, _meta=None, **options): if not _meta: _meta = UnionOptions(cls) - _meta.types = types + _meta._types = types super(Union, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod