mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-06-09 08:03:19 +03:00
Add a sanity check to avoid running into unresolved related models.
This commit is contained in:
parent
bf5b77ce6d
commit
2332382b51
|
@ -1 +1,22 @@
|
||||||
# Just to keep things like ./manage.py test happy
|
import inspect
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model(obj):
|
||||||
|
"""
|
||||||
|
Resolve supplied `obj` to a Django model class.
|
||||||
|
|
||||||
|
`obj` must be a Django model class, or a string representation
|
||||||
|
of one.
|
||||||
|
|
||||||
|
String representations should have the format:
|
||||||
|
'appname.ModelName'
|
||||||
|
"""
|
||||||
|
if type(obj) == str and len(obj.split('.')) == 2:
|
||||||
|
app_name, model_name = obj.split('.')
|
||||||
|
return models.get_model(app_name, model_name)
|
||||||
|
elif inspect.isclass(obj) and issubclass(obj, models.Model):
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
raise ValueError("{0} is not a valid Django model".format(obj))
|
||||||
|
|
|
@ -20,6 +20,7 @@ from django.db import models
|
||||||
from django.forms import widgets
|
from django.forms import widgets
|
||||||
from django.utils.datastructures import SortedDict
|
from django.utils.datastructures import SortedDict
|
||||||
from rest_framework.compat import get_concrete_model, six
|
from rest_framework.compat import get_concrete_model, six
|
||||||
|
from rest_framework.models import resolve_model
|
||||||
|
|
||||||
# Note: We do the following so that users of the framework can use this style:
|
# Note: We do the following so that users of the framework can use this style:
|
||||||
#
|
#
|
||||||
|
@ -656,7 +657,7 @@ class ModelSerializer(Serializer):
|
||||||
if model_field.rel:
|
if model_field.rel:
|
||||||
to_many = isinstance(model_field,
|
to_many = isinstance(model_field,
|
||||||
models.fields.related.ManyToManyField)
|
models.fields.related.ManyToManyField)
|
||||||
related_model = model_field.rel.to
|
related_model = resolve_model(model_field.rel.to)
|
||||||
|
|
||||||
if to_many and not model_field.rel.through._meta.auto_created:
|
if to_many and not model_field.rel.through._meta.auto_created:
|
||||||
has_through_model = True
|
has_through_model = True
|
||||||
|
|
28
rest_framework/tests/test_models.py
Normal file
28
rest_framework/tests/test_models.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from django.db import models
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from rest_framework.models import resolve_model
|
||||||
|
from rest_framework.tests.models import BasicModel
|
||||||
|
|
||||||
|
|
||||||
|
class ResolveModelTests(TestCase):
|
||||||
|
"""
|
||||||
|
`resolve_model` should return a Django model class given the
|
||||||
|
provided argument is a Django model class itself, or a properly
|
||||||
|
formatted string representation of one.
|
||||||
|
"""
|
||||||
|
def test_resolve_django_model(self):
|
||||||
|
resolved_model = resolve_model(BasicModel)
|
||||||
|
self.assertEqual(resolved_model, BasicModel)
|
||||||
|
|
||||||
|
def test_resolve_string_representation(self):
|
||||||
|
resolved_model = resolve_model('tests.BasicModel')
|
||||||
|
self.assertEqual(resolved_model, BasicModel)
|
||||||
|
|
||||||
|
def test_resolve_non_django_model(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
resolve_model(TestCase)
|
||||||
|
|
||||||
|
def test_resolve_with_improper_string_representation(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
resolve_model('BasicModel')
|
Loading…
Reference in New Issue
Block a user