Merge pull request #2086 from beck/doug/blow-up-with-bad-models

Ensure _resolve_model does not return None
This commit is contained in:
Tom Christie 2014-11-28 15:31:51 +00:00
commit 6fbd23ab34
2 changed files with 38 additions and 1 deletions

View File

@ -43,7 +43,11 @@ def _resolve_model(obj):
""" """
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.') app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name) resolved_model = models.get_model(app_name, model_name)
if not resolved_model:
raise ValueError("Django did not return a model for "
"{0}.{1}".format(app_name, model_name))
return resolved_model
elif inspect.isclass(obj) and issubclass(obj, models.Model): elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj return obj
raise ValueError("{0} is not a Django model".format(obj)) raise ValueError("{0} is not a Django model".format(obj))

View File

@ -7,6 +7,8 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
import rest_framework.utils.model_meta
class Root(APIView): class Root(APIView):
pass pass
@ -130,3 +132,34 @@ class ResolveModelTests(TestCase):
def test_resolve_improper_string_representation(self): def test_resolve_improper_string_representation(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_resolve_model('BasicModel') _resolve_model('BasicModel')
class ResolveModelWithPatchedDjangoTests(TestCase):
"""
Test coverage for when Django's `get_model` returns `None`.
Under certain circumstances Django may return `None` with `get_model`:
http://git.io/get-model-source
It usually happens with circular imports so it is important that DRF
excepts early, otherwise fault happens downstream and is much more
difficult to debug.
"""
def setUp(self):
"""Monkeypatch get_model."""
self.get_model = rest_framework.utils.model_meta.models.get_model
def get_model(app_label, model_name):
return None
rest_framework.utils.model_meta.models.get_model = get_model
def tearDown(self):
"""Revert monkeypatching."""
rest_framework.utils.model_meta.models.get_model = self.get_model
def test_blows_up_if_model_does_not_resolve(self):
with self.assertRaises(ValueError):
_resolve_model('tests.BasicModel')