From 314703d7b5380c80fe2ee1e2ad4a3eb3d9ed7fc4 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Fri, 5 Feb 2016 17:29:43 -0800 Subject: [PATCH] Improved querying and slicing in DjangoConnectionFields and inherited. Fixed #108 --- .../contrib/django/debug/tests/test_query.py | 188 ++++++++++++++++-- graphene/contrib/django/fields.py | 2 +- graphene/contrib/django/types.py | 18 +- graphene/contrib/sqlalchemy/types.py | 18 +- graphene/core/classtypes/objecttype.py | 3 +- graphene/core/types/field.py | 6 +- 6 files changed, 197 insertions(+), 38 deletions(-) diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py index 4df26e4f..b853a791 100644 --- a/graphene/contrib/django/debug/tests/test_query.py +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -1,33 +1,39 @@ import pytest import graphene -from graphene.contrib.django import DjangoObjectType +from graphene.contrib.django import DjangoNode, DjangoConnectionField +from graphene.contrib.django.filter import DjangoFilterConnectionField from ...tests.models import Reporter from ..plugin import DjangoDebugPlugin # from examples.starwars_django.models import Character +from django.db.models import Count + pytestmark = pytest.mark.django_db -def test_should_query_well(): +def count(qs): + query = qs.query + query.add_annotation(Count('*'), alias='__count', is_summary=True) + query.select = [] + query.default_cols = False + return query + + +def test_should_query_field(): r1 = Reporter(last_name='ABA') r1.save() r2 = Reporter(last_name='Griffin') r2.save() - class ReporterType(DjangoObjectType): - + class ReporterType(DjangoNode): class Meta: model = Reporter class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - all_reporters = ReporterType.List() - - def resolve_all_reporters(self, *args, **kwargs): - return Reporter.objects.all() def resolve_reporter(self, *args, **kwargs): return Reporter.objects.first() @@ -37,9 +43,6 @@ def test_should_query_well(): reporter { lastName } - allReporters { - lastName - } __debug { sql { rawSql @@ -51,6 +54,48 @@ def test_should_query_well(): 'reporter': { 'lastName': 'ABA', }, + '__debug': { + 'sql': [{ + 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_query_list(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoNode): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + all_reporters = ReporterType.List() + + def resolve_all_reporters(self, *args, **kwargs): + return Reporter.objects.all() + + query = ''' + query ReporterQuery { + allReporters { + lastName + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { 'allReporters': [{ 'lastName': 'ABA', }, { @@ -58,8 +103,6 @@ def test_should_query_well(): }], '__debug': { 'sql': [{ - 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) - }, { 'rawSql': str(Reporter.objects.all().query) }] } @@ -68,3 +111,122 @@ def test_should_query_well(): result = schema.execute(query) assert not result.errors assert result.data == expected + + +def test_should_query_connection(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoNode): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + all_reporters_connection = DjangoConnectionField(ReporterType) + + def resolve_all_reporters_connection(self, *args, **kwargs): + return Reporter.objects.all() + + query = ''' + query ReporterQuery { + allReportersConnection(first:1) { + edges { + node { + lastName + } + } + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { + 'allReportersConnection': { + 'edges': [{ + 'node': { + 'lastName': 'ABA', + } + }] + }, + '__debug': { + 'sql': [{ + 'rawSql': str(count(Reporter.objects.all())) + }, { + 'rawSql': str(Reporter.objects.all()[:1].query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_query_connectionfilter(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoNode): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + all_reporters_connection_filter = DjangoFilterConnectionField(ReporterType) + + def resolve_all_reporters_connection_filter(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_all_reporters_connection(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_all_reporters(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_reporter(self, *args, **kwargs): + return Reporter.objects.first() + + query = ''' + query ReporterQuery { + allReportersConnectionFilter(first:1) { + edges { + node { + lastName + } + } + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { + 'allReportersConnectionFilter': { + 'edges': [{ + 'node': { + 'lastName': 'ABA', + } + }] + }, + '__debug': { + 'sql': [{ + 'rawSql': str(count(Reporter.objects.all())) + }, { + 'rawSql': str(Reporter.objects.all()[:1].query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 4b675297..d7321e21 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -27,7 +27,7 @@ class DjangoConnectionField(ConnectionField): return resolved_qs def from_list(self, connection_type, resolved, args, info): - if not resolved: + if resolved is None: resolved = self.get_manager() resolved_qs = maybe_queryset(resolved) qs = self.get_queryset(resolved_qs, args, info) diff --git a/graphene/contrib/django/types.py b/graphene/contrib/django/types.py index 53c61f4b..f8f9cb2b 100644 --- a/graphene/contrib/django/types.py +++ b/graphene/contrib/django/types.py @@ -52,15 +52,14 @@ class InstanceObjectType(ObjectType): abstract = True def __init__(self, _root=None): - if _root: - assert isinstance(_root, self._meta.model), ( - '{} received a non-compatible instance ({}) ' - 'when expecting {}'.format( - self.__class__.__name__, - _root.__class__.__name__, - self._meta.model.__name__ - )) super(InstanceObjectType, self).__init__(_root=_root) + assert not self._root or isinstance(self._root, self._meta.model), ( + '{} received a non-compatible instance ({}) ' + 'when expecting {}'.format( + self.__class__.__name__, + self._root.__class__.__name__, + self._meta.model.__name__ + )) @property def instance(self): @@ -70,9 +69,6 @@ class InstanceObjectType(ObjectType): def instance(self, value): self._root = value - def __getattr__(self, attr): - return getattr(self._root, attr) - class DjangoObjectType(six.with_metaclass( DjangoObjectTypeMeta, InstanceObjectType)): diff --git a/graphene/contrib/sqlalchemy/types.py b/graphene/contrib/sqlalchemy/types.py index 07308f39..8f70d245 100644 --- a/graphene/contrib/sqlalchemy/types.py +++ b/graphene/contrib/sqlalchemy/types.py @@ -65,15 +65,14 @@ class InstanceObjectType(ObjectType): abstract = True def __init__(self, _root=None): - if _root: - assert isinstance(_root, self._meta.model), ( - '{} received a non-compatible instance ({}) ' - 'when expecting {}'.format( - self.__class__.__name__, - _root.__class__.__name__, - self._meta.model.__name__ - )) super(InstanceObjectType, self).__init__(_root=_root) + assert not self._root or isinstance(self._root, self._meta.model), ( + '{} received a non-compatible instance ({}) ' + 'when expecting {}'.format( + self.__class__.__name__, + self._root.__class__.__name__, + self._meta.model.__name__ + )) @property def instance(self): @@ -83,9 +82,6 @@ class InstanceObjectType(ObjectType): def instance(self, value): self._root = value - def __getattr__(self, attr): - return getattr(self._root, attr) - class SQLAlchemyObjectType(six.with_metaclass( SQLAlchemyObjectTypeMeta, InstanceObjectType)): diff --git a/graphene/core/classtypes/objecttype.py b/graphene/core/classtypes/objecttype.py index d542f28c..3f94bddf 100644 --- a/graphene/core/classtypes/objecttype.py +++ b/graphene/core/classtypes/objecttype.py @@ -47,7 +47,8 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta, FieldsClassType)): abstract = True def __getattr__(self, name): - return self._root and getattr(self._root, name) + if name != '_root' and self._root: + return getattr(self._root, name) def __init__(self, *args, **kwargs): signals.pre_init.send(self.__class__, args=args, kwargs=kwargs) diff --git a/graphene/core/types/field.py b/graphene/core/types/field.py index 3109ffe0..f825f861 100644 --- a/graphene/core/types/field.py +++ b/graphene/core/types/field.py @@ -57,7 +57,8 @@ class Field(NamedType, OrderedType): @property def resolver(self): - return self.resolver_fn or self.get_resolver_fn() + resolver = self.get_resolver_fn() + return resolver @property def default(self): @@ -70,6 +71,9 @@ class Field(NamedType, OrderedType): self._default = value def get_resolver_fn(self): + if self.resolver_fn: + return self.resolver_fn + resolve_fn_name = 'resolve_%s' % self.attname if hasattr(self.object_type, resolve_fn_name): return getattr(self.object_type, resolve_fn_name)