diff --git a/graphene/contrib/django/filter/__init__.py b/graphene/contrib/django/filter/__init__.py index 9df733cd..21e65b56 100644 --- a/graphene/contrib/django/filter/__init__.py +++ b/graphene/contrib/django/filter/__init__.py @@ -1,6 +1,7 @@ from .fields import DjangoFilterConnectionField -from .filterset import GrapheneFilterSet, GlobalIDFilter +from .filterset import GrapheneFilterSet, GlobalIDFilter, GlobalIDMultipleChoiceFilter from .resolvers import FilterConnectionResolver __all__ = ['DjangoFilterConnectionField', 'GrapheneFilterSet', - 'GlobalIDFilter', 'FilterConnectionResolver'] + 'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter', + 'FilterConnectionResolver'] diff --git a/graphene/contrib/django/filter/filterset.py b/graphene/contrib/django/filter/filterset.py index 430c0eb4..50fe288c 100644 --- a/graphene/contrib/django/filter/filterset.py +++ b/graphene/contrib/django/filter/filterset.py @@ -1,11 +1,12 @@ import six from django.conf import settings from django.db import models -from django_filters import Filter +from django.utils.text import capfirst +from django_filters import Filter, MultipleChoiceFilter from django_filters.filterset import FilterSetMetaclass, FilterSet from graphql_relay.node.node import from_global_id -from graphene.contrib.django.forms import GlobalIDFormField +from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField class GlobalIDFilter(Filter): @@ -16,6 +17,14 @@ class GlobalIDFilter(Filter): return super(GlobalIDFilter, self).filter(qs, gid.id) +class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): + field_class = GlobalIDMultipleChoiceField + + def filter(self, qs, value): + gids = [from_global_id(v).id for v in value] + return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) + + ORDER_BY_FIELD = getattr(settings, 'GRAPHENE_ORDER_BY_FIELD', 'order') @@ -28,8 +37,10 @@ GRAPHENE_FILTER_SET_OVERRIDES = { }, models.ForeignKey: { 'filter_class': GlobalIDFilter, + }, + models.ManyToManyField: { + 'filter_class': GlobalIDMultipleChoiceFilter, } - # TODO: Support ManyToManyFields. GlobalIDFilterList? } @@ -42,14 +53,30 @@ class GrapheneFilterSetMetaclass(FilterSetMetaclass): return new_class -class GrapheneFilterSet(six.with_metaclass(GrapheneFilterSetMetaclass, FilterSet)): +class GrapheneFilterSetMixin(object): + order_by_field = ORDER_BY_FIELD + + @classmethod + def filter_for_reverse_field(cls, f, name): + rel = f.field.rel + default = { + 'name': name, + 'label': capfirst(rel.related_name) + } + if rel.multiple: + return GlobalIDMultipleChoiceFilter(**default) + else: + return GlobalIDFilter(**default) + + +class GrapheneFilterSet(six.with_metaclass(GrapheneFilterSetMetaclass, GrapheneFilterSetMixin, FilterSet)): """ Base class for FilterSets used by Graphene You shouldn't usually need to use this class. The DjangoFilterConnectionField will wrap FilterSets with this class as necessary """ - order_by_field = ORDER_BY_FIELD + pass def setup_filterset(filterset_class): @@ -57,10 +84,8 @@ def setup_filterset(filterset_class): """ return type( 'Graphene{}'.format(filterset_class.__name__), - (six.with_metaclass(GrapheneFilterSetMetaclass, filterset_class),), - { - 'order_by_field': ORDER_BY_FIELD - }, + (six.with_metaclass(GrapheneFilterSetMetaclass, GrapheneFilterSetMixin, filterset_class),), + {}, ) diff --git a/graphene/contrib/django/form_converter.py b/graphene/contrib/django/form_converter.py index 7229b968..f5acf202 100644 --- a/graphene/contrib/django/form_converter.py +++ b/graphene/contrib/django/form_converter.py @@ -3,6 +3,8 @@ from django.forms.fields import BaseTemporalField from singledispatch import singledispatch from graphene import String, Int, Boolean, Float, ID +from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField +from graphene.core.types.definitions import List try: UUIDField = forms.UUIDField @@ -57,13 +59,12 @@ def convert_form_field_to_float(field): @convert_form_field.register(forms.ModelMultipleChoiceField) +@convert_form_field.register(GlobalIDMultipleChoiceField) def convert_form_field_to_list_or_connection(field): - # TODO: Consider how filtering on a many-to-many should work - from .fields import DjangoModelField, ConnectionOrListField - model_field = DjangoModelField(field.queryset.model) - return ConnectionOrListField(model_field) + return List(ID()) @convert_form_field.register(forms.ModelChoiceField) +@convert_form_field.register(GlobalIDFormField) def convert_form_field_to_djangomodel(field): return ID() diff --git a/graphene/contrib/django/forms.py b/graphene/contrib/django/forms.py index b8fc9106..f971897b 100644 --- a/graphene/contrib/django/forms.py +++ b/graphene/contrib/django/forms.py @@ -1,7 +1,7 @@ import binascii from django.core.exceptions import ValidationError -from django.forms import Field, IntegerField, CharField +from django.forms import Field, IntegerField, CharField, MultipleChoiceField from django.utils.translation import ugettext_lazy as _ from graphql_relay import from_global_id @@ -28,3 +28,15 @@ class GlobalIDFormField(Field): raise ValidationError(self.error_messages['invalid']) return value + + +class GlobalIDMultipleChoiceField(MultipleChoiceField): + default_error_messages = { + 'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), + 'invalid_list': _('Enter a list of values.'), + } + + def valid_value(self, value): + # Clean will raise a validation error if there is a problem + GlobalIDFormField().clean(value) + return True diff --git a/graphene/contrib/django/tests/filter/test_fields.py b/graphene/contrib/django/tests/filter/test_fields.py index 14170977..fc99a273 100644 --- a/graphene/contrib/django/tests/filter/test_fields.py +++ b/graphene/contrib/django/tests/filter/test_fields.py @@ -5,12 +5,13 @@ try: except ImportError: pytestmark = pytest.mark.skipif(True, reason='django_filters not installed') else: - from graphene.contrib.django.filter import GlobalIDFilter, DjangoFilterConnectionField + from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField, + GlobalIDMultipleChoiceFilter) from graphene.contrib.django.tests.filter.filters import ArticleFilter, PetFilter from graphene.contrib.django import DjangoNode -from graphene.contrib.django.forms import GlobalIDFormField -from graphene.contrib.django.tests.models import Article, Pet +from graphene.contrib.django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField +from graphene.contrib.django.tests.models import Article, Pet, Reporter class ArticleNode(DjangoNode): @@ -18,6 +19,11 @@ class ArticleNode(DjangoNode): model = Article +class ReporterNode(DjangoNode): + class Meta: + model = Reporter + + class PetNode(DjangoNode): class Meta: model = Pet @@ -129,3 +135,51 @@ def test_global_id_field_relation(): id_filter = filterset_class.base_filters['reporter'] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField + + +def test_global_id_multiple_field_implicit(): + field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) + filterset_class = field.resolver_fn.get_filterset_class() + multiple_filter = filterset_class.base_filters['pets'] + assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) + assert multiple_filter.field_class == GlobalIDMultipleChoiceField + + +def test_global_id_multiple_field_explicit(): + class ReporterPetsFilter(django_filters.FilterSet): + class Meta: + model = Reporter + fields = ['pets'] + + field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + filterset_class = field.resolver_fn.get_filterset_class() + multiple_filter = filterset_class.base_filters['pets'] + assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) + assert multiple_filter.field_class == GlobalIDMultipleChoiceField + + +@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field" + "but django has not loaded the models, so the test fails as " + "reverse relations are not ready yet") +def test_global_id_multiple_field_implicit_reverse(): + field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) + filterset_class = field.resolver_fn.get_filterset_class() + multiple_filter = filterset_class.base_filters['articles'] + assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) + assert multiple_filter.field_class == GlobalIDMultipleChoiceField + + +@pytest.mark.skipif(True, reason="Trying to test GrapheneFilterSetMixin.filter_for_reverse_field" + "but django has not loaded the models, so the test fails as " + "reverse relations are not ready yet") +def test_global_id_multiple_field_explicit_reverse(): + class ReporterPetsFilter(django_filters.FilterSet): + class Meta: + model = Reporter + fields = ['articles'] + + field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + filterset_class = field.resolver_fn.get_filterset_class() + multiple_filter = filterset_class.base_filters['articles'] + assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) + assert multiple_filter.field_class == GlobalIDMultipleChoiceField diff --git a/graphene/contrib/django/tests/test_form_converter.py b/graphene/contrib/django/tests/test_form_converter.py index 451f91ad..7492fc51 100644 --- a/graphene/contrib/django/tests/test_form_converter.py +++ b/graphene/contrib/django/tests/test_form_converter.py @@ -1,10 +1,10 @@ from django import forms +from graphene.core.types import List, ID from py.test import raises import graphene from graphene.contrib.django.form_converter import convert_form_field -from graphene.contrib.django.fields import (ConnectionOrListField, - DjangoModelField) + from .models import Reporter @@ -94,9 +94,8 @@ def test_should_decimal_convert_float(): def test_should_multiple_choice_convert_connectionorlist(): field = forms.ModelMultipleChoiceField(Reporter.objects.all()) graphene_type = convert_form_field(field) - assert isinstance(graphene_type, ConnectionOrListField) - assert isinstance(graphene_type.type, DjangoModelField) - assert graphene_type.type.model == Reporter + assert isinstance(graphene_type, List) + assert isinstance(graphene_type.of_type, ID) def test_should_manytoone_convert_connectionorlist():