replace merge_queryset with resolve_queryset pattern

This commit is contained in:
Jason Kraus 2019-10-11 14:15:27 -07:00
parent 8d95596ffb
commit c29aa7c8c4
3 changed files with 30 additions and 81 deletions

View File

@ -39,10 +39,8 @@ class DjangoListField(Field):
if queryset is None: if queryset is None:
# Default to Django Model queryset # Default to Django Model queryset
# N.B. This happens if DjangoListField is used in the top level Query object # N.B. This happens if DjangoListField is used in the top level Query object
model = django_object_type._meta.model model = django_object_type._meta.model.objects
queryset = maybe_queryset( queryset = maybe_queryset(django_object_type.get_queryset(model, info))
django_object_type.get_queryset(model.objects, info)
)
return queryset return queryset
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
@ -108,25 +106,13 @@ class DjangoConnectionField(ConnectionField):
@classmethod @classmethod
def resolve_queryset(cls, connection, queryset, info, args): def resolve_queryset(cls, connection, queryset, info, args):
# queryset is the resolved iterable from ObjectType
return connection._meta.node.get_queryset(queryset, info) return connection._meta.node.get_queryset(queryset, info)
@classmethod @classmethod
def merge_querysets(cls, default_queryset, queryset): def resolve_connection(cls, connection, args, iterable):
if default_queryset.query.distinct and not queryset.query.distinct:
queryset = queryset.distinct()
elif queryset.query.distinct and not default_queryset.query.distinct:
default_queryset = default_queryset.distinct()
return queryset & default_queryset
@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable) iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet): if isinstance(iterable, QuerySet):
if iterable.model.objects is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
_len = iterable.count() _len = iterable.count()
else: else:
_len = len(iterable) _len = len(iterable)
@ -150,6 +136,7 @@ class DjangoConnectionField(ConnectionField):
resolver, resolver,
connection, connection,
default_manager, default_manager,
queryset_resolver,
max_limit, max_limit,
enforce_first_or_last, enforce_first_or_last,
root, root,
@ -177,9 +164,15 @@ class DjangoConnectionField(ConnectionField):
).format(last, info.field_name, max_limit) ).format(last, info.field_name, max_limit)
args["last"] = min(last, max_limit) args["last"] = min(last, max_limit)
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)
queryset = cls.resolve_queryset(connection, default_manager, info, args) if iterable is None:
on_resolve = partial(cls.resolve_connection, connection, queryset, args) iterable = default_manager
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)
on_resolve = partial(cls.resolve_connection, connection, args)
if Promise.is_thenable(iterable): if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve) return Promise.resolve(iterable).then(on_resolve)
@ -192,6 +185,10 @@ class DjangoConnectionField(ConnectionField):
parent_resolver, parent_resolver,
self.connection_type, self.connection_type,
self.get_manager(), self.get_manager(),
self.get_queryset_resolver(),
self.max_limit, self.max_limit,
self.enforce_first_or_last, self.enforce_first_or_last,
) )
def get_queryset_resolver(self):
return self.resolve_queryset

View File

@ -52,69 +52,17 @@ class DjangoFilterConnectionField(DjangoConnectionField):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type) return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
@classmethod @classmethod
def merge_querysets(cls, default_queryset, queryset): def resolve_queryset(
# There could be the case where the default queryset (returned from the filterclass) cls, connection, iterable, info, args, filtering_args, filterset_class
# and the resolver queryset have some limits on it.
# We only would be able to apply one of those, but not both
# at the same time.
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (
default_queryset.query.low_mark and queryset.query.low_mark
), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (
default_queryset.query.high_mark and queryset.query.high_mark
), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits()
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high)
return queryset
@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
): ):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( return filterset_class(
data=filter_kwargs, data=filter_kwargs, queryset=iterable, request=info.context
queryset=default_manager.get_queryset(),
request=info.context,
).qs ).qs
return super(DjangoFilterConnectionField, cls).connection_resolver( def get_queryset_resolver(self):
resolver,
connection,
qs,
max_limit,
enforce_first_or_last,
root,
info,
**args
)
def get_resolver(self, parent_resolver):
return partial( return partial(
self.connection_resolver, self.resolve_queryset,
parent_resolver, filterset_class=self.filterset_class,
self.connection_type, filtering_args=self.filtering_args,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,
self.filterset_class,
self.filtering_args,
) )

View File

@ -638,6 +638,8 @@ def test_should_error_if_first_is_greater_than_max():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
assert Query.all_reporters.max_limit == 100
r = Reporter.objects.create( r = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -679,6 +681,8 @@ def test_should_error_if_last_is_greater_than_max():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
assert Query.all_reporters.max_limit == 100
r = Reporter.objects.create( r = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -804,7 +808,7 @@ def test_should_query_connectionfields_with_manager():
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
query = """ query = """
query ReporterLastQuery { query ReporterLastQuery {
allReporters(first: 2) { allReporters(first: 1) {
edges { edges {
node { node {
id id