Improved querying and slicing in DjangoConnectionFields and inherited. Fixed #108

This commit is contained in:
Syrus Akbary 2016-02-05 17:29:43 -08:00
parent c5b15cec2f
commit 314703d7b5
6 changed files with 197 additions and 38 deletions

View File

@ -1,33 +1,39 @@
import pytest import pytest
import graphene import graphene
from graphene.contrib.django import DjangoObjectType from graphene.contrib.django import DjangoNode, DjangoConnectionField
from graphene.contrib.django.filter import DjangoFilterConnectionField
from ...tests.models import Reporter from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin from ..plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character # from examples.starwars_django.models import Character
from django.db.models import Count
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def test_should_query_well(): def count(qs):
query = qs.query
query.add_annotation(Count('*'), alias='__count', is_summary=True)
query.select = []
query.default_cols = False
return query
def test_should_query_field():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name='ABA')
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name='Griffin')
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoNode):
class Meta: class Meta:
model = Reporter model = Reporter
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first() return Reporter.objects.first()
@ -37,9 +43,6 @@ def test_should_query_well():
reporter { reporter {
lastName lastName
} }
allReporters {
lastName
}
__debug { __debug {
sql { sql {
rawSql rawSql
@ -51,6 +54,48 @@ def test_should_query_well():
'reporter': { 'reporter': {
'lastName': 'ABA', 'lastName': 'ABA',
}, },
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_list():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters = ReporterType.List()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
query = '''
query ReporterQuery {
allReporters {
lastName
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReporters': [{ 'allReporters': [{
'lastName': 'ABA', 'lastName': 'ABA',
}, { }, {
@ -58,8 +103,6 @@ def test_should_query_well():
}], }],
'__debug': { '__debug': {
'sql': [{ 'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}, {
'rawSql': str(Reporter.objects.all().query) 'rawSql': str(Reporter.objects.all().query)
}] }]
} }
@ -68,3 +111,122 @@ def test_should_query_well():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_connection():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters_connection = DjangoConnectionField(ReporterType)
def resolve_all_reporters_connection(self, *args, **kwargs):
return Reporter.objects.all()
query = '''
query ReporterQuery {
allReportersConnection(first:1) {
edges {
node {
lastName
}
}
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReportersConnection': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
'__debug': {
'sql': [{
'rawSql': str(count(Reporter.objects.all()))
}, {
'rawSql': str(Reporter.objects.all()[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_connectionfilter():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoNode):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
all_reporters_connection_filter = DjangoFilterConnectionField(ReporterType)
def resolve_all_reporters_connection_filter(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_all_reporters_connection(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all()
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
allReportersConnectionFilter(first:1) {
edges {
node {
lastName
}
}
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReportersConnectionFilter': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
'__debug': {
'sql': [{
'rawSql': str(count(Reporter.objects.all()))
}, {
'rawSql': str(Reporter.objects.all()[:1].query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -27,7 +27,7 @@ class DjangoConnectionField(ConnectionField):
return resolved_qs return resolved_qs
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, info):
if not resolved: if resolved is None:
resolved = self.get_manager() resolved = self.get_manager()
resolved_qs = maybe_queryset(resolved) resolved_qs = maybe_queryset(resolved)
qs = self.get_queryset(resolved_qs, args, info) qs = self.get_queryset(resolved_qs, args, info)

View File

@ -52,15 +52,14 @@ class InstanceObjectType(ObjectType):
abstract = True abstract = True
def __init__(self, _root=None): def __init__(self, _root=None):
if _root:
assert isinstance(_root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
_root.__class__.__name__,
self._meta.model.__name__
))
super(InstanceObjectType, self).__init__(_root=_root) super(InstanceObjectType, self).__init__(_root=_root)
assert not self._root or isinstance(self._root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
self._root.__class__.__name__,
self._meta.model.__name__
))
@property @property
def instance(self): def instance(self):
@ -70,9 +69,6 @@ class InstanceObjectType(ObjectType):
def instance(self, value): def instance(self, value):
self._root = value self._root = value
def __getattr__(self, attr):
return getattr(self._root, attr)
class DjangoObjectType(six.with_metaclass( class DjangoObjectType(six.with_metaclass(
DjangoObjectTypeMeta, InstanceObjectType)): DjangoObjectTypeMeta, InstanceObjectType)):

View File

@ -65,15 +65,14 @@ class InstanceObjectType(ObjectType):
abstract = True abstract = True
def __init__(self, _root=None): def __init__(self, _root=None):
if _root:
assert isinstance(_root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
_root.__class__.__name__,
self._meta.model.__name__
))
super(InstanceObjectType, self).__init__(_root=_root) super(InstanceObjectType, self).__init__(_root=_root)
assert not self._root or isinstance(self._root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
self._root.__class__.__name__,
self._meta.model.__name__
))
@property @property
def instance(self): def instance(self):
@ -83,9 +82,6 @@ class InstanceObjectType(ObjectType):
def instance(self, value): def instance(self, value):
self._root = value self._root = value
def __getattr__(self, attr):
return getattr(self._root, attr)
class SQLAlchemyObjectType(six.with_metaclass( class SQLAlchemyObjectType(six.with_metaclass(
SQLAlchemyObjectTypeMeta, InstanceObjectType)): SQLAlchemyObjectTypeMeta, InstanceObjectType)):

View File

@ -47,7 +47,8 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta, FieldsClassType)):
abstract = True abstract = True
def __getattr__(self, name): def __getattr__(self, name):
return self._root and getattr(self._root, name) if name != '_root' and self._root:
return getattr(self._root, name)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
signals.pre_init.send(self.__class__, args=args, kwargs=kwargs) signals.pre_init.send(self.__class__, args=args, kwargs=kwargs)

View File

@ -57,7 +57,8 @@ class Field(NamedType, OrderedType):
@property @property
def resolver(self): def resolver(self):
return self.resolver_fn or self.get_resolver_fn() resolver = self.get_resolver_fn()
return resolver
@property @property
def default(self): def default(self):
@ -70,6 +71,9 @@ class Field(NamedType, OrderedType):
self._default = value self._default = value
def get_resolver_fn(self): def get_resolver_fn(self):
if self.resolver_fn:
return self.resolver_fn
resolve_fn_name = 'resolve_%s' % self.attname resolve_fn_name = 'resolve_%s' % self.attname
if hasattr(self.object_type, resolve_fn_name): if hasattr(self.object_type, resolve_fn_name):
return getattr(self.object_type, resolve_fn_name) return getattr(self.object_type, resolve_fn_name)