diff --git a/rest_framework/compat.py b/rest_framework/compat.py index af3a4b007..c56604862 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -6,9 +6,13 @@ versions of Django/Python, and compatibility wrappers around optional packages. # flake8: noqa from __future__ import unicode_literals +import inspect + import django +from django.apps import apps from django.conf import settings -from django.db import connection, transaction +from django.core.exceptions import ImproperlyConfigured +from django.db import connection, models, transaction from django.template import Context, RequestContext, Template from django.utils import six from django.views.generic import View @@ -88,6 +92,36 @@ def get_remote_field(field, **kwargs): return field.remote_field +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + resolved_model = apps.get_model(app_name, model_name) + if resolved_model is None: + msg = "Django did not return a model for {0}.{1}" + raise ImproperlyConfigured(msg.format(app_name, model_name)) + return resolved_model + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_related_model(field): + if django.VERSION < (1, 9): + return _resolve_model(field.rel.to) + return field.remote_field.model + + # contrib.postgres only supported from 1.8 onwards. try: from django.contrib.postgres import fields as postgres_fields diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 975729a47..94aa46e72 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -5,15 +5,9 @@ relationships and their associated metadata. Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ -import inspect from collections import OrderedDict, namedtuple -from django.apps import apps -from django.core.exceptions import ImproperlyConfigured -from django.db import models -from django.utils import six - -from rest_framework.compat import get_remote_field +from rest_framework.compat import get_related_model, get_remote_field FieldInfo = namedtuple('FieldResult', [ 'pk', # Model field instance @@ -33,30 +27,6 @@ RelationInfo = namedtuple('RelationInfo', [ ]) -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. - - `obj` must be a Django model class itself, or a string - representation of one. Useful in situations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - resolved_model = apps.get_model(app_name, model_name) - if resolved_model is None: - msg = "Django did not return a model for {0}.{1}" - raise ImproperlyConfigured(msg.format(app_name, model_name)) - return resolved_model - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - raise ValueError("{0} is not a Django model".format(obj)) - - def get_field_info(model): """ Given a model class, returns a `FieldInfo` instance, which is a @@ -82,7 +52,7 @@ def _get_pk(opts): while rel and rel.parent_link: # If model is a child via multi-table inheritance, use parent's pk. - pk = rel.to._meta.pk + pk = get_related_model(pk)._meta.pk rel = get_remote_field(pk) return pk @@ -108,7 +78,7 @@ def _get_forward_relationships(opts): for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: forward_relations[field.name] = RelationInfo( model_field=field, - related_model=_resolve_model(get_remote_field(field).to), + related_model=get_related_model(field), to_many=False, to_field=_get_to_field(field), has_through_model=False @@ -118,7 +88,7 @@ def _get_forward_relationships(opts): for field in [field for field in opts.many_to_many if field.serialize]: forward_relations[field.name] = RelationInfo( model_field=field, - related_model=_resolve_model(get_remote_field(field).to), + related_model=get_related_model(field), to_many=True, # manytomany do not have to_fields to_field=None, diff --git a/tests/test_utils.py b/tests/test_utils.py index cd6a6af9a..23b0013b2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,8 +6,8 @@ from django.test import TestCase, override_settings from django.utils import six import rest_framework.utils.model_meta +from rest_framework.compat import _resolve_model from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework.utils.model_meta import _resolve_model from rest_framework.views import APIView from tests.models import BasicModel @@ -166,16 +166,16 @@ class ResolveModelWithPatchedDjangoTests(TestCase): def setUp(self): """Monkeypatch get_model.""" - self.get_model = rest_framework.utils.model_meta.apps.get_model + self.get_model = rest_framework.compat.apps.get_model def get_model(app_label, model_name): return None - rest_framework.utils.model_meta.apps.get_model = get_model + rest_framework.compat.apps.get_model = get_model def tearDown(self): """Revert monkeypatching.""" - rest_framework.utils.model_meta.apps.get_model = self.get_model + rest_framework.compat.apps.get_model = self.get_model def test_blows_up_if_model_does_not_resolve(self): with self.assertRaises(ImproperlyConfigured):