diff --git a/graphene/contrib/django/filter/tests/test_fields.py b/graphene/contrib/django/filter/tests/test_fields.py index 45c1f0d0..b2591d1e 100644 --- a/graphene/contrib/django/filter/tests/test_fields.py +++ b/graphene/contrib/django/filter/tests/test_fields.py @@ -1,4 +1,7 @@ +from datetime import datetime + import pytest +from graphql.core.execution.base import ResolveInfo, ExecutionContext from graphene import ObjectType, Schema from graphene.contrib.django import DjangoNode @@ -7,6 +10,7 @@ from graphene.contrib.django.forms import (GlobalIDFormField, from graphene.contrib.django.tests.models import Article, Pet, Reporter from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene.relay import NodeField +from graphene.utils import ProxySnakeDict pytestmark = [] if DJANGO_FILTER_INSTALLED: @@ -160,6 +164,57 @@ def test_filter_filterset_information_on_meta_related(): assert_orderable(articles_field) +def test_filter_filterset_related_results(): + class ReporterFilterNode(DjangoNode): + + class Meta: + model = Reporter + filter_fields = ['first_name', 'articles'] + filter_order_by = True + + class ArticleFilterNode(DjangoNode): + + class Meta: + model = Article + filter_fields = ['headline', 'reporter'] + filter_order_by = True + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = NodeField(ReporterFilterNode) + article = NodeField(ArticleFilterNode) + + r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') + r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') + a1 = Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1) + a2 = Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2) + + query = ''' + query { + allReporters { + edges { + node { + articles { + edges { + node { + headline + } + } + } + } + } + } + } + ''' + schema = Schema(query=Query) + result = schema.execute(query) + assert not result.errors + # We should only get back a single article for each reporter + assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 + assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 + + def test_global_id_field_implicit(): field = DjangoFilterConnectionField(ArticleNode, fields=['id']) filterset_class = field.resolver_fn.get_filterset_class()