Update Connections

This commit is contained in:
Syrus Akbary 2015-11-20 16:12:11 -08:00
parent b1e0c3b533
commit 732b1aec1b
5 changed files with 31 additions and 24 deletions

View File

@ -8,9 +8,7 @@ from .utils import get_type_for_model, lazy_map
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
pass
def wrap_resolved(self, value, instance, args, info):
return lazy_map(value, self.type)
class LazyListField(Field): class LazyListField(Field):

View File

@ -2,10 +2,10 @@ import six
from ...core.types import BaseObjectType, ObjectTypeMeta from ...core.types import BaseObjectType, ObjectTypeMeta
from ...relay.fields import GlobalIDField from ...relay.fields import GlobalIDField
from ...relay.types import BaseNode from ...relay.types import BaseNode, Connection
from .converter import convert_django_field from .converter import convert_django_field
from .options import DjangoOptions from .options import DjangoOptions
from .utils import get_reverse_fields from .utils import get_reverse_fields, lazy_map
class DjangoObjectTypeMeta(ObjectTypeMeta): class DjangoObjectTypeMeta(ObjectTypeMeta):
@ -71,6 +71,13 @@ class DjangoInterface(six.with_metaclass(
pass pass
class DjangoConnection(Connection):
@classmethod
def from_list(cls, iterable, *args, **kwargs):
iterable = lazy_map(iterable, cls.edge_type.node_type)
return super(DjangoConnection, cls).from_list(iterable, *args, **kwargs)
class DjangoNode(BaseNode, DjangoInterface): class DjangoNode(BaseNode, DjangoInterface):
id = GlobalIDField() id = GlobalIDField()
@ -81,3 +88,5 @@ class DjangoNode(BaseNode, DjangoInterface):
return cls(instance) return cls(instance)
except cls._meta.model.DoesNotExist: except cls._meta.model.DoesNotExist:
return None return None
connection_type = DjangoConnection

View File

@ -219,6 +219,10 @@ class BaseObjectType(BaseType):
return OrderedDict(fields) return OrderedDict(fields)
@classmethod
def wrap(cls, instance, args, info):
return cls(_root=instance)
class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass pass

View File

@ -1,6 +1,3 @@
from collections import Iterable
from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id
from ..core.fields import Field from ..core.fields import Field
@ -30,24 +27,11 @@ class ConnectionField(Field):
return value return value
def resolver(self, instance, args, info): def resolver(self, instance, args, info):
from graphene.relay.types import PageInfo
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
connection_type = self.get_type(schema)
resolved = super(ConnectionField, self).resolver(instance, args, info) resolved = super(ConnectionField, self).resolver(instance, args, info)
if resolved: if not isinstance(resolved, connection_type):
resolved = self.wrap_resolved(resolved, instance, args, info) return connection_type.from_list(resolved, args, info)
assert isinstance(
resolved, Iterable), 'Resolved value from the connection field have to be iterable'
type = schema.T(self.type)
node = schema.objecttype(type)
connection_type = self.get_connection_type(node)
edge_type = self.get_edge_type(node)
connection = connection_from_list(
resolved, args, connection_type=connection_type,
edge_type=edge_type, pageinfo_type=PageInfo)
connection.set_connection_data(resolved)
return connection
def get_connection_type(self, node): def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type() connection_type = self.connection_type or node.get_connection_type()

View File

@ -1,6 +1,8 @@
import inspect import inspect
import warnings import warnings
from collections import Iterable
from functools import wraps from functools import wraps
from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import to_global_id from graphql_relay.node.node import to_global_id
from ..core.types import (Boolean, Field, InputObjectType, Interface, List, from ..core.types import (Boolean, Field, InputObjectType, Interface, List,
@ -63,6 +65,16 @@ class Connection(ObjectType):
(cls,), (cls,),
{'edge_type': edge_type, 'edges': edges}) {'edge_type': edge_type, 'edges': edges})
@classmethod
def from_list(cls, iterable, args, info):
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
iterable, args, connection_type=cls,
edge_type=cls.edge_type, pageinfo_type=PageInfo)
connection.set_connection_data(iterable)
return connection
def set_connection_data(self, data): def set_connection_data(self, data):
self._connection_data = data self._connection_data = data