2016-08-15 01:42:27 +03:00
|
|
|
from functools import partial
|
2016-07-23 06:18:23 +03:00
|
|
|
from sqlalchemy.orm.query import Query
|
2016-07-18 18:28:12 +03:00
|
|
|
|
2016-07-23 06:18:23 +03:00
|
|
|
from graphene.relay import ConnectionField
|
2016-08-20 04:51:12 +03:00
|
|
|
from graphene.relay.connection import PageInfo
|
2016-07-23 06:18:23 +03:00
|
|
|
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
|
|
|
from .utils import get_query
|
2016-07-18 18:28:12 +03:00
|
|
|
|
|
|
|
|
|
|
|
class SQLAlchemyConnectionField(ConnectionField):
|
|
|
|
|
|
|
|
@property
|
|
|
|
def model(self):
|
2016-08-15 01:42:27 +03:00
|
|
|
return self.type._meta.node._meta.model
|
2016-07-18 18:28:12 +03:00
|
|
|
|
2016-07-23 06:18:23 +03:00
|
|
|
@staticmethod
|
2016-08-15 01:42:27 +03:00
|
|
|
def connection_resolver(resolver, connection, model, root, args, context, info):
|
2016-07-23 06:18:23 +03:00
|
|
|
iterable = resolver(root, args, context, info)
|
2016-08-15 01:42:27 +03:00
|
|
|
if iterable is None:
|
|
|
|
iterable = get_query(model, context)
|
2016-07-23 06:18:23 +03:00
|
|
|
if isinstance(iterable, Query):
|
|
|
|
_len = iterable.count()
|
2016-07-18 18:28:12 +03:00
|
|
|
else:
|
2016-07-23 06:18:23 +03:00
|
|
|
_len = len(iterable)
|
|
|
|
return connection_from_list_slice(
|
|
|
|
iterable,
|
|
|
|
args,
|
|
|
|
slice_start=0,
|
|
|
|
list_length=_len,
|
|
|
|
list_slice_length=_len,
|
|
|
|
connection_type=connection,
|
2016-08-20 04:51:12 +03:00
|
|
|
pageinfo_type=PageInfo,
|
2016-07-23 06:18:23 +03:00
|
|
|
edge_type=connection.Edge,
|
|
|
|
)
|
2016-08-15 01:42:27 +03:00
|
|
|
|
|
|
|
def get_resolver(self, parent_resolver):
|
|
|
|
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
|