Add a sanity check to avoid running into unresolved related models.

This commit is contained in:
Dustin Farris 2014-01-12 20:28:19 -05:00
parent bf5b77ce6d
commit 2332382b51
3 changed files with 52 additions and 2 deletions

View File

@ -1 +1,22 @@
# Just to keep things like ./manage.py test happy
import inspect
from django.db import models
def resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class, or a string representation
of one.
String representations should have the format:
'appname.ModelName'
"""
if type(obj) == str and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a valid Django model".format(obj))

View File

@ -20,6 +20,7 @@ from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model, six
from rest_framework.models import resolve_model
# Note: We do the following so that users of the framework can use this style:
#
@ -656,7 +657,7 @@ class ModelSerializer(Serializer):
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
related_model = model_field.rel.to
related_model = resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True

View File

@ -0,0 +1,28 @@
from django.db import models
from django.test import TestCase
from rest_framework.models import resolve_model
from rest_framework.tests.models import BasicModel
class ResolveModelTests(TestCase):
"""
`resolve_model` should return a Django model class given the
provided argument is a Django model class itself, or a properly
formatted string representation of one.
"""
def test_resolve_django_model(self):
resolved_model = resolve_model(BasicModel)
self.assertEqual(resolved_model, BasicModel)
def test_resolve_string_representation(self):
resolved_model = resolve_model('tests.BasicModel')
self.assertEqual(resolved_model, BasicModel)
def test_resolve_non_django_model(self):
with self.assertRaises(ValueError):
resolve_model(TestCase)
def test_resolve_with_improper_string_representation(self):
with self.assertRaises(ValueError):
resolve_model('BasicModel')