11from functools import partial
2-
2+ from promise import is_thenable , Promise
33from sqlalchemy .orm .query import Query
44
55from graphene .relay import ConnectionField
@@ -25,39 +25,38 @@ def get_query(cls, model, info, sort=None, **args):
2525 query = query .order_by (* (col .value for col in sort ))
2626 return query
2727
28- @property
29- def type (self ):
30- from .types import SQLAlchemyObjectType
31- _type = super (ConnectionField , self ).type
32- assert issubclass (_type , SQLAlchemyObjectType ), (
33- "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
34- )
35- assert _type ._meta .connection , "The type {} doesn't have a connection" .format (_type .__name__ )
36- return _type ._meta .connection
37-
3828 @classmethod
39- def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
40- iterable = resolver (root , info , ** args )
41- if iterable is None :
42- iterable = cls .get_query (model , info , ** args )
43- if isinstance (iterable , Query ):
44- _len = iterable .count ()
29+ def resolve_connection (cls , connection_type , model , info , args , resolved ):
30+ if resolved is None :
31+ resolved = cls .get_query (model , info , ** args )
32+ if isinstance (resolved , Query ):
33+ _len = resolved .count ()
4534 else :
46- _len = len (iterable )
35+ _len = len (resolved )
4736 connection = connection_from_list_slice (
48- iterable ,
37+ resolved ,
4938 args ,
5039 slice_start = 0 ,
5140 list_length = _len ,
5241 list_slice_length = _len ,
53- connection_type = connection ,
42+ connection_type = connection_type ,
5443 pageinfo_type = PageInfo ,
55- edge_type = connection .Edge ,
44+ edge_type = connection_type .Edge ,
5645 )
57- connection .iterable = iterable
46+ connection .iterable = resolved
5847 connection .length = _len
5948 return connection
6049
50+ @classmethod
51+ def connection_resolver (cls , resolver , connection_type , model , root , info , ** args ):
52+ resolved = resolver (root , info , ** args )
53+
54+ on_resolve = partial (cls .resolve_connection , connection_type , model , info , args )
55+ if is_thenable (resolved ):
56+ return Promise .resolve (resolved ).then (on_resolve )
57+
58+ return on_resolve (resolved )
59+
6160 def get_resolver (self , parent_resolver ):
6261 return partial (self .connection_resolver , parent_resolver , self .type , self .model )
6362
0 commit comments