diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index df339d8..87cc3e0 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1067,6 +1067,110 @@ def test_proxy_model_support(): assert result.data == expected +def test_proxy_model_support_reverse_relationships(): + """ + This test asserts that we can query reverse relationships for all Reporters and proxied Reporters. + """ + + class FilmType(DjangoObjectType): + class Meta: + model = Film + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class CNNReporterType(DjangoObjectType): + class Meta: + model = CNNReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + film = Film.objects.create(genre="do") + + reporter = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + cnn_reporter = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + + film.reporters.add(cnn_reporter) + film.save() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + cnn_reporters = DjangoConnectionField(CNNReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + films { + id + } + } + } + } + cnnReporters { + edges { + node { + id + films { + id + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": to_global_id("ReporterType", reporter.id), + "films": [], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", cnn_reporter.id), + "films": [{"id": f"{film.id}"}], + }, + }, + ] + }, + "cnnReporters": { + "edges": [ + { + "node": { + "id": to_global_id("CNNReporterType", cnn_reporter.id), + "films": [{"id": f"{film.id}"}], + } + } + ] + }, + } + + result = schema.execute(query) + assert result.data == expected + + def test_should_resolve_get_queryset_connectionfields(): reporter_1 = Reporter.objects.create( first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index fa269b4..22f87c7 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -4,8 +4,8 @@ import pytest from django.utils.translation import gettext_lazy from unittest.mock import patch -from ..utils import camelize, get_model_fields, GraphQLTestCase -from .models import Film, Reporter +from ..utils import camelize, get_model_fields, get_reverse_fields, GraphQLTestCase +from .models import Film, Reporter, CNNReporter from ..utils.testing import graphql_query @@ -19,6 +19,13 @@ def test_get_model_fields_no_duplication(): assert len(film_fields) == len(film_name_set) +def test_get_reverse_fields_includes_proxied_models(): + reporter_fields = get_reverse_fields(Reporter, []) + cnn_reporter_fields = get_reverse_fields(CNNReporter, []) + + assert len(list(reporter_fields)) == len(list(cnn_reporter_fields)) + + def test_camelize(): assert camelize({}) == {} assert camelize("value_a") == "value_a" diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index 343a3a7..228b41a 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -38,17 +38,23 @@ def camelize(data): def get_reverse_fields(model, local_field_names): - for name, attr in model.__dict__.items(): - # Don't duplicate any local fields - if name in local_field_names: - continue + model_ancestry = [model] + # Include proxy models when getting related fields + if model._meta.proxy: + model_ancestry.append(model._meta.proxy_for_model) - # "rel" for FK and M2M relations and "related" for O2O Relations - related = getattr(attr, "rel", None) or getattr(attr, "related", None) - if isinstance(related, models.ManyToOneRel): - yield (name, related) - elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: - yield (name, related) + for _model in model_ancestry: + for name, attr in _model.__dict__.items(): + # Don't duplicate any local fields + if name in local_field_names: + continue + + # "rel" for FK and M2M relations and "related" for O2O Relations + related = getattr(attr, "rel", None) or getattr(attr, "related", None) + if isinstance(related, models.ManyToOneRel): + yield (name, related) + elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: + yield (name, related) def maybe_queryset(value):