diff --git a/rest_framework/models.py b/rest_framework/models.py index 5b53a5264..249cdd828 100644 --- a/rest_framework/models.py +++ b/rest_framework/models.py @@ -1 +1,22 @@ -# Just to keep things like ./manage.py test happy +import inspect + +from django.db import models + + +def resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class, or a string representation + of one. + + String representations should have the format: + 'appname.ModelName' + """ + if type(obj) == str and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + else: + raise ValueError("{0} is not a valid Django model".format(obj)) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b22ca5783..6b31c3043 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model, six +from rest_framework.models import resolve_model # Note: We do the following so that users of the framework can use this style: # @@ -656,7 +657,7 @@ class ModelSerializer(Serializer): if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - related_model = model_field.rel.to + related_model = resolve_model(model_field.rel.to) if to_many and not model_field.rel.through._meta.auto_created: has_through_model = True diff --git a/rest_framework/tests/test_models.py b/rest_framework/tests/test_models.py new file mode 100644 index 000000000..5e92d48ae --- /dev/null +++ b/rest_framework/tests/test_models.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.models import resolve_model +from rest_framework.tests.models import BasicModel + + +class ResolveModelTests(TestCase): + """ + `resolve_model` should return a Django model class given the + provided argument is a Django model class itself, or a properly + formatted string representation of one. + """ + def test_resolve_django_model(self): + resolved_model = resolve_model(BasicModel) + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_string_representation(self): + resolved_model = resolve_model('tests.BasicModel') + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_non_django_model(self): + with self.assertRaises(ValueError): + resolve_model(TestCase) + + def test_resolve_with_improper_string_representation(self): + with self.assertRaises(ValueError): + resolve_model('BasicModel')