diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 77785dae..3c222470 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -1,19 +1,14 @@ import warnings -import six - -from graphene.contrib.django.filterset import setup_filterset +from graphene.contrib.django.utils import get_filtering_args_from_filterset +from .resolvers import FilterConnectionResolver +from .utils import get_type_for_model from ...core.exceptions import SkipField from ...core.fields import Field -from ...core.types import Argument, String from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node -from .form_converter import convert_form_field -from .resolvers import FilterConnectionResolver -from .utils import get_type_for_model -from .filterset import custom_filterset_factory class DjangoConnectionField(ConnectionField): @@ -69,39 +64,21 @@ class DjangoModelField(FieldType): class DjangoFilterConnectionField(DjangoConnectionField): - def __init__(self, type, filterset_class=None, resolver=None, on=None, - fields=None, order_by=None, extra_filter_meta=None, + def __init__(self, type, on=None, fields=None, order_by=None, + extra_filter_meta=None, filterset_class=None, resolver=None, *args, **kwargs): - if not filterset_class: - # If no filter class is specified then create one given the - # information provided - meta = dict( - model=type._meta.model, + if not resolver: + resolver = FilterConnectionResolver( + node=type, + on=on, + filterset_class=filterset_class, fields=fields, order_by=order_by, + extra_filter_meta=extra_filter_meta, ) - if extra_filter_meta: - meta.update(extra_filter_meta) - filterset_class = custom_filterset_factory(**meta) - else: - filterset_class = setup_filterset(filterset_class) - - if not resolver: - resolver = FilterConnectionResolver(type, on, filterset_class) + filtering_args = get_filtering_args_from_filterset(resolver.get_filterset_class(), type) kwargs.setdefault('args', {}) - kwargs['args'].update(**self.get_filtering_args(type, filterset_class)) + kwargs['args'].update(**filtering_args) super(DjangoFilterConnectionField, self).__init__(type, resolver, *args, **kwargs) - - def get_filtering_args(self, type, filterset_class): - args = {} - for name, filter_field in six.iteritems(filterset_class.base_filters): - field_type = Argument(convert_form_field(filter_field.field)) - # Is this correct? I don't quite grok the 'parent' system yet - field_type.mount(type) - args[name] = field_type - - # Also add the 'order_by' field - args[filterset_class.order_by_field] = Argument(String) - return args diff --git a/graphene/contrib/django/resolvers.py b/graphene/contrib/django/resolvers.py index 7960106b..39daecea 100644 --- a/graphene/contrib/django/resolvers.py +++ b/graphene/contrib/django/resolvers.py @@ -1,5 +1,6 @@ from django.core.exceptions import ImproperlyConfigured -from django_filters.filterset import filterset_factory + +from graphene.contrib.django.filterset import setup_filterset, custom_filterset_factory class BaseQuerySetConnectionResolver(object): @@ -50,8 +51,13 @@ class SimpleQuerySetConnectionResolver(BaseQuerySetConnectionResolver): class FilterConnectionResolver(BaseQuerySetConnectionResolver): # Querying using django-filter - def __init__(self, node, on=None, filterset_class=None): + def __init__(self, node, on=None, filterset_class=None, + fields=None, order_by=None, extra_filter_meta=None): self.filterset_class = filterset_class + self.fields = fields + self.order_by = order_by + self.extra_filter_meta = extra_filter_meta or {} + self._filterset_class = None super(FilterConnectionResolver, self).__init__(node, on) def make_query(self): @@ -60,19 +66,40 @@ class FilterConnectionResolver(BaseQuerySetConnectionResolver): return filterset.qs def get_filterset_class(self): + """Get the class to be used as the FilterSet""" + if self._filterset_class: + return self._filterset_class + if self.filterset_class: - return self.filterset_class + # If were given a FilterSet class, then set it up and + # return it + self._filterset_class = setup_filterset(self.filterset_class) elif self.model: - return filterset_factory(self.model) + # If no filter class was specified then create one given the + # other information provided + meta = dict( + model=self.model, + fields=self.fields, + order_by=self.order_by, + ) + meta.update(self.extra_filter_meta) + self._filterset_class = custom_filterset_factory(**meta) else: - msg = "'%s' must define 'filterset_class' or 'model'" + msg = "Neither 'filterset_class' or 'model' available in '%s'. " \ + "Either pass in 'filterset_class' or 'model' when " \ + "initialising, or extend this class and override " \ + "get_filterset() or get_filterset_class()" raise ImproperlyConfigured(msg % self.__class__.__name__) + return self._filterset_class + def get_filterset(self, filterset_class): + """Get an instance of the FilterSet""" kwargs = self.get_filterset_kwargs(filterset_class) return filterset_class(**kwargs) def get_filterset_kwargs(self, filterset_class): + """Get the kwargs to use when initialising the FilterSet class""" kwargs = { 'data': self.args or None, 'queryset': self.get_manager() diff --git a/graphene/contrib/django/tests/test_fields.py b/graphene/contrib/django/tests/test_fields.py index d5aeb22a..3c703796 100644 --- a/graphene/contrib/django/tests/test_fields.py +++ b/graphene/contrib/django/tests/test_fields.py @@ -98,7 +98,7 @@ def test_filter_shortcut_filterset_extra_meta(): def test_global_id_field_implicit(): field = DjangoFilterConnectionField(ArticleNode, fields=['id']) - filterset_class = field.resolver_fn.filterset_class + filterset_class = field.resolver_fn.get_filterset_class() id_filter = filterset_class.base_filters['id'] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField @@ -111,7 +111,7 @@ def test_global_id_field_explicit(): fields = ['id'] field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) - filterset_class = field.resolver_fn.filterset_class + filterset_class = field.resolver_fn.get_filterset_class() id_filter = filterset_class.base_filters['id'] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField @@ -119,7 +119,7 @@ def test_global_id_field_explicit(): def test_global_id_field_relation(): field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) - filterset_class = field.resolver_fn.filterset_class + filterset_class = field.resolver_fn.get_filterset_class() id_filter = filterset_class.base_filters['reporter'] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField diff --git a/graphene/contrib/django/tests/test_resolvers.py b/graphene/contrib/django/tests/test_resolvers.py index 9daa55b0..8f7ab7d6 100644 --- a/graphene/contrib/django/tests/test_resolvers.py +++ b/graphene/contrib/django/tests/test_resolvers.py @@ -66,7 +66,7 @@ def test_filter_get_filterset_class_explicit(): resolver = FilterConnectionResolver(ReporterNode, filterset_class=ReporterFilter) resolver(inst=reporter, args={}, info=None) - assert resolver.get_filterset_class() == ReporterFilter, \ + assert issubclass(resolver.get_filterset_class(), ReporterFilter), \ 'ReporterFilter not returned' @@ -83,7 +83,7 @@ def test_filter_get_filterset_class_error(): resolver.model = None with raises(ImproperlyConfigured) as excinfo: resolver(inst=reporter, args={}, info=None) - assert "must define 'filterset_class' or 'model'" in str(excinfo.value) + assert "Neither 'filterset_class' or 'model' available" in str(excinfo.value) def test_filter_filter(): diff --git a/graphene/contrib/django/utils.py b/graphene/contrib/django/utils.py index 54c6420c..2b4519fc 100644 --- a/graphene/contrib/django/utils.py +++ b/graphene/contrib/django/utils.py @@ -1,6 +1,10 @@ +import six from django.db import models from django.db.models.manager import Manager +from graphene import Argument, String +from graphene.contrib.django.form_converter import convert_form_field + def get_type_for_model(schema, model): schema = schema @@ -23,3 +27,20 @@ def maybe_queryset(value): if isinstance(value, Manager): value = value.get_queryset() return value + + +def get_filtering_args_from_filterset(filterset_class, type): + """ Inspect a FilterSet and produce the arguments to pass to + a Graphene Field. These arguments will be available to + filter against in the GraphQL + """ + args = {} + for name, filter_field in six.iteritems(filterset_class.base_filters): + field_type = Argument(convert_form_field(filter_field.field)) + # Is this correct? I don't quite grok the 'parent' system yet + field_type.mount(type) + args[name] = field_type + + # Also add the 'order_by' field + args[filterset_class.order_by_field] = Argument(String) + return args