Use '.remote_field' and '.model' in preference to '.rel' and '.to' when inspecting model fields.

This commit is contained in:
Tom Christie 2016-06-02 14:05:16 +01:00
parent b91aaa56cb
commit 745a28e8b1
3 changed files with 43 additions and 39 deletions

View File

@ -6,9 +6,13 @@ versions of Django/Python, and compatibility wrappers around optional packages.
# flake8: noqa # flake8: noqa
from __future__ import unicode_literals from __future__ import unicode_literals
import inspect
import django import django
from django.apps import apps
from django.conf import settings from django.conf import settings
from django.db import connection, transaction from django.core.exceptions import ImproperlyConfigured
from django.db import connection, models, transaction
from django.template import Context, RequestContext, Template from django.template import Context, RequestContext, Template
from django.utils import six from django.utils import six
from django.views.generic import View from django.views.generic import View
@ -88,6 +92,36 @@ def get_remote_field(field, **kwargs):
return field.remote_field return field.remote_field
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
resolved_model = apps.get_model(app_name, model_name)
if resolved_model is None:
msg = "Django did not return a model for {0}.{1}"
raise ImproperlyConfigured(msg.format(app_name, model_name))
return resolved_model
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
def get_related_model(field):
if django.VERSION < (1, 9):
return _resolve_model(field.rel.to)
return field.remote_field.model
# contrib.postgres only supported from 1.8 onwards. # contrib.postgres only supported from 1.8 onwards.
try: try:
from django.contrib.postgres import fields as postgres_fields from django.contrib.postgres import fields as postgres_fields

View File

@ -5,15 +5,9 @@ relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance. Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
import inspect
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from django.apps import apps from rest_framework.compat import get_related_model, get_remote_field
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.utils import six
from rest_framework.compat import get_remote_field
FieldInfo = namedtuple('FieldResult', [ FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance 'pk', # Model field instance
@ -33,30 +27,6 @@ RelationInfo = namedtuple('RelationInfo', [
]) ])
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
resolved_model = apps.get_model(app_name, model_name)
if resolved_model is None:
msg = "Django did not return a model for {0}.{1}"
raise ImproperlyConfigured(msg.format(app_name, model_name))
return resolved_model
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
def get_field_info(model): def get_field_info(model):
""" """
Given a model class, returns a `FieldInfo` instance, which is a Given a model class, returns a `FieldInfo` instance, which is a
@ -82,7 +52,7 @@ def _get_pk(opts):
while rel and rel.parent_link: while rel and rel.parent_link:
# If model is a child via multi-table inheritance, use parent's pk. # If model is a child via multi-table inheritance, use parent's pk.
pk = rel.to._meta.pk pk = get_related_model(pk)._meta.pk
rel = get_remote_field(pk) rel = get_remote_field(pk)
return pk return pk
@ -108,7 +78,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=_resolve_model(get_remote_field(field).to), related_model=get_related_model(field),
to_many=False, to_many=False,
to_field=_get_to_field(field), to_field=_get_to_field(field),
has_through_model=False has_through_model=False
@ -118,7 +88,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.many_to_many if field.serialize]: for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=_resolve_model(get_remote_field(field).to), related_model=get_related_model(field),
to_many=True, to_many=True,
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,

View File

@ -6,8 +6,8 @@ from django.test import TestCase, override_settings
from django.utils import six from django.utils import six
import rest_framework.utils.model_meta import rest_framework.utils.model_meta
from rest_framework.compat import _resolve_model
from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.utils.model_meta import _resolve_model
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
@ -166,16 +166,16 @@ class ResolveModelWithPatchedDjangoTests(TestCase):
def setUp(self): def setUp(self):
"""Monkeypatch get_model.""" """Monkeypatch get_model."""
self.get_model = rest_framework.utils.model_meta.apps.get_model self.get_model = rest_framework.compat.apps.get_model
def get_model(app_label, model_name): def get_model(app_label, model_name):
return None return None
rest_framework.utils.model_meta.apps.get_model = get_model rest_framework.compat.apps.get_model = get_model
def tearDown(self): def tearDown(self):
"""Revert monkeypatching.""" """Revert monkeypatching."""
rest_framework.utils.model_meta.apps.get_model = self.get_model rest_framework.compat.apps.get_model = self.get_model
def test_blows_up_if_model_does_not_resolve(self): def test_blows_up_if_model_does_not_resolve(self):
with self.assertRaises(ImproperlyConfigured): with self.assertRaises(ImproperlyConfigured):