diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 41e3bd68..973edc17 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -1,7 +1,7 @@ import warnings import six - +from django_filters import FilterSet from ...core.exceptions import SkipField from ...core.fields import Field @@ -66,12 +66,36 @@ class DjangoModelField(FieldType): return get_type_for_model(schema, self.model) +def custom_filterset_factory(model, filter_base_class=FilterSet, **meta): + meta.update({ + 'model': model, + }) + meta_class = type(str('Meta'), (object,), meta) + filterset = type(str('%sFilterSet' % model._meta.object_name), + (filter_base_class,), {'Meta': meta_class}) + return filterset + + class DjangoFilterConnectionField(DjangoConnectionField): - def __init__(self, type, filterset_class, resolver=None, on=None, *args, **kwargs): + def __init__(self, type, filterset_class=None, resolver=None, on=None, + fields=None, order_by=None, extra_filter_meta=None, + *args, **kwargs): if not resolver: resolver = FilterConnectionResolver(type, on, filterset_class) + if not filterset_class: + # If no filter class is specified then create one given the + # information provided + meta = dict( + model=type._meta.model, + fields=fields, + order_by=order_by, + ) + if extra_filter_meta: + meta.update(extra_filter_meta) + filterset_class = custom_filterset_factory(**meta) + kwargs.setdefault('args', {}) kwargs['args'].update(**self.get_filtering_args(type, filterset_class)) super(DjangoFilterConnectionField, self).__init__(type, resolver, *args, **kwargs) diff --git a/graphene/contrib/django/tests/test_fields.py b/graphene/contrib/django/tests/test_fields.py index 0305d6f5..1797527a 100644 --- a/graphene/contrib/django/tests/test_fields.py +++ b/graphene/contrib/django/tests/test_fields.py @@ -46,11 +46,47 @@ def test_filter_explicit_filterset_arguments(): ) +def test_filter_shortcut_filterset_arguments_list(): + field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) + assert_arguments(field, + 'pubDate', + 'reporter', + ) + + +def test_filter_shortcut_filterset_arguments_dict(): + field = DjangoFilterConnectionField(ArticleNode, fields={ + 'headline': ['exact', 'icontains'], + 'reporter': ['exact'], + }) + assert_arguments(field, + 'headline', 'headlineIcontains', + 'reporter', + ) + + def test_filter_explicit_filterset_orderable(): field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) assert_orderable(field) +def test_filter_shortcut_filterset_orderable_true(): + field = DjangoFilterConnectionField(ArticleNode, order_by=True) + assert_orderable(field) + + +def test_filter_shortcut_filterset_orderable_headline(): + field = DjangoFilterConnectionField(ArticleNode, order_by=['headline']) + assert_orderable(field) + + def test_filter_explicit_filterset_not_orderable(): field = DjangoFilterConnectionField(PetNode, filterset_class=PetFilter) assert_not_orderable(field) + + +def test_filter_shortcut_filterset_extra_meta(): + field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ + 'ordering': True + }) + assert_orderable(field)