support local many to many in model inheritance

This commit is contained in:
Tom Dror 2023-01-10 19:41:57 -05:00
parent 50ebf76789
commit 264b84aed8
5 changed files with 199 additions and 13 deletions

View File

@ -46,6 +46,7 @@ class Reporter(models.Model):
a_choice = models.CharField(max_length=30, choices=CHOICES, blank=True) a_choice = models.CharField(max_length=30, choices=CHOICES, blank=True)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
fans = models.ManyToManyField(Person)
reporter_type = models.IntegerField( reporter_type = models.IntegerField(
"Reporter Type", "Reporter Type",

View File

@ -1209,6 +1209,165 @@ def test_model_inheritance_support_reverse_relationships():
assert result.data == expected assert result.data == expected
def test_model_inheritance_support_local_relationships():
"""
This test asserts that we can query local relationships for all Reporters and proxied Reporters and multi table Reporters.
"""
class PersonType(DjangoObjectType):
class Meta:
model = Person
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class APNewsReporterType(DjangoObjectType):
class Meta:
model = APNewsReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
film = Film.objects.create(genre="do")
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
reporter_fan = Person.objects.create(
name="Reporter Fan"
)
reporter.fans.add(reporter_fan)
reporter.save()
cnn_reporter = CNNReporter.objects.create(
first_name="Some",
last_name="Guy",
email="someguy@cnn.com",
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
cnn_fan = Person.objects.create(
name="CNN Fan"
)
cnn_reporter.fans.add(cnn_fan)
cnn_reporter.save()
ap_news_reporter = APNewsReporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
ap_news_fan = Person.objects.create(
name="AP News Fan"
)
ap_news_reporter.fans.add(ap_news_fan)
ap_news_reporter.save()
film.reporters.add(cnn_reporter, ap_news_reporter)
film.save()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
cnn_reporters = DjangoConnectionField(CNNReporterType)
ap_news_reporters = DjangoConnectionField(APNewsReporterType)
schema = graphene.Schema(query=Query)
query = """
query ProxyModelQuery {
allReporters {
edges {
node {
id
fans {
name
}
}
}
}
cnnReporters {
edges {
node {
id
fans {
name
}
}
}
}
apNewsReporters {
edges {
node {
id
fans {
name
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": to_global_id("ReporterType", reporter.id),
"fans": [{"name": f"{reporter_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
},
},
]
},
"cnnReporters": {
"edges": [
{
"node": {
"id": to_global_id("CNNReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
}
}
]
},
"apNewsReporters": {
"edges": [
{
"node": {
"id": to_global_id("APNewsReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
}
}
]
},
}
result = schema.execute(query)
assert result.data == expected
def test_should_resolve_get_queryset_connectionfields(): def test_should_resolve_get_queryset_connectionfields():
reporter_1 = Reporter.objects.create( reporter_1 = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1

View File

@ -40,6 +40,7 @@ def test_should_map_fields_correctly():
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"fans",
"reporter_type", "reporter_type",
] ]

View File

@ -74,6 +74,7 @@ def test_django_objecttype_map_correct_fields():
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"fans",
"reporter_type", "reporter_type",
] ]
assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]

View File

@ -37,12 +37,21 @@ def camelize(data):
return data return data
def get_reverse_fields(model, local_field_names): def _get_model_ancestry(model):
model_ancestry = [model] model_ancestry = [model]
for base in model.__bases__: for base in model.__bases__:
if is_valid_django_model(base): if is_valid_django_model(base) and getattr(base, "_meta", False):
model_ancestry.append(base) model_ancestry.append(base)
return model_ancestry
def get_reverse_fields(model, local_field_names):
"""
Searches through the model's ancestry and gets reverse relationships the models
Yields a tuple of (field.name, field)
"""
model_ancestry = _get_model_ancestry(model)
for _model in model_ancestry: for _model in model_ancestry:
for name, attr in _model.__dict__.items(): for name, attr in _model.__dict__.items():
@ -58,6 +67,24 @@ def get_reverse_fields(model, local_field_names):
yield (name, related) yield (name, related)
def get_local_fields(model):
"""
Searches through the model's ancestry and gets the fields on the models
Returns a dict of {field.name: field}
"""
model_ancestry = _get_model_ancestry(model)
local_fields_dict = {}
for _model in model_ancestry:
for field in sorted(
list(_model._meta.fields) + list(_model._meta.local_many_to_many)
):
if field.name not in local_fields_dict:
local_fields_dict[field.name] = field
return list(local_fields_dict.items())
def maybe_queryset(value): def maybe_queryset(value):
if isinstance(value, Manager): if isinstance(value, Manager):
value = value.get_queryset() value = value.get_queryset()
@ -65,17 +92,14 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
local_fields = [ """
(field.name, field) Gets all the fields and relationships on the Django model and its ancestry.
for field in sorted( Prioritizes local fields and relationships over the reverse relationships of the same name
list(model._meta.fields) + list(model._meta.local_many_to_many) Returns a tuple of (field.name, field)
) """
] local_fields = get_local_fields(model)
local_field_names = {field[0] for field in local_fields}
# Make sure we don't duplicate local fields with "reverse" version
local_field_names = [field[0] for field in local_fields]
reverse_fields = get_reverse_fields(model, local_field_names) reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields) all_fields = local_fields + list(reverse_fields)
return all_fields return all_fields