Test manual optimization

This commit is contained in:
Jacob Foster 2017-07-19 22:47:27 -05:00
parent 34e6a90df8
commit 582958c642

View File

@ -3,22 +3,15 @@ from django.db import connection
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
import graphene
import pytest
from .. import registry
from ..fields import DjangoConnectionField
from ..fields import DjangoConnectionField, DjangoListField
from ..optimization import optimize_queryset
from ..types import DjangoObjectType
from .models import (
Article as ArticleModel,
Reporter as ReporterModel,
Pet as PetModel
Reporter as ReporterModel
)
pytestmark = pytest.mark.django_db
registry.reset_global_registry()
class Article(DjangoObjectType):
class Meta:
@ -27,24 +20,36 @@ class Article(DjangoObjectType):
class Reporter(DjangoObjectType):
favorite_pet = graphene.Field(lambda: Reporter)
class Meta:
model = ReporterModel
#interfaces = (graphene.relay.Node,)
optimizations = {
'favorite_pet': {
'prefetch': ['pets']
}
}
class Pet(DjangoObjectType):
class Meta:
model = PetModel
def resolve_favorite_pet(self, *args):
for pet in self.pets.all():
if pet.last_name == 'Kent':
return pet
class RootQuery(graphene.ObjectType):
article = graphene.Field(Article, id=graphene.ID())
articles = DjangoConnectionField(Article)
reporters = DjangoListField(Reporter)
def resolve_article(self, args, context, info):
qs = ArticleModel.objects
qs = optimize_queryset(ArticleModel, qs, info.field_asts[0])
return qs.get(**args)
def resolve_reporters(self, args, context, info):
return ReporterModel.objects
schema = graphene.Schema(query=RootQuery)
@ -124,3 +129,25 @@ class TestOptimization(TestCase):
assert len(returned_articles) == 2
self.assertEqual(len(query_context.captured_queries), 4)
def test_manual(self):
query = """query {
reporters {
email
favoritePet {
email
}
}
}"""
with CaptureQueriesContext(connection) as query_context:
results = schema.execute(query)
returned_reporters = results.data['reporters']
assert len(returned_reporters) == 2
returned_editor = [reporter for reporter in returned_reporters
if reporter['email'] == self.editor.email][0]
assert returned_editor['favoritePet']['email'] == self.reporter.email
self.assertEqual(len(query_context.captured_queries), 2)