diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index d87734bb..a6d54f62 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -7,23 +7,32 @@ from graphene.core.fields import Field, LazyField from graphene.utils import cached_property from graphene.env import get_global_schema +from django.db.models.query import QuerySet def get_type_for_model(schema, model): schema = schema or get_global_schema() types = schema.types.values() for _type in types: - type_model = getattr(_type._meta, 'model', None) + type_model = hasattr(_type,'_meta') and getattr(_type._meta, 'model', None) if model == type_model: return _type +class DjangoConnectionField(relay.ConnectionField): + def wrap_resolved(self, value, instance, args, info): + if isinstance(value, QuerySet): + cls = instance.__class__ + value = [cls(s) for s in value] + return value + + class ConnectionOrListField(LazyField): def get_field(self): schema = self.schema model_field = self.field_type field_object_type = model_field.get_object_type() if field_object_type and issubclass(field_object_type, schema.Node): - field = relay.ConnectionField(model_field) + field = DjangoConnectionField(model_field) else: field = ListField(model_field) field.contribute_to_class(self.object_type, self.field_name) diff --git a/graphene/relay/connections.py b/graphene/relay/connections.py index ba2053ec..5cda4705 100644 --- a/graphene/relay/connections.py +++ b/graphene/relay/connections.py @@ -12,7 +12,7 @@ from graphene.core.fields import NativeField @signals.class_prepared.connect def object_type_created(object_type): schema = object_type._meta.schema - if issubclass(object_type, schema.Node) and object_type != schema.Node: + if hasattr(schema, 'Node') and issubclass(object_type, schema.Node) and object_type != schema.Node: if object_type._meta.proxy: return type_name = object_type._meta.type_name diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 2a35763a..a888919d 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -15,10 +15,14 @@ class ConnectionField(Field): super(ConnectionField, self).__init__(field_type, resolve=resolve, args=connectionArgs, description=description) + def wrap_resolved(self, value, instance, args, info): + return value + def resolve(self, instance, args, info): resolved = super(ConnectionField, self).resolve(instance, args, info) if resolved: assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable' + resolved = self.wrap_resolved(resolved, instance, args, info) return connectionFromArray(resolved, args) @cached_property