From 264b84aed88b3ef77621e5b9c366aec73f402947 Mon Sep 17 00:00:00 2001 From: Tom Dror Date: Tue, 10 Jan 2023 19:41:57 -0500 Subject: [PATCH] support local many to many in model inheritance --- graphene_django/tests/models.py | 1 + graphene_django/tests/test_query.py | 161 ++++++++++++++++++++++++++- graphene_django/tests/test_schema.py | 1 + graphene_django/tests/test_types.py | 1 + graphene_django/utils/utils.py | 48 ++++++-- 5 files changed, 199 insertions(+), 13 deletions(-) diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 705cedf..69b9a95 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -46,6 +46,7 @@ class Reporter(models.Model): a_choice = models.CharField(max_length=30, choices=CHOICES, blank=True) objects = models.Manager() doe_objects = DoeReporterManager() + fans = models.ManyToManyField(Person) reporter_type = models.IntegerField( "Reporter Type", diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 43e7d54..a01e017 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1076,7 +1076,7 @@ def test_model_inheritance_support_reverse_relationships(): class Meta: model = Film fields = "__all__" - + class ReporterType(DjangoObjectType): class Meta: model = Reporter @@ -1209,6 +1209,165 @@ def test_model_inheritance_support_reverse_relationships(): assert result.data == expected +def test_model_inheritance_support_local_relationships(): + """ + This test asserts that we can query local relationships for all Reporters and proxied Reporters and multi table Reporters. + """ + + class PersonType(DjangoObjectType): + class Meta: + model = Person + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class CNNReporterType(DjangoObjectType): + class Meta: + model = CNNReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class APNewsReporterType(DjangoObjectType): + class Meta: + model = APNewsReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + film = Film.objects.create(genre="do") + + reporter = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + + reporter_fan = Person.objects.create( + name="Reporter Fan" + ) + + reporter.fans.add(reporter_fan) + reporter.save() + + cnn_reporter = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + cnn_fan = Person.objects.create( + name="CNN Fan" + ) + cnn_reporter.fans.add(cnn_fan) + cnn_reporter.save() + + ap_news_reporter = APNewsReporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + ap_news_fan = Person.objects.create( + name="AP News Fan" + ) + ap_news_reporter.fans.add(ap_news_fan) + ap_news_reporter.save() + + film.reporters.add(cnn_reporter, ap_news_reporter) + film.save() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + cnn_reporters = DjangoConnectionField(CNNReporterType) + ap_news_reporters = DjangoConnectionField(APNewsReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + fans { + name + } + } + } + } + cnnReporters { + edges { + node { + id + fans { + name + } + } + } + } + apNewsReporters { + edges { + node { + id + fans { + name + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": to_global_id("ReporterType", reporter.id), + "fans": [{"name": f"{reporter_fan.name}"}], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", cnn_reporter.id), + "fans": [{"name": f"{cnn_fan.name}"}], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", ap_news_reporter.id), + "fans": [{"name": f"{ap_news_fan.name}"}], + }, + }, + ] + }, + "cnnReporters": { + "edges": [ + { + "node": { + "id": to_global_id("CNNReporterType", cnn_reporter.id), + "fans": [{"name": f"{cnn_fan.name}"}], + } + } + ] + }, + "apNewsReporters": { + "edges": [ + { + "node": { + "id": to_global_id("APNewsReporterType", ap_news_reporter.id), + "fans": [{"name": f"{ap_news_fan.name}"}], + } + } + ] + }, + } + + result = schema.execute(query) + assert result.data == expected + def test_should_resolve_get_queryset_connectionfields(): reporter_1 = Reporter.objects.create( first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index 88cabe9..93cbd9f 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -40,6 +40,7 @@ def test_should_map_fields_correctly(): "email", "pets", "a_choice", + "fans", "reporter_type", ] diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 7d75267..fd85ef1 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -74,6 +74,7 @@ def test_django_objecttype_map_correct_fields(): "email", "pets", "a_choice", + "fans", "reporter_type", ] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index 51abeb5..84f88b6 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -37,12 +37,21 @@ def camelize(data): return data -def get_reverse_fields(model, local_field_names): +def _get_model_ancestry(model): model_ancestry = [model] for base in model.__bases__: - if is_valid_django_model(base): + if is_valid_django_model(base) and getattr(base, "_meta", False): model_ancestry.append(base) + return model_ancestry + + +def get_reverse_fields(model, local_field_names): + """ + Searches through the model's ancestry and gets reverse relationships the models + Yields a tuple of (field.name, field) + """ + model_ancestry = _get_model_ancestry(model) for _model in model_ancestry: for name, attr in _model.__dict__.items(): @@ -58,6 +67,24 @@ def get_reverse_fields(model, local_field_names): yield (name, related) +def get_local_fields(model): + """ + Searches through the model's ancestry and gets the fields on the models + Returns a dict of {field.name: field} + """ + model_ancestry = _get_model_ancestry(model) + + local_fields_dict = {} + for _model in model_ancestry: + for field in sorted( + list(_model._meta.fields) + list(_model._meta.local_many_to_many) + ): + if field.name not in local_fields_dict: + local_fields_dict[field.name] = field + + return list(local_fields_dict.items()) + + def maybe_queryset(value): if isinstance(value, Manager): value = value.get_queryset() @@ -65,17 +92,14 @@ def maybe_queryset(value): def get_model_fields(model): - local_fields = [ - (field.name, field) - for field in sorted( - list(model._meta.fields) + list(model._meta.local_many_to_many) - ) - ] - - # Make sure we don't duplicate local fields with "reverse" version - local_field_names = [field[0] for field in local_fields] + """ + Gets all the fields and relationships on the Django model and its ancestry. + Prioritizes local fields and relationships over the reverse relationships of the same name + Returns a tuple of (field.name, field) + """ + local_fields = get_local_fields(model) + local_field_names = {field[0] for field in local_fields} reverse_fields = get_reverse_fields(model, local_field_names) - all_fields = local_fields + list(reverse_fields) return all_fields