diff --git a/graphene/contrib/django/tests/test_optimizequery.py b/graphene/contrib/django/tests/test_optimizequery.py index 0a7c4028..dfde81f1 100644 --- a/graphene/contrib/django/tests/test_optimizequery.py +++ b/graphene/contrib/django/tests/test_optimizequery.py @@ -1,9 +1,13 @@ +from functools import wraps from graphql.core.utils.get_field_def import get_field_def import pytest import graphene from graphene.contrib.django import DjangoObjectType +from graphql.core.type.definition import GraphQLList, GraphQLNonNull + + from ..tests.models import Reporter from ..debug.plugin import DjangoDebugPlugin @@ -11,6 +15,27 @@ from ..debug.plugin import DjangoDebugPlugin pytestmark = pytest.mark.django_db +def get_fields(info): + field_asts = info.field_asts[0].selection_set.selections + only_args = [] + _type = info.return_type + if isinstance(_type, (GraphQLList, GraphQLNonNull)): + _type = _type.of_type + + for field in field_asts: + field_def = get_field_def(info.schema, _type, field) + f = field_def.resolver + fetch_field = getattr(f, 'django_fetch_field', None) + if fetch_field: + only_args.append(fetch_field) + return only_args + +def fetch_only_required(f): + @wraps(f) + def wrapper(*args): + info = args[-1] + return f(*args).only(*get_fields(info)) + return wrapper def test_should_query_well(): r1 = Reporter(last_name='ABA') @@ -27,24 +52,12 @@ def test_should_query_well(): reporter = graphene.Field(ReporterType) all_reporters = ReporterType.List() + @fetch_only_required def resolve_all_reporters(self, args, info): - queryset = Reporter.objects.all() - # from graphql.core.execution.base import collect_fields - # print info.field_asts[0], info.parent_type, info.return_type.of_type - # field_asts = collect_fields(info.context, info.parent_type, info.field_asts[0], {}, set()) - # field_asts = info.field_asts - field_asts = info.field_asts[0].selection_set.selections - only_args = [] - for field in field_asts: - field_def = get_field_def(info.schema, info.return_type.of_type, field) - f = field_def.resolver - fetch_field = getattr(f, 'django_fetch_field') - only_args.append(fetch_field) - queryset = queryset.only(*only_args) - return queryset + return Reporter.objects.all() - def resolve_reporter(self, *args, **kwargs): - return Reporter.objects.first() + def resolve_reporter(self, args, info): + return Reporter.objects.only(*get_fields(info)).first() query = ''' query ReporterQuery { @@ -52,6 +65,9 @@ def test_should_query_well(): lastName email } + reporter { + email + } __debug { sql { rawSql @@ -67,9 +83,14 @@ def test_should_query_well(): 'lastName': 'Griffin', 'email': '', }], + 'reporter': { + 'email': '' + }, '__debug': { 'sql': [{ 'rawSql': str(Reporter.objects.all().only('last_name', 'email').query) + }, { + 'rawSql': str(Reporter.objects.only('email').order_by('pk')[:1].query) }] } }