1- from typing import Callable , Collection , Dict , List , NoReturn , Optional , Union , cast
1+ from collections import defaultdict
2+ from typing import (
3+ Callable ,
4+ Collection ,
5+ DefaultDict ,
6+ Dict ,
7+ List ,
8+ NoReturn ,
9+ Optional ,
10+ Union ,
11+ cast ,
12+ )
213
314from ..language import (
415 DirectiveDefinitionNode ,
2031 ObjectTypeExtensionNode ,
2132 OperationType ,
2233 ScalarTypeDefinitionNode ,
34+ ScalarTypeExtensionNode ,
2335 SchemaDefinitionNode ,
2436 SchemaExtensionNode ,
2537 Source ,
2638 TypeDefinitionNode ,
39+ TypeExtensionNode ,
2740 TypeNode ,
2841 UnionTypeDefinitionNode ,
2942 UnionTypeExtensionNode ,
@@ -102,15 +115,23 @@ def build_ast_schema(
102115
103116 assert_valid_sdl (document_ast )
104117
118+ # Collect the definitions and extensions found in the document.
105119 schema_def : Optional [SchemaDefinitionNode ] = None
120+ schema_extensions : List [SchemaExtensionNode ] = []
106121 type_defs : List [TypeDefinitionNode ] = []
122+ type_extensions_map : DefaultDict [str , List [TypeExtensionNode ]] = defaultdict (list )
107123 directive_defs : List [DirectiveDefinitionNode ] = []
108124 append_directive_def = directive_defs .append
109125 for def_ in document_ast .definitions :
110126 if isinstance (def_ , SchemaDefinitionNode ):
111127 schema_def = def_
128+ elif isinstance (def_ , SchemaExtensionNode ):
129+ schema_extensions .append (def_ )
112130 elif isinstance (def_ , TypeDefinitionNode ):
113131 type_defs .append (def_ )
132+ elif isinstance (def_ , TypeExtensionNode ):
133+ extended_type_name = def_ .name .value
134+ type_extensions_map [extended_type_name ].append (def_ )
114135 elif isinstance (def_ , DirectiveDefinitionNode ):
115136 append_directive_def (def_ )
116137
@@ -124,10 +145,10 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
124145 assume_valid = assume_valid , resolve_type = resolve_type
125146 )
126147
127- type_map = ast_builder .build_type_map (type_defs )
148+ type_map = ast_builder .build_type_map (type_defs , type_extensions_map )
128149
129150 operation_types : Dict [OperationType , GraphQLObjectType ] = (
130- ast_builder .get_operation_types ([schema_def ])
151+ ast_builder .get_operation_types ([schema_def , * schema_extensions ])
131152 if schema_def
132153 else {
133154 # Note: While this could make early assertions to get the correctly
@@ -159,6 +180,7 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
159180 types = type_map .values (),
160181 directives = directives ,
161182 ast_node = schema_def ,
183+ extension_ast_nodes = schema_extensions ,
162184 assume_valid = assume_valid ,
163185 )
164186
@@ -359,15 +381,23 @@ def build_union_types(
359381 return types
360382
361383 def build_type_map (
362- self , nodes : Collection [TypeDefinitionNode ]
384+ self ,
385+ nodes : Collection [TypeDefinitionNode ],
386+ extensions_map : DefaultDict [str , List [TypeExtensionNode ]],
363387 ) -> Dict [str , GraphQLNamedType ]:
364388 type_map : Dict [str , GraphQLNamedType ] = {}
365389 for node in nodes :
366390 name = node .name .value
367- type_map [name ] = std_type_map .get (name ) or self ._build_type (node )
391+ type_map [name ] = std_type_map .get (name ) or self ._build_type (
392+ node , extensions_map [name ]
393+ )
368394 return type_map
369395
370- def _build_type (self , ast_node : TypeDefinitionNode ) -> GraphQLNamedType :
396+ def _build_type (
397+ self ,
398+ ast_node : TypeDefinitionNode ,
399+ extension_nodes : Collection [TypeExtensionNode ],
400+ ) -> GraphQLNamedType :
371401 try :
372402 # object_type_definition_node is built with _build_object_type etc.
373403 method = getattr (self , "_build_" + ast_node .kind [:- 11 ])
@@ -377,62 +407,103 @@ def _build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
377407 f"Unexpected type definition node: { inspect (ast_node )} ."
378408 )
379409 else :
380- return method (ast_node )
410+ return method (ast_node , extension_nodes )
381411
382412 def _build_object_type (
383- self , ast_node : ObjectTypeDefinitionNode
413+ self ,
414+ ast_node : ObjectTypeDefinitionNode ,
415+ extension_nodes : Collection [ObjectTypeExtensionNode ],
384416 ) -> GraphQLObjectType :
417+ all_nodes : List [Union [ObjectTypeDefinitionNode , ObjectTypeExtensionNode ]] = [
418+ ast_node ,
419+ * extension_nodes ,
420+ ]
385421 return GraphQLObjectType (
386422 name = ast_node .name .value ,
387423 description = ast_node .description .value if ast_node .description else None ,
388- interfaces = lambda : self .build_interfaces ([ ast_node ] ),
389- fields = lambda : self .build_field_map ([ ast_node ] ),
424+ interfaces = lambda : self .build_interfaces (all_nodes ),
425+ fields = lambda : self .build_field_map (all_nodes ),
390426 ast_node = ast_node ,
427+ extension_ast_nodes = extension_nodes ,
391428 )
392429
393430 def _build_interface_type (
394- self , ast_node : InterfaceTypeDefinitionNode
431+ self ,
432+ ast_node : InterfaceTypeDefinitionNode ,
433+ extension_nodes : Collection [InterfaceTypeExtensionNode ],
395434 ) -> GraphQLInterfaceType :
435+ all_nodes : List [
436+ Union [InterfaceTypeDefinitionNode , InterfaceTypeExtensionNode ]
437+ ] = [ast_node , * extension_nodes ]
396438 return GraphQLInterfaceType (
397439 name = ast_node .name .value ,
398440 description = ast_node .description .value if ast_node .description else None ,
399- interfaces = lambda : self .build_interfaces ([ ast_node ] ),
400- fields = lambda : self .build_field_map ([ ast_node ] ),
441+ interfaces = lambda : self .build_interfaces (all_nodes ),
442+ fields = lambda : self .build_field_map (all_nodes ),
401443 ast_node = ast_node ,
444+ extension_ast_nodes = extension_nodes ,
402445 )
403446
404- def _build_enum_type (self , ast_node : EnumTypeDefinitionNode ) -> GraphQLEnumType :
447+ def _build_enum_type (
448+ self ,
449+ ast_node : EnumTypeDefinitionNode ,
450+ extension_nodes : Collection [EnumTypeExtensionNode ],
451+ ) -> GraphQLEnumType :
452+ all_nodes : List [Union [EnumTypeDefinitionNode , EnumTypeExtensionNode ]] = [
453+ ast_node ,
454+ * extension_nodes ,
455+ ]
405456 return GraphQLEnumType (
406457 name = ast_node .name .value ,
407458 description = ast_node .description .value if ast_node .description else None ,
408- values = self .build_enum_value_map ([ ast_node ] ),
459+ values = self .build_enum_value_map (all_nodes ),
409460 ast_node = ast_node ,
461+ extension_ast_nodes = extension_nodes ,
410462 )
411463
412- def _build_union_type (self , ast_node : UnionTypeDefinitionNode ) -> GraphQLUnionType :
464+ def _build_union_type (
465+ self ,
466+ ast_node : UnionTypeDefinitionNode ,
467+ extension_nodes : Collection [UnionTypeExtensionNode ],
468+ ) -> GraphQLUnionType :
469+ all_nodes : List [Union [UnionTypeDefinitionNode , UnionTypeExtensionNode ]] = [
470+ ast_node ,
471+ * extension_nodes ,
472+ ]
413473 return GraphQLUnionType (
414474 name = ast_node .name .value ,
415475 description = ast_node .description .value if ast_node .description else None ,
416- types = lambda : self .build_union_types ([ ast_node ] ),
476+ types = lambda : self .build_union_types (all_nodes ),
417477 ast_node = ast_node ,
478+ extension_ast_nodes = extension_nodes ,
418479 )
419480
420481 @staticmethod
421- def _build_scalar_type (ast_node : ScalarTypeDefinitionNode ) -> GraphQLScalarType :
482+ def _build_scalar_type (
483+ ast_node : ScalarTypeDefinitionNode ,
484+ extension_nodes : Collection [ScalarTypeExtensionNode ],
485+ ) -> GraphQLScalarType :
422486 return GraphQLScalarType (
423487 name = ast_node .name .value ,
424488 description = ast_node .description .value if ast_node .description else None ,
425489 ast_node = ast_node ,
490+ extension_ast_nodes = extension_nodes ,
426491 )
427492
428493 def _build_input_object_type (
429- self , ast_node : InputObjectTypeDefinitionNode
494+ self ,
495+ ast_node : InputObjectTypeDefinitionNode ,
496+ extension_nodes : Collection [InputObjectTypeExtensionNode ],
430497 ) -> GraphQLInputObjectType :
498+ all_nodes : List [
499+ Union [InputObjectTypeDefinitionNode , InputObjectTypeExtensionNode ]
500+ ] = [ast_node , * extension_nodes ]
431501 return GraphQLInputObjectType (
432502 name = ast_node .name .value ,
433503 description = ast_node .description .value if ast_node .description else None ,
434- fields = lambda : self .build_input_field_map ([ ast_node ] ),
504+ fields = lambda : self .build_input_field_map (all_nodes ),
435505 ast_node = ast_node ,
506+ extension_ast_nodes = extension_nodes ,
436507 )
437508
438509
0 commit comments