Skip to content

Commit e7a5f77

Browse files
committed
build/extend_schema: add support for extensions
Replicates graphql/graphql-js@1283c84
1 parent e45316e commit e7a5f77

File tree

5 files changed

+304
-27
lines changed

5 files changed

+304
-27
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The current stable version 3.0.1 of GraphQL-core is up-to-date
1616
with GraphQL.js version 14.5.8.
1717

1818
All parts of the API are covered by an extensive test suite
19-
of currently 2044 unit tests.
19+
of currently 2050 unit tests.
2020

2121

2222
## Documentation

src/graphql/utilities/build_ast_schema.py

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
from typing import Callable, Collection, Dict, List, NoReturn, Optional, Union, cast
1+
from collections import defaultdict
2+
from typing import (
3+
Callable,
4+
Collection,
5+
DefaultDict,
6+
Dict,
7+
List,
8+
NoReturn,
9+
Optional,
10+
Union,
11+
cast,
12+
)
213

314
from ..language import (
415
DirectiveDefinitionNode,
@@ -20,10 +31,12 @@
2031
ObjectTypeExtensionNode,
2132
OperationType,
2233
ScalarTypeDefinitionNode,
34+
ScalarTypeExtensionNode,
2335
SchemaDefinitionNode,
2436
SchemaExtensionNode,
2537
Source,
2638
TypeDefinitionNode,
39+
TypeExtensionNode,
2740
TypeNode,
2841
UnionTypeDefinitionNode,
2942
UnionTypeExtensionNode,
@@ -102,15 +115,23 @@ def build_ast_schema(
102115

103116
assert_valid_sdl(document_ast)
104117

118+
# Collect the definitions and extensions found in the document.
105119
schema_def: Optional[SchemaDefinitionNode] = None
120+
schema_extensions: List[SchemaExtensionNode] = []
106121
type_defs: List[TypeDefinitionNode] = []
122+
type_extensions_map: DefaultDict[str, List[TypeExtensionNode]] = defaultdict(list)
107123
directive_defs: List[DirectiveDefinitionNode] = []
108124
append_directive_def = directive_defs.append
109125
for def_ in document_ast.definitions:
110126
if isinstance(def_, SchemaDefinitionNode):
111127
schema_def = def_
128+
elif isinstance(def_, SchemaExtensionNode):
129+
schema_extensions.append(def_)
112130
elif isinstance(def_, TypeDefinitionNode):
113131
type_defs.append(def_)
132+
elif isinstance(def_, TypeExtensionNode):
133+
extended_type_name = def_.name.value
134+
type_extensions_map[extended_type_name].append(def_)
114135
elif isinstance(def_, DirectiveDefinitionNode):
115136
append_directive_def(def_)
116137

@@ -124,10 +145,10 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
124145
assume_valid=assume_valid, resolve_type=resolve_type
125146
)
126147

127-
type_map = ast_builder.build_type_map(type_defs)
148+
type_map = ast_builder.build_type_map(type_defs, type_extensions_map)
128149

129150
operation_types: Dict[OperationType, GraphQLObjectType] = (
130-
ast_builder.get_operation_types([schema_def])
151+
ast_builder.get_operation_types([schema_def, *schema_extensions])
131152
if schema_def
132153
else {
133154
# Note: While this could make early assertions to get the correctly
@@ -159,6 +180,7 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
159180
types=type_map.values(),
160181
directives=directives,
161182
ast_node=schema_def,
183+
extension_ast_nodes=schema_extensions,
162184
assume_valid=assume_valid,
163185
)
164186

@@ -359,15 +381,23 @@ def build_union_types(
359381
return types
360382

361383
def build_type_map(
362-
self, nodes: Collection[TypeDefinitionNode]
384+
self,
385+
nodes: Collection[TypeDefinitionNode],
386+
extensions_map: DefaultDict[str, List[TypeExtensionNode]],
363387
) -> Dict[str, GraphQLNamedType]:
364388
type_map: Dict[str, GraphQLNamedType] = {}
365389
for node in nodes:
366390
name = node.name.value
367-
type_map[name] = std_type_map.get(name) or self._build_type(node)
391+
type_map[name] = std_type_map.get(name) or self._build_type(
392+
node, extensions_map[name]
393+
)
368394
return type_map
369395

370-
def _build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
396+
def _build_type(
397+
self,
398+
ast_node: TypeDefinitionNode,
399+
extension_nodes: Collection[TypeExtensionNode],
400+
) -> GraphQLNamedType:
371401
try:
372402
# object_type_definition_node is built with _build_object_type etc.
373403
method = getattr(self, "_build_" + ast_node.kind[:-11])
@@ -377,62 +407,103 @@ def _build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
377407
f"Unexpected type definition node: {inspect(ast_node)}."
378408
)
379409
else:
380-
return method(ast_node)
410+
return method(ast_node, extension_nodes)
381411

382412
def _build_object_type(
383-
self, ast_node: ObjectTypeDefinitionNode
413+
self,
414+
ast_node: ObjectTypeDefinitionNode,
415+
extension_nodes: Collection[ObjectTypeExtensionNode],
384416
) -> GraphQLObjectType:
417+
all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [
418+
ast_node,
419+
*extension_nodes,
420+
]
385421
return GraphQLObjectType(
386422
name=ast_node.name.value,
387423
description=ast_node.description.value if ast_node.description else None,
388-
interfaces=lambda: self.build_interfaces([ast_node]),
389-
fields=lambda: self.build_field_map([ast_node]),
424+
interfaces=lambda: self.build_interfaces(all_nodes),
425+
fields=lambda: self.build_field_map(all_nodes),
390426
ast_node=ast_node,
427+
extension_ast_nodes=extension_nodes,
391428
)
392429

393430
def _build_interface_type(
394-
self, ast_node: InterfaceTypeDefinitionNode
431+
self,
432+
ast_node: InterfaceTypeDefinitionNode,
433+
extension_nodes: Collection[InterfaceTypeExtensionNode],
395434
) -> GraphQLInterfaceType:
435+
all_nodes: List[
436+
Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode]
437+
] = [ast_node, *extension_nodes]
396438
return GraphQLInterfaceType(
397439
name=ast_node.name.value,
398440
description=ast_node.description.value if ast_node.description else None,
399-
interfaces=lambda: self.build_interfaces([ast_node]),
400-
fields=lambda: self.build_field_map([ast_node]),
441+
interfaces=lambda: self.build_interfaces(all_nodes),
442+
fields=lambda: self.build_field_map(all_nodes),
401443
ast_node=ast_node,
444+
extension_ast_nodes=extension_nodes,
402445
)
403446

404-
def _build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
447+
def _build_enum_type(
448+
self,
449+
ast_node: EnumTypeDefinitionNode,
450+
extension_nodes: Collection[EnumTypeExtensionNode],
451+
) -> GraphQLEnumType:
452+
all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [
453+
ast_node,
454+
*extension_nodes,
455+
]
405456
return GraphQLEnumType(
406457
name=ast_node.name.value,
407458
description=ast_node.description.value if ast_node.description else None,
408-
values=self.build_enum_value_map([ast_node]),
459+
values=self.build_enum_value_map(all_nodes),
409460
ast_node=ast_node,
461+
extension_ast_nodes=extension_nodes,
410462
)
411463

412-
def _build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType:
464+
def _build_union_type(
465+
self,
466+
ast_node: UnionTypeDefinitionNode,
467+
extension_nodes: Collection[UnionTypeExtensionNode],
468+
) -> GraphQLUnionType:
469+
all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [
470+
ast_node,
471+
*extension_nodes,
472+
]
413473
return GraphQLUnionType(
414474
name=ast_node.name.value,
415475
description=ast_node.description.value if ast_node.description else None,
416-
types=lambda: self.build_union_types([ast_node]),
476+
types=lambda: self.build_union_types(all_nodes),
417477
ast_node=ast_node,
478+
extension_ast_nodes=extension_nodes,
418479
)
419480

420481
@staticmethod
421-
def _build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType:
482+
def _build_scalar_type(
483+
ast_node: ScalarTypeDefinitionNode,
484+
extension_nodes: Collection[ScalarTypeExtensionNode],
485+
) -> GraphQLScalarType:
422486
return GraphQLScalarType(
423487
name=ast_node.name.value,
424488
description=ast_node.description.value if ast_node.description else None,
425489
ast_node=ast_node,
490+
extension_ast_nodes=extension_nodes,
426491
)
427492

428493
def _build_input_object_type(
429-
self, ast_node: InputObjectTypeDefinitionNode
494+
self,
495+
ast_node: InputObjectTypeDefinitionNode,
496+
extension_nodes: Collection[InputObjectTypeExtensionNode],
430497
) -> GraphQLInputObjectType:
498+
all_nodes: List[
499+
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
500+
] = [ast_node, *extension_nodes]
431501
return GraphQLInputObjectType(
432502
name=ast_node.name.value,
433503
description=ast_node.description.value if ast_node.description else None,
434-
fields=lambda: self.build_input_field_map([ast_node]),
504+
fields=lambda: self.build_input_field_map(all_nodes),
435505
ast_node=ast_node,
506+
extension_ast_nodes=extension_nodes,
436507
)
437508

438509

src/graphql/utilities/extend_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Any, Dict, List, Optional, cast
2+
from typing import Any, DefaultDict, Dict, List, Optional, cast
33

44
from ..language import (
55
DirectiveDefinitionNode,
@@ -78,7 +78,7 @@ def extend_schema(
7878

7979
# Collect the type definitions and extensions found in the document.
8080
type_defs: List[TypeDefinitionNode] = []
81-
type_extensions_map: Dict[str, Any] = defaultdict(list)
81+
type_extensions_map: DefaultDict[str, Any] = defaultdict(list)
8282

8383
# New directives and types are separate because a directives and types can have the
8484
# same name. For example, a type named "skip".
@@ -311,7 +311,7 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
311311
assume_valid=assume_valid, resolve_type=resolve_type
312312
)
313313

314-
type_map = ast_builder.build_type_map(type_defs)
314+
type_map = ast_builder.build_type_map(type_defs, type_extensions_map)
315315
for existing_type_name, existing_type in schema.type_map.items():
316316
type_map[existing_type_name] = extend_named_type(existing_type)
317317

0 commit comments

Comments
 (0)