diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 636f74c..705cedf 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -89,6 +89,14 @@ class CNNReporter(Reporter): objects = CNNReporterManager() +class APNewsReporter(Reporter): + """ + This class only inherits from Reporter for testing multi table inheritence + similar to what you'd see in django-polymorphic + """ + alias = models.CharField(max_length=30) + objects = models.Manager() + class Article(models.Model): headline = models.CharField(max_length=100) diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index ff2d8a6..88cabe9 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -33,7 +33,7 @@ def test_should_map_fields_correctly(): fields = "__all__" fields = list(ReporterType2._meta.fields.keys()) - assert fields[:-2] == [ + assert fields[:-3] == [ "id", "first_name", "last_name", @@ -43,7 +43,7 @@ def test_should_map_fields_correctly(): "reporter_type", ] - assert sorted(fields[-2:]) == ["articles", "films"] + assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] def test_should_map_only_few_fields(): diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index fad26e2..7d75267 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -67,7 +67,7 @@ def test_django_get_node(get): def test_django_objecttype_map_correct_fields(): fields = Reporter._meta.fields fields = list(fields.keys()) - assert fields[:-2] == [ + assert fields[:-3] == [ "id", "first_name", "last_name", @@ -76,7 +76,7 @@ def test_django_objecttype_map_correct_fields(): "a_choice", "reporter_type", ] - assert sorted(fields[-2:]) == ["articles", "films"] + assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] def test_django_objecttype_with_node_have_correct_fields(): diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index 22f87c7..5a4db8d 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -5,7 +5,7 @@ from django.utils.translation import gettext_lazy from unittest.mock import patch from ..utils import camelize, get_model_fields, get_reverse_fields, GraphQLTestCase -from .models import Film, Reporter, CNNReporter +from .models import Film, Reporter, CNNReporter, APNewsReporter from ..utils.testing import graphql_query @@ -22,8 +22,9 @@ def test_get_model_fields_no_duplication(): def test_get_reverse_fields_includes_proxied_models(): reporter_fields = get_reverse_fields(Reporter, []) cnn_reporter_fields = get_reverse_fields(CNNReporter, []) + ap_news_reporter_fields = get_reverse_fields(APNewsReporter, []) - assert len(list(reporter_fields)) == len(list(cnn_reporter_fields)) + assert len(list(reporter_fields)) == len(list(cnn_reporter_fields)) == len(list(ap_news_reporter_fields)) def test_camelize(): diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index 228b41a..51abeb5 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -39,9 +39,10 @@ def camelize(data): def get_reverse_fields(model, local_field_names): model_ancestry = [model] - # Include proxy models when getting related fields - if model._meta.proxy: - model_ancestry.append(model._meta.proxy_for_model) + + for base in model.__bases__: + if is_valid_django_model(base): + model_ancestry.append(base) for _model in model_ancestry: for name, attr in _model.__dict__.items():