support multi table inheritence

This commit is contained in:
Tom Dror 2023-01-10 12:50:01 -05:00
parent 88fe24764b
commit d009f3502e
5 changed files with 19 additions and 9 deletions

View File

@ -89,6 +89,14 @@ class CNNReporter(Reporter):
objects = CNNReporterManager()
class APNewsReporter(Reporter):
"""
This class only inherits from Reporter for testing multi table inheritence
similar to what you'd see in django-polymorphic
"""
alias = models.CharField(max_length=30)
objects = models.Manager()
class Article(models.Model):
headline = models.CharField(max_length=100)

View File

@ -33,7 +33,7 @@ def test_should_map_fields_correctly():
fields = "__all__"
fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [
assert fields[:-3] == [
"id",
"first_name",
"last_name",
@ -43,7 +43,7 @@ def test_should_map_fields_correctly():
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_should_map_only_few_fields():

View File

@ -67,7 +67,7 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields
fields = list(fields.keys())
assert fields[:-2] == [
assert fields[:-3] == [
"id",
"first_name",
"last_name",
@ -76,7 +76,7 @@ def test_django_objecttype_map_correct_fields():
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_django_objecttype_with_node_have_correct_fields():

View File

@ -5,7 +5,7 @@ from django.utils.translation import gettext_lazy
from unittest.mock import patch
from ..utils import camelize, get_model_fields, get_reverse_fields, GraphQLTestCase
from .models import Film, Reporter, CNNReporter
from .models import Film, Reporter, CNNReporter, APNewsReporter
from ..utils.testing import graphql_query
@ -22,8 +22,9 @@ def test_get_model_fields_no_duplication():
def test_get_reverse_fields_includes_proxied_models():
reporter_fields = get_reverse_fields(Reporter, [])
cnn_reporter_fields = get_reverse_fields(CNNReporter, [])
ap_news_reporter_fields = get_reverse_fields(APNewsReporter, [])
assert len(list(reporter_fields)) == len(list(cnn_reporter_fields))
assert len(list(reporter_fields)) == len(list(cnn_reporter_fields)) == len(list(ap_news_reporter_fields))
def test_camelize():

View File

@ -39,9 +39,10 @@ def camelize(data):
def get_reverse_fields(model, local_field_names):
model_ancestry = [model]
# Include proxy models when getting related fields
if model._meta.proxy:
model_ancestry.append(model._meta.proxy_for_model)
for base in model.__bases__:
if is_valid_django_model(base):
model_ancestry.append(base)
for _model in model_ancestry:
for name, attr in _model.__dict__.items():