diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 1ecce45..3a0493c 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,10 +1,12 @@ from functools import partial +from collections import OrderedDict from django.db.models.query import QuerySet from promise import Promise from graphene.types import Field, List +from graphene.types.argument import to_arguments from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -29,7 +31,7 @@ class DjangoListField(Field): class DjangoConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): + def __init__(self, type, *args, **kwargs): self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT @@ -38,7 +40,15 @@ class DjangoConnectionField(ConnectionField): "enforce_first_or_last", graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) - super(DjangoConnectionField, self).__init__(*args, **kwargs) + super(DjangoConnectionField, self).__init__(type, *args, **kwargs) + + @property + def args(self): + return to_arguments(self._base_args or OrderedDict(), self.node_type.get_connection_parameters()) + + @args.setter + def args(self, args): + self._base_args = args @property def type(self): @@ -76,7 +86,7 @@ class DjangoConnectionField(ConnectionField): return queryset & default_queryset @classmethod - def resolve_connection(cls, connection, default_manager, args, iterable): + def resolve_connection(cls, connection, node, default_manager, args, info, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -84,6 +94,9 @@ class DjangoConnectionField(ConnectionField): if iterable is not default_manager: default_queryset = maybe_queryset(default_manager) iterable = cls.merge_querysets(default_queryset, iterable) + from .types import DjangoObjectType + if issubclass(node, DjangoObjectType): + iterable = node.refine_queryset(iterable, info, **args) _len = iterable.count() else: _len = len(iterable) @@ -103,15 +116,16 @@ class DjangoConnectionField(ConnectionField): @classmethod def connection_resolver( - cls, - resolver, - connection, - default_manager, - max_limit, - enforce_first_or_last, - root, - info, - **args + cls, + resolver, + connection, + node, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **args ): first = args.get("first") last = args.get("last") @@ -135,7 +149,7 @@ class DjangoConnectionField(ConnectionField): args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) @@ -147,6 +161,7 @@ class DjangoConnectionField(ConnectionField): self.connection_resolver, parent_resolver, self.type, + self.node_type, self.get_manager(), self.max_limit, self.enforce_first_or_last, diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index cb42543..340eb37 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -76,17 +76,18 @@ class DjangoFilterConnectionField(DjangoConnectionField): @classmethod def connection_resolver( - cls, - resolver, - connection, - default_manager, - max_limit, - enforce_first_or_last, - filterset_class, - filtering_args, - root, - info, - **args + cls, + resolver, + connection, + node, + default_manager, + max_limit, + enforce_first_or_last, + filterset_class, + filtering_args, + root, + info, + **args ): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( @@ -98,6 +99,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): return super(DjangoFilterConnectionField, cls).connection_resolver( resolver, connection, + node, qs, max_limit, enforce_first_or_last, @@ -111,6 +113,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.connection_resolver, parent_resolver, self.type, + self.node_type, self.get_manager(), self.max_limit, self.enforce_first_or_last, diff --git a/graphene_django/types.py b/graphene_django/types.py index aa8b5a3..1afd045 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -111,6 +111,14 @@ class DjangoObjectType(ObjectType): if not skip_registry: registry.register(cls) + @classmethod + def refine_queryset(cls, qs, info, **kwargs): + return qs + + @classmethod + def get_connection_parameters(cls): + return {} + def resolve_id(self, info): return self.pk @@ -130,6 +138,6 @@ class DjangoObjectType(ObjectType): @classmethod def get_node(cls, info, id): try: - return cls._meta.model.objects.get(pk=id) + return cls.refine_queryset(cls._meta.model.objects, info).get(pk=id) except cls._meta.model.DoesNotExist: return None