diff --git a/graphene-django/graphene_django/filter/fields.py b/graphene-django/graphene_django/filter/fields.py index c984f913..f4f84e29 100644 --- a/graphene-django/graphene_django/filter/fields.py +++ b/graphene-django/graphene_django/filter/fields.py @@ -1,7 +1,7 @@ +from functools import partial from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class -# from graphene.types.argument import to_arguments class DjangoFilterConnectionField(DjangoConnectionField): @@ -20,19 +20,20 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type) kwargs.setdefault('args', {}) kwargs['args'].update(self.filtering_args) - # kwargs['args'].update(to_arguments(self.filtering_args)) super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) - def get_queryset(self, qs, args, info): - filterset_class = self.filterset_class - filter_kwargs = self.get_filter_kwargs(args) - order = self.get_order(args) + @staticmethod + def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, + root, args, context, info): + filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} + order = args.get('order_by', None) + qs = default_manager.get_queryset() if order: qs = qs.order_by(order) - return filterset_class(data=filter_kwargs, queryset=qs) + qs = filterset_class(data=filter_kwargs, queryset=qs) - def get_filter_kwargs(self, args): - return {k: v for k, v in args.items() if k in self.filtering_args} + return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) - def get_order(self, args): - return args.get('order_by', None) + def get_resolver(self, parent_resolver): + return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), + self.filterset_class, self.filtering_args) diff --git a/graphene-django/graphene_django/filter/tests/test_fields.py b/graphene-django/graphene_django/filter/tests/test_fields.py index 1d2fe078..4735ee0a 100644 --- a/graphene-django/graphene_django/filter/tests/test_fields.py +++ b/graphene-django/graphene_django/filter/tests/test_fields.py @@ -302,3 +302,38 @@ def test_global_id_multiple_field_explicit_reverse(): multiple_filter = filterset_class.base_filters['articles'] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField + + +def test_filter_filterset_related_results(): + class ReporterFilterNode(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + filter_fields = { + 'first_name': ['icontains'] + } + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + + r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') + r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') + r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') + + query = ''' + query { + allReporters(firstName_Icontains: "test") { + edges { + node { + id + } + } + } + } + ''' + schema = Schema(query=Query) + result = schema.execute(query) + assert not result.errors + # We should only get two reporters + assert len(result.data['allReporters']['edges']) == 2