mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 01:47:57 +03:00 
			
		
		
		
	replace merge_queryset with resolve_queryset pattern (#796)
* replace merge_queryset with resolve_queryset pattern * skip double limit test * Update graphene_django/fields.py Co-Authored-By: Jonathan Kim <jkimbo@gmail.com> * yank skipped test * fix bad variable ref * add test for annotations * add test for using queryset with django filters * document ththat one should use defer instead of values with queysets and DjangoObjectTypes
This commit is contained in:
		
							parent
							
								
									3ce44908c9
								
							
						
					
					
						commit
						a818ec9017
					
				| 
						 | 
				
			
			@ -282,6 +282,13 @@ of Django's ``HTTPRequest`` in your resolve methods, such as checking for authen
 | 
			
		|||
            return Question.objects.none()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DjangoObjectTypes
 | 
			
		||||
~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
A Resolver that maps to a defined `DjangoObjectType` should only use methods that return a queryset.
 | 
			
		||||
Queryset methods like `values` will return dictionaries, use `defer` instead.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Plain ObjectTypes
 | 
			
		||||
-----------------
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,9 +39,9 @@ class DjangoListField(Field):
 | 
			
		|||
        if queryset is None:
 | 
			
		||||
            # Default to Django Model queryset
 | 
			
		||||
            # N.B. This happens if DjangoListField is used in the top level Query object
 | 
			
		||||
            model = django_object_type._meta.model
 | 
			
		||||
            model_manager = django_object_type._meta.model.objects
 | 
			
		||||
            queryset = maybe_queryset(
 | 
			
		||||
                django_object_type.get_queryset(model.objects, info)
 | 
			
		||||
                django_object_type.get_queryset(model_manager, info)
 | 
			
		||||
            )
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -108,25 +108,13 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_queryset(cls, connection, queryset, info, args):
 | 
			
		||||
        # queryset is the resolved iterable from ObjectType
 | 
			
		||||
        return connection._meta.node.get_queryset(queryset, info)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def merge_querysets(cls, default_queryset, queryset):
 | 
			
		||||
        if default_queryset.query.distinct and not queryset.query.distinct:
 | 
			
		||||
            queryset = queryset.distinct()
 | 
			
		||||
        elif queryset.query.distinct and not default_queryset.query.distinct:
 | 
			
		||||
            default_queryset = default_queryset.distinct()
 | 
			
		||||
        return queryset & default_queryset
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_connection(cls, connection, default_manager, args, iterable):
 | 
			
		||||
        if iterable is None:
 | 
			
		||||
            iterable = default_manager
 | 
			
		||||
    def resolve_connection(cls, connection, args, iterable):
 | 
			
		||||
        iterable = maybe_queryset(iterable)
 | 
			
		||||
        if isinstance(iterable, QuerySet):
 | 
			
		||||
            if iterable.model.objects is not default_manager:
 | 
			
		||||
                default_queryset = maybe_queryset(default_manager)
 | 
			
		||||
                iterable = cls.merge_querysets(default_queryset, iterable)
 | 
			
		||||
            _len = iterable.count()
 | 
			
		||||
        else:
 | 
			
		||||
            _len = len(iterable)
 | 
			
		||||
| 
						 | 
				
			
			@ -150,6 +138,7 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
        resolver,
 | 
			
		||||
        connection,
 | 
			
		||||
        default_manager,
 | 
			
		||||
        queryset_resolver,
 | 
			
		||||
        max_limit,
 | 
			
		||||
        enforce_first_or_last,
 | 
			
		||||
        root,
 | 
			
		||||
| 
						 | 
				
			
			@ -177,9 +166,15 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
                ).format(last, info.field_name, max_limit)
 | 
			
		||||
                args["last"] = min(last, max_limit)
 | 
			
		||||
 | 
			
		||||
        # eventually leads to DjangoObjectType's get_queryset (accepts queryset)
 | 
			
		||||
        # or a resolve_foo (does not accept queryset)
 | 
			
		||||
        iterable = resolver(root, info, **args)
 | 
			
		||||
        queryset = cls.resolve_queryset(connection, default_manager, info, args)
 | 
			
		||||
        on_resolve = partial(cls.resolve_connection, connection, queryset, args)
 | 
			
		||||
        if iterable is None:
 | 
			
		||||
            iterable = default_manager
 | 
			
		||||
        # thus the iterable gets refiltered by resolve_queryset
 | 
			
		||||
        # but iterable might be promise
 | 
			
		||||
        iterable = queryset_resolver(connection, iterable, info, args)
 | 
			
		||||
        on_resolve = partial(cls.resolve_connection, connection, args)
 | 
			
		||||
 | 
			
		||||
        if Promise.is_thenable(iterable):
 | 
			
		||||
            return Promise.resolve(iterable).then(on_resolve)
 | 
			
		||||
