Improved querying and slicing in DjangoConnectionFields and inherited. Fixed #108

This commit is contained in:
Syrus Akbary 2016-02-05 17:29:43 -08:00
parent c5b15cec2f
commit 314703d7b5
6 changed files with 197 additions and 38 deletions

View File

@ -1,33 +1,39 @@
import pytest
import graphene
from graphene.contrib.django import DjangoObjectType
from graphene.contrib.django import DjangoNode, DjangoConnectionField
from graphene.contrib.django.filter import DjangoFilterConnectionField
from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character
from django.db.models import Count
pytestmark = pytest.mark.django_db
def test_should_query_well():
def count(qs):
query = qs.query
query.add_annotation(Count('*'), alias='__count', is_summary=True)
query.select = []
query.default_cols = False
return query
def test_should_query_field():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoObjectType):
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
@ -37,9 +43,6 @@ def test_should_query_well():
reporter {
lastName
}
allReporters {
lastName
}
__debug {
sql {
rawSql
@ -51,6 +54,48 @@ def test_should_query_well():
'reporter': {
'lastName': 'ABA',
},
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_list():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
query = '''
query ReporterQuery {
allReporters {
lastName
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReporters': [{
'lastName': 'ABA',
}, {
@ -58,8 +103,6 @@ def test_should_query_well():
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}, {
'rawSql': str(Reporter.objects.all().query)
}]
}
@ -68,3 +111,122 @@ def test_should_query_well():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_connection():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters_connection = DjangoConnectionField(ReporterType)
def resolve_all_reporters_connection(self, *args, **kwargs):
return Reporter.objects.all()
query = '''
query ReporterQuery {
allReportersConnection(first:1) {
edges {
node {
lastName
}
}
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReportersConnection': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
'__debug': {
'sql': [{
'rawSql': str(count(Reporter.objects.all()))
}, {
'rawSql': str(Reporter.objects.all()[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_connectionfilter():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters_connection_filter = DjangoFilterConnectionField(ReporterType)
def resolve_all_reporters_connection_filter(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_all_reporters_connection(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
allReportersConnectionFilter(first:1) {
edges {
node {
lastName
}
}
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReportersConnectionFilter': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
'__debug': {
'sql': [{
'rawSql': str(count(Reporter.objects.all()))
}, {
'rawSql': str(Reporter.objects.all()[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -27,7 +27,7 @@ class DjangoConnectionField(ConnectionField):
return resolved_qs
def from_list(self, connection_type, resolved, args, info):
if not resolved:
if resolved is None:
resolved = self.get_manager()
resolved_qs = maybe_queryset(resolved)
qs = self.get_queryset(resolved_qs, args, info)

View File

@ -52,15 +52,14 @@ class InstanceObjectType(ObjectType):
abstract = True
def __init__(self, _root=None):
if _root:
assert isinstance(_root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
_root.__class__.__name__,
self._meta.model.__name__
))
super(InstanceObjectType, self).__init__(_root=_root)
assert not self._root or isinstance(self._root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
self._root.__class__.__name__,
self._meta.model.__name__
))
@property
def instance(self):
@ -70,9 +69,6 @@ class InstanceObjectType(ObjectType):
def instance(self, value):
self._root = value
def __getattr__(self, attr):
return getattr(self._root, attr)
class DjangoObjectType(six.with_metaclass(
DjangoObjectTypeMeta, InstanceObjectType)):

View File

@ -65,15 +65,14 @@ class InstanceObjectType(ObjectType):
abstract = True
def __init__(self, _root=None):
if _root:
assert isinstance(_root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
_root.__class__.__name__,
self._meta.model.__name__
))
super(InstanceObjectType, self).__init__(_root=_root)
assert not self._root or isinstance(self._root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
self._root.__class__.__name__,
self._meta.model.__name__
))
@property
def instance(self):
@ -83,9 +82,6 @@ class InstanceObjectType(ObjectType):
def instance(self, value):
self._root = value
def __getattr__(self, attr):
return getattr(self._root, attr)
class SQLAlchemyObjectType(six.with_metaclass(
SQLAlchemyObjectTypeMeta, InstanceObjectType)):

View File

@ -47,7 +47,8 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta, FieldsClassType)):
abstract = True
def __getattr__(self, name):
return self._root and getattr(self._root, name)
if name != '_root' and self._root:
return getattr(self._root, name)
def __init__(self, *args, **kwargs):
signals.pre_init.send(self.__class__, args=args, kwargs=kwargs)

View File

@ -57,7 +57,8 @@ class Field(NamedType, OrderedType):
@property
def resolver(self):
return self.resolver_fn or self.get_resolver_fn()
resolver = self.get_resolver_fn()
return resolver
@property
def default(self):
@ -70,6 +71,9 @@ class Field(NamedType, OrderedType):
self._default = value
def get_resolver_fn(self):
if self.resolver_fn:
return self.resolver_fn
resolve_fn_name = 'resolve_%s' % self.attname
if hasattr(self.object_type, resolve_fn_name):
return getattr(self.object_type, resolve_fn_name)