add queryset_resolver to DjangoConnectionField and fix test failures

This commit is contained in:
Jason Kraus 2018-10-05 09:35:34 -07:00
parent eb938da353
commit 857b20a2f2

View File

@ -66,6 +66,20 @@ class DjangoConnectionField(ConnectionField):
return getattr(self.model, self.on) return getattr(self.model, self.on)
else: else:
return self.model._default_manager return self.model._default_manager
@classmethod
def resolve_queryset(cls, connection, queryset, info):
default_queryset = maybe_queryset(connection._meta.node.get_queryset(info))
if queryset:
return cls.merge_querysets(default_queryset, queryset)
return default_queryset
def get_queryset_resolver(self):
return partial(
self.resolve_queryset,
self.type,
maybe_queryset(getattr(self.model, self.on)) if self.on else None,
)
@classmethod @classmethod
def merge_querysets(cls, default_queryset, queryset): def merge_querysets(cls, default_queryset, queryset):
@ -106,7 +120,7 @@ class DjangoConnectionField(ConnectionField):
cls, cls,
resolver, resolver,
connection, connection,
default_manager, queryset_resolver,
max_limit, max_limit,
enforce_first_or_last, enforce_first_or_last,
root, root,
@ -135,7 +149,13 @@ class DjangoConnectionField(ConnectionField):
args["last"] = min(last, max_limit) args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)
queryset = connection._meta.node.get_queryset(info) if callable(queryset_resolver):
queryset = queryset_resolver(info)
else:
assert isinstance(queryset_resolver, QuerySet), "The type {} is not a QuerySet".format(
type(queryset_resolver)
)
queryset = queryset_resolver
on_resolve = partial(cls.resolve_connection, connection, queryset, args) on_resolve = partial(cls.resolve_connection, connection, queryset, args)
if Promise.is_thenable(iterable): if Promise.is_thenable(iterable):
@ -148,7 +168,7 @@ class DjangoConnectionField(ConnectionField):
self.connection_resolver, self.connection_resolver,
parent_resolver, parent_resolver,
self.type, self.type,
self.get_manager(), self.get_queryset_resolver(),
self.max_limit, self.max_limit,
self.enforce_first_or_last, self.enforce_first_or_last,
) )