88 FragmentDefinitionNode ,
99 FragmentSpreadNode ,
1010 InlineFragmentNode ,
11+ OperationDefinitionNode ,
12+ OperationType ,
1113 SelectionSetNode ,
1214)
1315from ..type import (
@@ -43,7 +45,7 @@ def collect_fields(
4345 fragments : Dict [str , FragmentDefinitionNode ],
4446 variable_values : Dict [str , Any ],
4547 runtime_type : GraphQLObjectType ,
46- selection_set : SelectionSetNode ,
48+ operation : OperationDefinitionNode ,
4749) -> FieldsAndPatches :
4850 """Collect fields.
4951
@@ -61,8 +63,9 @@ def collect_fields(
6163 schema ,
6264 fragments ,
6365 variable_values ,
66+ operation ,
6467 runtime_type ,
65- selection_set ,
68+ operation . selection_set ,
6669 fields ,
6770 patches ,
6871 set (),
@@ -74,6 +77,7 @@ def collect_subfields(
7477 schema : GraphQLSchema ,
7578 fragments : Dict [str , FragmentDefinitionNode ],
7679 variable_values : Dict [str , Any ],
80+ operation : OperationDefinitionNode ,
7781 return_type : GraphQLObjectType ,
7882 field_nodes : List [FieldNode ],
7983) -> FieldsAndPatches :
@@ -100,6 +104,7 @@ def collect_subfields(
100104 schema ,
101105 fragments ,
102106 variable_values ,
107+ operation ,
103108 return_type ,
104109 node .selection_set ,
105110 sub_field_nodes ,
@@ -113,6 +118,7 @@ def collect_fields_impl(
113118 schema : GraphQLSchema ,
114119 fragments : Dict [str , FragmentDefinitionNode ],
115120 variable_values : Dict [str , Any ],
121+ operation : OperationDefinitionNode ,
116122 runtime_type : GraphQLObjectType ,
117123 selection_set : SelectionSetNode ,
118124 fields : Dict [str , List [FieldNode ]],
@@ -133,13 +139,14 @@ def collect_fields_impl(
133139 ) or not does_fragment_condition_match (schema , selection , runtime_type ):
134140 continue
135141
136- defer = get_defer_values (variable_values , selection )
142+ defer = get_defer_values (operation , variable_values , selection )
137143 if defer :
138144 patch_fields = defaultdict (list )
139145 collect_fields_impl (
140146 schema ,
141147 fragments ,
142148 variable_values ,
149+ operation ,
143150 runtime_type ,
144151 selection .selection_set ,
145152 patch_fields ,
@@ -152,6 +159,7 @@ def collect_fields_impl(
152159 schema ,
153160 fragments ,
154161 variable_values ,
162+ operation ,
155163 runtime_type ,
156164 selection .selection_set ,
157165 fields ,
@@ -164,7 +172,7 @@ def collect_fields_impl(
164172 if not should_include_node (variable_values , selection ):
165173 continue
166174
167- defer = get_defer_values (variable_values , selection )
175+ defer = get_defer_values (operation , variable_values , selection )
168176 if frag_name in visited_fragment_names and not defer :
169177 continue
170178
@@ -183,6 +191,7 @@ def collect_fields_impl(
183191 schema ,
184192 fragments ,
185193 variable_values ,
194+ operation ,
186195 runtime_type ,
187196 fragment .selection_set ,
188197 patch_fields ,
@@ -195,6 +204,7 @@ def collect_fields_impl(
195204 schema ,
196205 fragments ,
197206 variable_values ,
207+ operation ,
198208 runtime_type ,
199209 fragment .selection_set ,
200210 fields ,
@@ -210,7 +220,9 @@ class DeferValues(NamedTuple):
210220
211221
212222def get_defer_values (
213- variable_values : Dict [str , Any ], node : Union [FragmentSpreadNode , InlineFragmentNode ]
223+ operation : OperationDefinitionNode ,
224+ variable_values : Dict [str , Any ],
225+ node : Union [FragmentSpreadNode , InlineFragmentNode ],
214226) -> Optional [DeferValues ]:
215227 """Get values of defer directive if active.
216228
@@ -223,6 +235,13 @@ def get_defer_values(
223235 if not defer or defer .get ("if" ) is False :
224236 return None
225237
238+ if operation .operation == OperationType .SUBSCRIPTION :
239+ msg = (
240+ "`@defer` directive not supported on subscription operations."
241+ " Disable `@defer` by setting the `if` argument to `false`."
242+ )
243+ raise TypeError (msg )
244+
226245 return DeferValues (defer .get ("label" ))
227246
228247
0 commit comments