Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions graphene/types/tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,50 @@ class Meta:
my_union_field = my_union_instance.mount_as(Field)
assert isinstance(my_union_field, Field)
assert my_union_field.type == MyUnion


def test_generate_union_with_string_references():
class MyUnion(Union):
class Meta:
types = (
"graphene.types.tests.test_union.MyObjectType1",
"graphene.types.tests.test_union.MyObjectType2",
)

assert MyUnion._meta.types == (MyObjectType1, MyObjectType2)


def test_generate_union_with_lambda_references():
class MyUnion(Union):
class Meta:
types = (
lambda: MyObjectType1,
lambda: MyObjectType2,
)

assert MyUnion._meta.types == (MyObjectType1, MyObjectType2)


def test_generate_union_with_mixed_references():
class MyUnion(Union):
class Meta:
types = (
MyObjectType1,
lambda: MyObjectType2,
"graphene.types.tests.test_union.MyObjectType1",
)

assert MyUnion._meta.types == (MyObjectType1, MyObjectType2, MyObjectType1)


def test_union_types_property_is_iterable():
class MyUnion(Union):
class Meta:
types = (
lambda: MyObjectType1,
"graphene.types.tests.test_union.MyObjectType2",
)

first_access = MyUnion._meta.types
second_access = MyUnion._meta.types
assert first_access == second_access
19 changes: 14 additions & 5 deletions graphene/types/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
from .utils import get_type

# For static type checking with type checker
if TYPE_CHECKING:
from .objecttype import ObjectType # NOQA
from typing import Iterable, Type # NOQA
from typing import Callable, Iterable, Tuple, Type # NOQA


class UnionOptions(BaseOptions):
types = () # type: Iterable[Type[ObjectType]]
_types = () # type: Iterable[Type[ObjectType] | str | Callable[[], Type[ObjectType]]]

@property
def types(self) -> Tuple[Type["ObjectType"], ...]:
return tuple(get_type(type_) for type_ in self._types)


class Union(UnmountedType, BaseType):
Expand All @@ -28,6 +33,9 @@ class Union(UnmountedType, BaseType):
attribute or ``is_type_of`` method. Or by implementing ``resolve_type`` class method on the
Union.

To avoid circular import issues, `Meta.types` can also contain string references or functions
instead of direct type references.

.. code:: python

from graphene import Union, ObjectType, List
Expand All @@ -42,8 +50,9 @@ class Query(ObjectType):
)

Meta:
types (Iterable[graphene.ObjectType]): Required. Collection of types that may be returned
by this Union for the graphQL schema.
types (Iterable[type[graphene.ObjectType] | str | Callable[[], type[graphene.ObjectType]]]): Required.
Collection of types that may be returned by this Union for the graphQL schema. Can also be string
references or functions for lazy resolution.
name (optional, str): the name of the GraphQL type (must be unique in schema). Defaults to class
name.
description (optional, str): the description of the GraphQL type in the schema. Defaults to class
Expand All @@ -59,7 +68,7 @@ def __init_subclass_with_meta__(cls, types=None, _meta=None, **options):
if not _meta:
_meta = UnionOptions(cls)

_meta.types = types
_meta._types = types
super(Union, cls).__init_subclass_with_meta__(_meta=_meta, **options)

@classmethod
Expand Down