diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index c64b9547..0b719586 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -8,9 +8,7 @@ from .utils import get_type_for_model, lazy_map class DjangoConnectionField(ConnectionField): - - def wrap_resolved(self, value, instance, args, info): - return lazy_map(value, self.type) + pass class LazyListField(Field): diff --git a/graphene/contrib/django/types.py b/graphene/contrib/django/types.py index f17893f0..8ef2a6a3 100644 --- a/graphene/contrib/django/types.py +++ b/graphene/contrib/django/types.py @@ -2,10 +2,10 @@ import six from ...core.types import BaseObjectType, ObjectTypeMeta from ...relay.fields import GlobalIDField -from ...relay.types import BaseNode +from ...relay.types import BaseNode, Connection from .converter import convert_django_field from .options import DjangoOptions -from .utils import get_reverse_fields +from .utils import get_reverse_fields, lazy_map class DjangoObjectTypeMeta(ObjectTypeMeta): @@ -71,6 +71,13 @@ class DjangoInterface(six.with_metaclass( pass +class DjangoConnection(Connection): + @classmethod + def from_list(cls, iterable, *args, **kwargs): + iterable = lazy_map(iterable, cls.edge_type.node_type) + return super(DjangoConnection, cls).from_list(iterable, *args, **kwargs) + + class DjangoNode(BaseNode, DjangoInterface): id = GlobalIDField() @@ -81,3 +88,5 @@ class DjangoNode(BaseNode, DjangoInterface): return cls(instance) except cls._meta.model.DoesNotExist: return None + + connection_type = DjangoConnection diff --git a/graphene/core/types/objecttype.py b/graphene/core/types/objecttype.py index e18518e9..14283af5 100644 --- a/graphene/core/types/objecttype.py +++ b/graphene/core/types/objecttype.py @@ -219,6 +219,10 @@ class BaseObjectType(BaseType): return OrderedDict(fields) + @classmethod + def wrap(cls, instance, args, info): + return cls(_root=instance) + class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): pass diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 58e0e8bd..a54140ee 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -1,6 +1,3 @@ -from collections import Iterable - -from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.node.node import from_global_id from ..core.fields import Field @@ -30,24 +27,11 @@ class ConnectionField(Field): return value def resolver(self, instance, args, info): - from graphene.relay.types import PageInfo schema = info.schema.graphene_schema - + connection_type = self.get_type(schema) resolved = super(ConnectionField, self).resolver(instance, args, info) - if resolved: - resolved = self.wrap_resolved(resolved, instance, args, info) - assert isinstance( - resolved, Iterable), 'Resolved value from the connection field have to be iterable' - type = schema.T(self.type) - node = schema.objecttype(type) - connection_type = self.get_connection_type(node) - edge_type = self.get_edge_type(node) - - connection = connection_from_list( - resolved, args, connection_type=connection_type, - edge_type=edge_type, pageinfo_type=PageInfo) - connection.set_connection_data(resolved) - return connection + if not isinstance(resolved, connection_type): + return connection_type.from_list(resolved, args, info) def get_connection_type(self, node): connection_type = self.connection_type or node.get_connection_type() diff --git a/graphene/relay/types.py b/graphene/relay/types.py index a803d2f3..3969a09e 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -1,6 +1,8 @@ import inspect import warnings +from collections import Iterable from functools import wraps +from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.node.node import to_global_id from ..core.types import (Boolean, Field, InputObjectType, Interface, List, @@ -63,6 +65,16 @@ class Connection(ObjectType): (cls,), {'edge_type': edge_type, 'edges': edges}) + @classmethod + def from_list(cls, iterable, args, info): + assert isinstance( + iterable, Iterable), 'Resolved value from the connection field have to be iterable' + connection = connection_from_list( + iterable, args, connection_type=cls, + edge_type=cls.edge_type, pageinfo_type=PageInfo) + connection.set_connection_data(iterable) + return connection + def set_connection_data(self, data): self._connection_data = data