Support base class relations and reverse for proxy models (#1380)

* support reverse relationship for proxy models

* support multi table inheritence

* update query test for multi table inheritance

* remove debugger

* support local many to many in model inheritance

* format and lint

---------

Co-authored-by: Firas K <3097061+firaskafri@users.noreply.github.com>
This commit is contained in:
Tom Dror 2023-07-18 13:17:45 -04:00 committed by GitHub
parent 0de35ca3b0
commit b1abebdb97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 388 additions and 28 deletions

View File

@ -46,6 +46,7 @@ class Reporter(models.Model):
a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
fans = models.ManyToManyField(Person)
reporter_type = models.IntegerField( reporter_type = models.IntegerField(
"Reporter Type", "Reporter Type",
@ -90,6 +91,16 @@ class CNNReporter(Reporter):
objects = CNNReporterManager() objects = CNNReporterManager()
class APNewsReporter(Reporter):
"""
This class only inherits from Reporter for testing multi table inheritence
similar to what you'd see in django-polymorphic
"""
alias = models.CharField(max_length=30)
objects = models.Manager()
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField(auto_now_add=True) pub_date = models.DateField(auto_now_add=True)

View File

@ -15,7 +15,16 @@ from ..compat import IntegerRangeField, MissingType
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from ..utils import DJANGO_FILTER_INSTALLED from ..utils import DJANGO_FILTER_INSTALLED
from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter from .models import (
Article,
CNNReporter,
Film,
FilmDetails,
Person,
Pet,
Reporter,
APNewsReporter,
)
def test_should_query_only_fields(): def test_should_query_only_fields():
@ -1064,6 +1073,301 @@ def test_proxy_model_support():
assert result.data == expected assert result.data == expected
def test_model_inheritance_support_reverse_relationships():
"""
This test asserts that we can query reverse relationships for all Reporters and proxied Reporters and multi table Reporters.
"""
class FilmType(DjangoObjectType):
class Meta:
model = Film
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class APNewsReporterType(DjangoObjectType):
class Meta:
model = APNewsReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
film = Film.objects.create(genre="do")
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
cnn_reporter = CNNReporter.objects.create(
first_name="Some",
last_name="Guy",
email="someguy@cnn.com",
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
ap_news_reporter = APNewsReporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
film.reporters.add(cnn_reporter, ap_news_reporter)
film.save()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
cnn_reporters = DjangoConnectionField(CNNReporterType)
ap_news_reporters = DjangoConnectionField(APNewsReporterType)
schema = graphene.Schema(query=Query)
query = """
query ProxyModelQuery {
allReporters {
edges {
node {
id
films {
id
}
}
}
}
cnnReporters {
edges {
node {
id
films {
id
}
}
}
}
apNewsReporters {
edges {
node {
id
films {
id
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": to_global_id("ReporterType", reporter.id),
"films": [],
},
},
{
"node": {
"id": to_global_id("ReporterType", cnn_reporter.id),
"films": [{"id": f"{film.id}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", ap_news_reporter.id),
"films": [{"id": f"{film.id}"}],
},
},
]
},
"cnnReporters": {
"edges": [
{
"node": {
"id": to_global_id("CNNReporterType", cnn_reporter.id),
"films": [{"id": f"{film.id}"}],
}
}
]
},
"apNewsReporters": {
"edges": [
{
"node": {
"id": to_global_id("APNewsReporterType", ap_news_reporter.id),
"films": [{"id": f"{film.id}"}],
}
}
]
},
}
result = schema.execute(query)
assert result.data == expected
def test_model_inheritance_support_local_relationships():
"""
This test asserts that we can query local relationships for all Reporters and proxied Reporters and multi table Reporters.
"""
class PersonType(DjangoObjectType):
class Meta:
model = Person
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class APNewsReporterType(DjangoObjectType):
class Meta:
model = APNewsReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
film = Film.objects.create(genre="do")
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
reporter_fan = Person.objects.create(name="Reporter Fan")
reporter.fans.add(reporter_fan)
reporter.save()
cnn_reporter = CNNReporter.objects.create(
first_name="Some",
last_name="Guy",
email="someguy@cnn.com",
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
cnn_fan = Person.objects.create(name="CNN Fan")
cnn_reporter.fans.add(cnn_fan)
cnn_reporter.save()
ap_news_reporter = APNewsReporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
ap_news_fan = Person.objects.create(name="AP News Fan")
ap_news_reporter.fans.add(ap_news_fan)
ap_news_reporter.save()
film.reporters.add(cnn_reporter, ap_news_reporter)
film.save()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
cnn_reporters = DjangoConnectionField(CNNReporterType)
ap_news_reporters = DjangoConnectionField(APNewsReporterType)
schema = graphene.Schema(query=Query)
query = """
query ProxyModelQuery {
allReporters {
edges {
node {
id
fans {
name
}
}
}
}
cnnReporters {
edges {
node {
id
fans {
name
}
}
}
}
apNewsReporters {
edges {
node {
id
fans {
name
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": to_global_id("ReporterType", reporter.id),
"fans": [{"name": f"{reporter_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
},
},
]
},
"cnnReporters": {
"edges": [
{
"node": {
"id": to_global_id("CNNReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
}
}
]
},
"apNewsReporters": {
"edges": [
{
"node": {
"id": to_global_id("APNewsReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
}
}
]
},
}
result = schema.execute(query)
assert result.data == expected
def test_should_resolve_get_queryset_connectionfields(): def test_should_resolve_get_queryset_connectionfields():
reporter_1 = Reporter.objects.create( reporter_1 = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1

View File

@ -33,17 +33,18 @@ def test_should_map_fields_correctly():
fields = "__all__" fields = "__all__"
fields = list(ReporterType2._meta.fields.keys()) fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [ assert fields[:-3] == [
"id", "id",
"first_name", "first_name",
"last_name", "last_name",
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"fans",
"reporter_type", "reporter_type",
] ]
assert sorted(fields[-2:]) == ["articles", "films"] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():

View File

@ -67,16 +67,17 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields(): def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields fields = Reporter._meta.fields
fields = list(fields.keys()) fields = list(fields.keys())
assert fields[:-2] == [ assert fields[:-3] == [
"id", "id",
"first_name", "first_name",
"last_name", "last_name",
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"fans",
"reporter_type", "reporter_type",
] ]
assert sorted(fields[-2:]) == ["articles", "films"] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_django_objecttype_with_node_have_correct_fields(): def test_django_objecttype_with_node_have_correct_fields():

View File

@ -4,8 +4,8 @@ import pytest
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from unittest.mock import patch from unittest.mock import patch
from ..utils import camelize, get_model_fields, GraphQLTestCase from ..utils import camelize, get_model_fields, get_reverse_fields, GraphQLTestCase
from .models import Film, Reporter from .models import Film, Reporter, CNNReporter, APNewsReporter
from ..utils.testing import graphql_query from ..utils.testing import graphql_query
@ -19,6 +19,18 @@ def test_get_model_fields_no_duplication():
assert len(film_fields) == len(film_name_set) assert len(film_fields) == len(film_name_set)
def test_get_reverse_fields_includes_proxied_models():
reporter_fields = get_reverse_fields(Reporter, [])
cnn_reporter_fields = get_reverse_fields(CNNReporter, [])
ap_news_reporter_fields = get_reverse_fields(APNewsReporter, [])
assert (
len(list(reporter_fields))
== len(list(cnn_reporter_fields))
== len(list(ap_news_reporter_fields))
)
def test_camelize(): def test_camelize():
assert camelize({}) == {} assert camelize({}) == {}
assert camelize("value_a") == "value_a" assert camelize("value_a") == "value_a"

View File

@ -37,8 +37,24 @@ def camelize(data):
return data return data
def _get_model_ancestry(model):
model_ancestry = [model]
for base in model.__bases__:
if is_valid_django_model(base) and getattr(base, "_meta", False):
model_ancestry.append(base)
return model_ancestry
def get_reverse_fields(model, local_field_names): def get_reverse_fields(model, local_field_names):
for name, attr in model.__dict__.items(): """
Searches through the model's ancestry and gets reverse relationships the models
Yields a tuple of (field.name, field)
"""
model_ancestry = _get_model_ancestry(model)
for _model in model_ancestry:
for name, attr in _model.__dict__.items():
# Don't duplicate any local fields # Don't duplicate any local fields
if name in local_field_names: if name in local_field_names:
continue continue
@ -51,6 +67,24 @@ def get_reverse_fields(model, local_field_names):
yield (name, related) yield (name, related)
def get_local_fields(model):
"""
Searches through the model's ancestry and gets the fields on the models
Returns a dict of {field.name: field}
"""
model_ancestry = _get_model_ancestry(model)
local_fields_dict = {}
for _model in model_ancestry:
for field in sorted(
list(_model._meta.fields) + list(_model._meta.local_many_to_many)
):
if field.name not in local_fields_dict:
local_fields_dict[field.name] = field
return list(local_fields_dict.items())
def maybe_queryset(value): def maybe_queryset(value):
if isinstance(value, Manager): if isinstance(value, Manager):
value = value.get_queryset() value = value.get_queryset()
@ -58,17 +92,14 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
local_fields = [ """
(field.name, field) Gets all the fields and relationships on the Django model and its ancestry.
for field in sorted( Prioritizes local fields and relationships over the reverse relationships of the same name
list(model._meta.fields) + list(model._meta.local_many_to_many) Returns a tuple of (field.name, field)
) """
] local_fields = get_local_fields(model)
local_field_names = {field[0] for field in local_fields}
# Make sure we don't duplicate local fields with "reverse" version
local_field_names = [field[0] for field in local_fields]
reverse_fields = get_reverse_fields(model, local_field_names) reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields) all_fields = local_fields + list(reverse_fields)
return all_fields return all_fields