feat!: check django model has a default ordering when used in a relay connection

This commit is contained in:
Thomas Leonard 2024-01-29 16:13:01 +01:00
parent b85177cebf
commit d12ea31f88
6 changed files with 69 additions and 16 deletions

View File

@ -24,6 +24,9 @@ class Faction(models.Model):
class Ship(models.Model): class Ship(models.Model):
class Meta:
ordering = ["pk"]
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name="ships") faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name="ships")

View File

@ -101,13 +101,19 @@ class DjangoConnectionField(ConnectionField):
non_null = True non_null = True
assert issubclass( assert issubclass(
_type, DjangoObjectType _type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types" ), "DjangoConnectionField only accepts DjangoObjectType types as underlying type"
assert _type._meta.connection, "The type {} doesn't have a connection".format( assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__ _type.__name__
) )
connection_type = _type._meta.connection connection_type = _type._meta.connection
if non_null: if non_null:
return NonNull(connection_type) return NonNull(connection_type)
# Since Relay Connections require to have a predictible ordering for pagination,
# check on init that the Django model provided has a default ordering declared.
model = connection_type._meta.node._meta.model
assert (
len(getattr(model._meta, "ordering", [])) > 0
), f"Django model {model._meta.app_label}.{model.__name__} has to have a default ordering to be used in a Connection."
return connection_type return connection_type
@property @property

View File

@ -26,6 +26,9 @@ else:
class Event(models.Model): class Event(models.Model):
class Meta:
ordering = ["pk"]
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50)) tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField()) tag_ids = ArrayField(models.IntegerField())

View File

@ -5,6 +5,9 @@ CHOICES = ((1, "this"), (2, _("that")))
class Person(models.Model): class Person(models.Model):
class Meta:
ordering = ["pk"]
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
parent = models.ForeignKey( parent = models.ForeignKey(
"self", on_delete=models.CASCADE, null=True, blank=True, related_name="children" "self", on_delete=models.CASCADE, null=True, blank=True, related_name="children"
@ -12,6 +15,9 @@ class Person(models.Model):
class Pet(models.Model): class Pet(models.Model):
class Meta:
ordering = ["pk"]
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
age = models.PositiveIntegerField() age = models.PositiveIntegerField()
owner = models.ForeignKey( owner = models.ForeignKey(
@ -31,6 +37,9 @@ class FilmDetails(models.Model):
class Film(models.Model): class Film(models.Model):
class Meta:
ordering = ["pk"]
genre = models.CharField( genre = models.CharField(
max_length=2, max_length=2,
help_text="Genre", help_text="Genre",
@ -46,6 +55,9 @@ class DoeReporterManager(models.Manager):
class Reporter(models.Model): class Reporter(models.Model):
class Meta:
ordering = ["pk"]
first_name = models.CharField(max_length=30) first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30)
email = models.EmailField() email = models.EmailField()

View File

@ -2,11 +2,12 @@ import datetime
import re import re
import pytest import pytest
from django.db.models import Count, Prefetch from django.db.models import Count, Model, Prefetch
from graphene import List, NonNull, ObjectType, Schema, String from graphene import List, NonNull, ObjectType, Schema, String
from graphene.relay import Node
from ..fields import DjangoListField from ..fields import DjangoConnectionField, DjangoListField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import ( from .models import (
Article as ArticleModel, Article as ArticleModel,
@ -716,3 +717,34 @@ class TestDjangoListField:
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"', r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
captured.captured_queries[1]["sql"], captured.captured_queries[1]["sql"],
) )
class TestDjangoConnectionField:
def test_model_ordering_assertion(self):
class Chaos(Model):
class Meta:
app_label = "test"
class ChaosType(DjangoObjectType):
class Meta:
model = Chaos
interfaces = (Node,)
class Query(ObjectType):
chaos = DjangoConnectionField(ChaosType)
with pytest.raises(
TypeError,
match=r"Django model test\.Chaos has to have a default ordering to be used in a Connection\.",
):
Schema(query=Query)
def test_only_django_object_types(self):
class Query(ObjectType):
something = DjangoConnectionField(String)
with pytest.raises(
TypeError,
match="DjangoConnectionField only accepts DjangoObjectType types as underlying type",
):
Schema(query=Query)

View File

@ -1,3 +1,4 @@
import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from textwrap import dedent from textwrap import dedent
from unittest.mock import patch from unittest.mock import patch
@ -399,7 +400,7 @@ def test_django_objecttype_fields_exist_on_model():
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match=r"Field name .* matches an attribute on Django model .* but it's not a model field", match=r"Field name .* matches an attribute on Django model .* but it's not a model field",
) as record: ):
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
@ -407,7 +408,8 @@ def test_django_objecttype_fields_exist_on_model():
fields = ["first_name", "some_method", "email"] fields = ["first_name", "some_method", "email"]
# Don't warn if selecting a custom field # Don't warn if selecting a custom field
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
custom_field = String() custom_field = String()
@ -416,8 +418,6 @@ def test_django_objecttype_fields_exist_on_model():
model = ReporterModel model = ReporterModel
fields = ["first_name", "custom_field", "email"] fields = ["first_name", "custom_field", "email"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_exclude_fields_exist_on_model(): def test_django_objecttype_exclude_fields_exist_on_model():
@ -445,15 +445,14 @@ def test_django_objecttype_exclude_fields_exist_on_model():
exclude = ["custom_field"] exclude = ["custom_field"]
# Don't warn on exclude fields # Don't warn on exclude fields
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter4(DjangoObjectType): class Reporter4(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email", "first_name"] exclude = ["email", "first_name"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_neither_fields_nor_exclude(): def test_django_objecttype_neither_fields_nor_exclude():
@ -467,24 +466,22 @@ def test_django_objecttype_neither_fields_nor_exclude():
class Meta: class Meta:
model = ReporterModel model = ReporterModel
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
fields = ["email"] fields = ["email"]
assert len(record) == 0 with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(None) as record:
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email"] exclude = ["email"]
assert len(record) == 0
def custom_enum_name(field): def custom_enum_name(field):
return f"CustomEnum{field.name.title()}" return f"CustomEnum{field.name.title()}"