| 
						 | 
				
			
			@ -192,6 +187,10 @@ class DjangoConnectionField(ConnectionField):
 | 
			
		|||
            parent_resolver,
 | 
			
		||||
            self.connection_type,
 | 
			
		||||
            self.get_manager(),
 | 
			
		||||
            self.get_queryset_resolver(),
 | 
			
		||||
            self.max_limit,
 | 
			
		||||
            self.enforce_first_or_last,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_queryset_resolver(self):
 | 
			
		||||
        return self.resolve_queryset
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,69 +52,17 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
        return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def merge_querysets(cls, default_queryset, queryset):
 | 
			
		||||
        # There could be the case where the default queryset (returned from the filterclass)
 | 
			
		||||
        # and the resolver queryset have some limits on it.
 | 
			
		||||
        # We only would be able to apply one of those, but not both
 | 
			
		||||
        # at the same time.
 | 
			
		||||
 | 
			
		||||
        # See related PR: https://github.com/graphql-python/graphene-django/pull/126
 | 
			
		||||
 | 
			
		||||
        assert not (
 | 
			
		||||
            default_queryset.query.low_mark and queryset.query.low_mark
 | 
			
		||||
        ), "Received two sliced querysets (low mark) in the connection, please slice only in one."
 | 
			
		||||
        assert not (
 | 
			
		||||
            default_queryset.query.high_mark and queryset.query.high_mark
 | 
			
		||||
        ), "Received two sliced querysets (high mark) in the connection, please slice only in one."
 | 
			
		||||
        low = default_queryset.query.low_mark or queryset.query.low_mark
 | 
			
		||||
        high = default_queryset.query.high_mark or queryset.query.high_mark
 | 
			
		||||
        default_queryset.query.clear_limits()
 | 
			
		||||
        queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
 | 
			
		||||
            default_queryset, queryset
 | 
			
		||||
        )
 | 
			
		||||
        queryset.query.set_limits(low, high)
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def connection_resolver(
 | 
			
		||||
        cls,
 | 
			
		||||
        resolver,
 | 
			
		||||
        connection,
 | 
			
		||||
        default_manager,
 | 
			
		||||
        max_limit,
 | 
			
		||||
        enforce_first_or_last,
 | 
			
		||||
        filterset_class,
 | 
			
		||||
        filtering_args,
 | 
			
		||||
        root,
 | 
			
		||||
        info,
 | 
			
		||||
        **args
 | 
			
		||||
    def resolve_queryset(
 | 
			
		||||
        cls, connection, iterable, info, args, filtering_args, filterset_class
 | 
			
		||||
    ):
 | 
			
		||||
        filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
 | 
			
		||||
        qs = filterset_class(
 | 
			
		||||
            data=filter_kwargs,
 | 
			
		||||
            queryset=default_manager.get_queryset(),
 | 
			
		||||
            request=info.context,
 | 
			
		||||
        return filterset_class(
 | 
			
		||||
            data=filter_kwargs, queryset=iterable, request=info.context
 | 
			
		||||
        ).qs
 | 
			
		||||
 | 
			
		||||
        return super(DjangoFilterConnectionField, cls).connection_resolver(
 | 
			
		||||
            resolver,
 | 
			
		||||
            connection,
 | 
			
		||||
            qs,
 | 
			
		||||
            max_limit,
 | 
			
		||||
            enforce_first_or_last,
 | 
			
		||||
            root,
 | 
			
		||||
            info,
 | 
			
		||||
            **args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_resolver(self, parent_resolver):
 | 
			
		||||
    def get_queryset_resolver(self):
 | 
			
		||||
        return partial(
 | 
			
		||||
            self.connection_resolver,
 | 
			
		||||
            parent_resolver,
 | 
			
		||||
            self.connection_type,
 | 
			
		||||
            self.get_manager(),
 | 
			
		||||
            self.max_limit,
 | 
			
		||||
            self.enforce_first_or_last,
 | 
			
		||||
            self.filterset_class,
 | 
			
		||||
            self.filtering_args,
 | 
			
		||||
            self.resolve_queryset,
 | 
			
		||||
            filterset_class=self.filterset_class,
 | 
			
		||||
            filtering_args=self.filtering_args,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -608,58 +608,6 @@ def test_should_query_filter_node_limit():
 | 
			
		|||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_query_filter_node_double_limit_raises():
 | 
			
		||||
    class ReporterFilter(FilterSet):
 | 
			
		||||
        limit = NumberFilter(method="filter_limit")
 | 
			
		||||
 | 
			
		||||
        def filter_limit(self, queryset, name, value):
 | 
			
		||||
            return queryset[:value]
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            fields = ["first_name"]
 | 
			
		||||
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(
 | 
			
		||||
            ReporterType, filterset_class=ReporterFilter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.order_by("a_choice")[:2]
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
 | 
			
		||||
    )
 | 
			
		||||
    Reporter.objects.create(
 | 
			
		||||
        first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(limit: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
                    node {
 | 
			
		||||
                        id
 | 
			
		||||
                        firstName
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert len(result.errors) == 1
 | 
			
		||||
    assert str(result.errors[0]) == (
 | 
			
		||||
        "Received two sliced querysets (high mark) in the connection, please slice only in one."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_order_by_is_perserved():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
| 
						 | 
				
			
			@ -721,7 +669,7 @@ def test_order_by_is_perserved():
 | 
			
		|||
    assert reverse_result.data == reverse_expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_annotation_is_perserved():
 | 
			
		||||
def test_annotation_is_preserved():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        full_name = String()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -766,6 +714,48 @@ def test_annotation_is_perserved():
 | 
			
		|||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_annotation_with_only():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        full_name = String()
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (Node,)
 | 
			
		||||
            filter_fields = ()
 | 
			
		||||
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterType)
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, info, **args):
 | 
			
		||||
            return Reporter.objects.only("first_name", "last_name").annotate(
 | 
			
		||||
                full_name=Concat(
 | 
			
		||||
                    "first_name", Value(" "), "last_name", output_field=TextField()
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    Reporter.objects.create(first_name="John", last_name="Doe")
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
        query NodeFilteringQuery {
 | 
			
		||||
            allReporters(first: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
                    node {
 | 
			
		||||
                        fullName
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
    expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
 | 
			
		||||
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_integer_field_filter_type():
 | 
			
		||||
    class PetType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -638,6 +638,8 @@ def test_should_error_if_first_is_greater_than_max():
 | 
			
		|||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoConnectionField(ReporterType)
 | 
			
		||||
 | 
			
		||||
    assert Query.all_reporters.max_limit == 100
 | 
			
		||||
 | 
			
		||||
    r = Reporter.objects.create(
 | 
			
		||||
        first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -679,6 +681,8 @@ def test_should_error_if_last_is_greater_than_max():
 | 
			
		|||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoConnectionField(ReporterType)
 | 
			
		||||
 | 
			
		||||
    assert Query.all_reporters.max_limit == 100
 | 
			
		||||
 | 
			
		||||
    r = Reporter.objects.create(
 | 
			
		||||
        first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -804,7 +808,7 @@ def test_should_query_connectionfields_with_manager():
 | 
			
		|||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    query = """
 | 
			
		||||
        query ReporterLastQuery {
 | 
			
		||||
            allReporters(first: 2) {
 | 
			
		||||
            allReporters(first: 1) {
 | 
			
		||||
                edges {
 | 
			
		||||
                    node {
 | 
			
		||||
                        id
 | 
			
		||||
| 
						 | 
				
			
			@ -1116,3 +1120,55 @@ def test_should_preserve_prefetch_related(django_assert_num_queries):
 | 
			
		|||
    with django_assert_num_queries(3) as captured:
 | 
			
		||||
        result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_should_preserve_annotations():
 | 
			
		||||
    class ReporterType(DjangoObjectType):
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Reporter
 | 
			
		||||
            interfaces = (graphene.relay.Node,)
 | 
			
		||||
 | 
			
		||||
    class FilmType(DjangoObjectType):
 | 
			
		||||
        reporters = DjangoConnectionField(ReporterType)
 | 
			
		||||
        reporters_count = graphene.Int()
 | 
			
		||||
 | 
			
		||||
        class Meta:
 | 
			
		||||
            model = Film
 | 
			
		||||
            interfaces = (graphene.relay.Node,)
 | 
			
		||||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        films = DjangoConnectionField(FilmType)
 | 
			
		||||
 | 
			
		||||
        def resolve_films(root, info):
 | 
			
		||||
            qs = Film.objects.prefetch_related("reporters")
 | 
			
		||||
            return qs.annotate(reporters_count=models.Count("reporters"))
 | 
			
		||||
 | 
			
		||||
    r1 = Reporter.objects.create(first_name="Dave", last_name="Smith")
 | 
			
		||||
    r2 = Reporter.objects.create(first_name="Jane", last_name="Doe")
 | 
			
		||||
 | 
			
		||||
    f1 = Film.objects.create()
 | 
			
		||||
    f1.reporters.set([r1, r2])
 | 
			
		||||
    f2 = Film.objects.create()
 | 
			
		||||
    f2.reporters.set([r2])
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
        query {
 | 
			
		||||
            films {
 | 
			
		||||
                edges {
 | 
			
		||||
                    node {
 | 
			
		||||
                        reportersCount
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
    schema = graphene.Schema(query=Query)
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors, str(result)
 | 
			
		||||
 | 
			
		||||
    expected = {
 | 
			
		||||
        "films": {
 | 
			
		||||
            "edges": [{"node": {"reportersCount": 2}}, {"node": {"reportersCount": 1}}]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    assert result.data == expected, str(result.data)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user