mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
get_base_field() refactor
This commit is contained in:
parent
c0155fd9dc
commit
5b7e4af0d6
|
@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
|
|||
dictionary[keys[-1]] = value
|
||||
|
||||
|
||||
def field_name_to_label(field_name):
|
||||
return field_name.replace('_', ' ').capitalize()
|
||||
|
||||
|
||||
class SkipField(Exception):
|
||||
pass
|
||||
|
||||
|
@ -162,7 +158,7 @@ class Field(object):
|
|||
|
||||
# `self.label` should deafult to being based on the field name.
|
||||
if self.label is None:
|
||||
self.label = field_name_to_label(self.field_name)
|
||||
self.label = field_name.replace('_', ' ').capitalize()
|
||||
|
||||
# self.source should default to being the same as the field name.
|
||||
if self.source is None:
|
||||
|
|
|
@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
|
|||
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
|
||||
}
|
||||
|
||||
def __init__(self, view_name, **kwargs):
|
||||
def __init__(self, view_name=None, **kwargs):
|
||||
assert view_name is not None, 'The `view_name` argument is required.'
|
||||
self.view_name = view_name
|
||||
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
|
||||
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
|
||||
|
@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
|
|||
URL of relationships to other objects.
|
||||
"""
|
||||
|
||||
def __init__(self, view_name, **kwargs):
|
||||
def __init__(self, view_name=None, **kwargs):
|
||||
assert view_name is not None, 'The `view_name` argument is required.'
|
||||
kwargs['read_only'] = True
|
||||
kwargs['source'] = '*'
|
||||
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
|
||||
|
@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
|
|||
'invalid': _('Invalid value.'),
|
||||
}
|
||||
|
||||
def __init__(self, slug_field, **kwargs):
|
||||
def __init__(self, slug_field=None, **kwargs):
|
||||
assert slug_field is not None, 'The `slug_field` argument is required.'
|
||||
self.slug_field = slug_field
|
||||
super(SlugRelatedField, self).__init__(**kwargs)
|
||||
|
||||
|
|
|
@ -10,17 +10,19 @@ python primitives.
|
|||
2. The process of marshalling between python primitives and request and
|
||||
response content is handled by parsers and renderers.
|
||||
"""
|
||||
from django.core import validators
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.utils import six
|
||||
from django.utils.datastructures import SortedDict
|
||||
from django.utils.text import capfirst
|
||||
from collections import namedtuple
|
||||
from rest_framework.compat import clean_manytomany_helptext
|
||||
from rest_framework.fields import empty, set_value, Field, SkipField
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html, model_meta, representation
|
||||
from rest_framework.utils.field_mapping import (
|
||||
get_url_kwargs, get_field_kwargs,
|
||||
get_relation_kwargs, get_nested_relation_kwargs,
|
||||
lookup_class
|
||||
)
|
||||
import copy
|
||||
|
||||
# Note: We do the following so that users of the framework can use this style:
|
||||
|
@ -126,7 +128,7 @@ class SerializerMetaclass(type):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, bases, attrs):
|
||||
def _get_declared_fields(cls, bases, attrs):
|
||||
fields = [(field_name, attrs.pop(field_name))
|
||||
for field_name, obj in list(attrs.items())
|
||||
if isinstance(obj, Field)]
|
||||
|
@ -136,25 +138,18 @@ class SerializerMetaclass(type):
|
|||
# fields. Note that we loop over the bases in *reverse*. This is necessary
|
||||
# in order to maintain the correct order of fields.
|
||||
for base in bases[::-1]:
|
||||
if hasattr(base, 'base_fields'):
|
||||
fields = list(base.base_fields.items()) + fields
|
||||
if hasattr(base, '_declared_fields'):
|
||||
fields = list(base._declared_fields.items()) + fields
|
||||
|
||||
return SortedDict(fields)
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
attrs['base_fields'] = cls._get_fields(bases, attrs)
|
||||
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
|
||||
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
@six.add_metaclass(SerializerMetaclass)
|
||||
class Serializer(BaseSerializer):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if kwargs.pop('many', False):
|
||||
kwargs['child'] = cls()
|
||||
return ListSerializer(*args, **kwargs)
|
||||
return super(Serializer, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.context = kwargs.pop('context', {})
|
||||
kwargs.pop('partial', None)
|
||||
|
@ -165,14 +160,22 @@ class Serializer(BaseSerializer):
|
|||
# Every new serializer is created with a clone of the field instances.
|
||||
# This allows users to dynamically modify the fields on a serializer
|
||||
# instance without affecting every other serializer class.
|
||||
self.fields = self.get_fields()
|
||||
self.fields = self._get_base_fields()
|
||||
|
||||
# Setup all the child fields, to provide them with the current context.
|
||||
for field_name, field in self.fields.items():
|
||||
field.bind(field_name, self, self)
|
||||
|
||||
def get_fields(self):
|
||||
return copy.deepcopy(self.base_fields)
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We override this method in order to automagically create
|
||||
# `ListSerializer` classes instead when `many=True` is set.
|
||||
if kwargs.pop('many', False):
|
||||
kwargs['child'] = cls()
|
||||
return ListSerializer(*args, **kwargs)
|
||||
return super(Serializer, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
def _get_base_fields(self):
|
||||
return copy.deepcopy(self._declared_fields)
|
||||
|
||||
def bind(self, field_name, parent, root):
|
||||
# If the serializer is used as a field then when it becomes bound
|
||||
|
@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer):
|
|||
return representation.list_repr(self, indent=1)
|
||||
|
||||
|
||||
class ModelSerializerOptions(object):
|
||||
"""
|
||||
Meta class options for ModelSerializer
|
||||
"""
|
||||
def __init__(self, meta):
|
||||
self.model = getattr(meta, 'model')
|
||||
self.fields = getattr(meta, 'fields', ())
|
||||
self.depth = getattr(meta, 'depth', 0)
|
||||
|
||||
|
||||
def lookup_class(mapping, instance):
|
||||
"""
|
||||
Takes a dictionary with classes as keys, and an object.
|
||||
Traverses the object's inheritance hierarchy in method
|
||||
resolution order, and returns the first matching value
|
||||
from the dictionary or raises a KeyError if nothing matches.
|
||||
"""
|
||||
for cls in inspect.getmro(instance.__class__):
|
||||
if cls in mapping:
|
||||
return mapping[cls]
|
||||
raise KeyError('Class %s not found in lookup.', cls.__name__)
|
||||
|
||||
|
||||
def needs_label(model_field, field_name):
|
||||
"""
|
||||
Returns `True` if the label based on the model's verbose name
|
||||
is not equal to the default label it would have based on it's field name.
|
||||
"""
|
||||
return capfirst(model_field.verbose_name) != field_name_to_label(field_name)
|
||||
|
||||
|
||||
class ModelSerializer(Serializer):
|
||||
field_mapping = {
|
||||
_field_mapping = {
|
||||
models.AutoField: IntegerField,
|
||||
models.BigIntegerField: IntegerField,
|
||||
models.BooleanField: BooleanField,
|
||||
|
@ -368,16 +340,10 @@ class ModelSerializer(Serializer):
|
|||
models.TimeField: TimeField,
|
||||
models.URLField: URLField,
|
||||
}
|
||||
nested_class = None # We fill this in at the end of this module.
|
||||
|
||||
_options_class = ModelSerializerOptions
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.opts = self._options_class(self.Meta)
|
||||
super(ModelSerializer, self).__init__(*args, **kwargs)
|
||||
_related_class = PrimaryKeyRelatedField
|
||||
|
||||
def create(self, attrs):
|
||||
ModelClass = self.opts.model
|
||||
ModelClass = self.Meta.model
|
||||
return ModelClass.objects.create(**attrs)
|
||||
|
||||
def update(self, obj, attrs):
|
||||
|
@ -385,319 +351,97 @@ class ModelSerializer(Serializer):
|
|||
setattr(obj, attr, value)
|
||||
obj.save()
|
||||
|
||||
def get_fields(self):
|
||||
# Get the explicitly declared fields.
|
||||
fields = copy.deepcopy(self.base_fields)
|
||||
def _get_base_fields(self):
|
||||
declared_fields = copy.deepcopy(self._declared_fields)
|
||||
|
||||
# Add in the default fields.
|
||||
for key, val in self.get_default_fields().items():
|
||||
if key not in fields:
|
||||
fields[key] = val
|
||||
|
||||
# If `fields` is set on the `Meta` class,
|
||||
# then use only those fields, and in that order.
|
||||
if self.opts.fields:
|
||||
fields = SortedDict([
|
||||
(key, fields[key]) for key in self.opts.fields
|
||||
])
|
||||
|
||||
return fields
|
||||
|
||||
def get_default_fields(self):
|
||||
"""
|
||||
Return all the fields that should be serialized for the model.
|
||||
"""
|
||||
info = model_meta.get_field_info(self.opts.model)
|
||||
ret = SortedDict()
|
||||
model = getattr(self.Meta, 'model')
|
||||
fields = getattr(self.Meta, 'fields', None)
|
||||
depth = getattr(self.Meta, 'depth', 0)
|
||||
|
||||
# URL field
|
||||
serializer_url_field = self.get_url_field()
|
||||
if serializer_url_field:
|
||||
ret[api_settings.URL_FIELD_NAME] = serializer_url_field
|
||||
# Retrieve metadata about fields & relationships on the model class.
|
||||
info = model_meta.get_field_info(model)
|
||||
|
||||
# Primary key field
|
||||
field_name = info.pk.name
|
||||
serializer_pk_field = self.get_pk_field(field_name, info.pk)
|
||||
if serializer_pk_field:
|
||||
ret[field_name] = serializer_pk_field
|
||||
# Use the default set of fields if none is supplied explicitly.
|
||||
if fields is None:
|
||||
fields = self._get_default_field_names(declared_fields, info)
|
||||
|
||||
# Regular fields
|
||||
for field_name, field in info.fields.items():
|
||||
ret[field_name] = self.get_field(field_name, field)
|
||||
for field_name in fields:
|
||||
if field_name in declared_fields:
|
||||
# Field is explicitly declared on the class, use that.
|
||||
ret[field_name] = declared_fields[field_name]
|
||||
continue
|
||||
|
||||
# Forward relations
|
||||
for field_name, relation_info in info.forward_relations.items():
|
||||
if self.opts.depth:
|
||||
ret[field_name] = self.get_nested_field(field_name, *relation_info)
|
||||
else:
|
||||
ret[field_name] = self.get_related_field(field_name, *relation_info)
|
||||
elif field_name == api_settings.URL_FIELD_NAME:
|
||||
# Create the URL field.
|
||||
field_cls = HyperlinkedIdentityField
|
||||
kwargs = get_url_kwargs(model)
|
||||
|
||||
# Reverse relations
|
||||
for accessor_name, relation_info in info.reverse_relations.items():
|
||||
if accessor_name in self.opts.fields:
|
||||
if self.opts.depth:
|
||||
ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info)
|
||||
elif field_name in info.fields_and_pk:
|
||||
# Create regular model fields.
|
||||
model_field = info.fields_and_pk[field_name]
|
||||
field_cls = lookup_class(self._field_mapping, model_field)
|
||||
kwargs = get_field_kwargs(field_name, model_field)
|
||||
if 'choices' in kwargs:
|
||||
# Fields with choices get coerced into `ChoiceField`
|
||||
# instead of using their regular typed field.
|
||||
field_cls = ChoiceField
|
||||
if not issubclass(field_cls, ModelField):
|
||||
# `model_field` is only valid for the fallback case of
|
||||
# `ModelField`, which is used when no other typed field
|
||||
# matched to the model field.
|
||||
kwargs.pop('model_field', None)
|
||||
|
||||
elif field_name in info.relations:
|
||||
# Create forward and reverse relationships.
|
||||
relation_info = info.relations[field_name]
|
||||
if depth:
|
||||
field_cls = self._get_nested_class(depth, relation_info)
|
||||
kwargs = get_nested_relation_kwargs(relation_info)
|
||||
else:
|
||||
ret[accessor_name] = self.get_related_field(accessor_name, *relation_info)
|
||||
field_cls = self._related_class
|
||||
kwargs = get_relation_kwargs(field_name, relation_info)
|
||||
# `view_name` is only valid for hyperlinked relationships.
|
||||
if not issubclass(field_cls, HyperlinkedRelatedField):
|
||||
kwargs.pop('view_name', None)
|
||||
|
||||
else:
|
||||
assert False, 'Field name `%s` is not valid.' % field_name
|
||||
|
||||
ret[field_name] = field_cls(**kwargs)
|
||||
|
||||
return ret
|
||||
|
||||
def get_url_field(self):
|
||||
return None
|
||||
def _get_default_field_names(self, declared_fields, model_info):
|
||||
return (
|
||||
[model_info.pk.name] +
|
||||
list(declared_fields.keys()) +
|
||||
list(model_info.fields.keys()) +
|
||||
list(model_info.forward_relations.keys())
|
||||
)
|
||||
|
||||
def get_pk_field(self, field_name, model_field):
|
||||
"""
|
||||
Returns a default instance of the pk field.
|
||||
"""
|
||||
return self.get_field(field_name, model_field)
|
||||
|
||||
def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a nested relational field.
|
||||
|
||||
Note that model_field will be `None` for reverse relationships.
|
||||
"""
|
||||
class NestedModelSerializer(self.nested_class):
|
||||
def _get_nested_class(self, nested_depth, relation_info):
|
||||
class NestedSerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = related_model
|
||||
depth = self.opts.depth - 1
|
||||
|
||||
kwargs = {'read_only': True}
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
return NestedModelSerializer(**kwargs)
|
||||
|
||||
def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
|
||||
Note that model_field will be `None` for reverse relationships.
|
||||
"""
|
||||
kwargs = {
|
||||
'queryset': related_model._default_manager,
|
||||
}
|
||||
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
if has_through_model:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
if model_field:
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
if model_field.verbose_name and needs_label(model_field, field_name):
|
||||
kwargs['label'] = capfirst(model_field.verbose_name)
|
||||
if not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
help_text = clean_manytomany_helptext(model_field.help_text)
|
||||
if help_text:
|
||||
kwargs['help_text'] = help_text
|
||||
|
||||
return PrimaryKeyRelatedField(**kwargs)
|
||||
|
||||
def get_field(self, field_name, model_field):
|
||||
"""
|
||||
Creates a default instance of a basic non-relational field.
|
||||
"""
|
||||
serializer_cls = lookup_class(self.field_mapping, model_field)
|
||||
kwargs = {}
|
||||
validator_kwarg = model_field.validators
|
||||
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
|
||||
if model_field.verbose_name and needs_label(model_field, field_name):
|
||||
kwargs['label'] = capfirst(model_field.verbose_name)
|
||||
|
||||
if model_field.help_text:
|
||||
kwargs['help_text'] = model_field.help_text
|
||||
|
||||
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
# Read only implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if model_field.has_default():
|
||||
kwargs['default'] = model_field.get_default()
|
||||
# Having a default implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if model_field.flatchoices:
|
||||
# If this model field contains choices, then use a ChoiceField,
|
||||
# rather than the standard serializer field for this type.
|
||||
# Note that we return this prior to setting any validation type
|
||||
# keyword arguments, as those are not valid initializers.
|
||||
kwargs['choices'] = model_field.flatchoices
|
||||
return ChoiceField(**kwargs)
|
||||
|
||||
# Ensure that max_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_length = getattr(model_field, 'max_length', None)
|
||||
if max_length is not None:
|
||||
kwargs['max_length'] = max_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that min_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_length = getattr(model_field, 'min_length', None)
|
||||
if min_length is not None:
|
||||
kwargs['min_length'] = min_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MaxValueValidator)
|
||||
), None)
|
||||
if max_value is not None:
|
||||
kwargs['max_value'] = max_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxValueValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MinValueValidator)
|
||||
), None)
|
||||
if min_value is not None:
|
||||
kwargs['min_value'] = min_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinValueValidator)
|
||||
]
|
||||
|
||||
# URLField does not need to include the URLValidator argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.URLField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.URLValidator)
|
||||
]
|
||||
|
||||
# EmailField does not need to include the validate_email argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.EmailField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_email
|
||||
]
|
||||
|
||||
# SlugField do not need to include the 'validate_slug' argument,
|
||||
if isinstance(model_field, models.SlugField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_slug
|
||||
]
|
||||
|
||||
max_digits = getattr(model_field, 'max_digits', None)
|
||||
if max_digits is not None:
|
||||
kwargs['max_digits'] = max_digits
|
||||
|
||||
decimal_places = getattr(model_field, 'decimal_places', None)
|
||||
if decimal_places is not None:
|
||||
kwargs['decimal_places'] = decimal_places
|
||||
|
||||
if isinstance(model_field, models.BooleanField):
|
||||
# models.BooleanField has `blank=True`, but *is* actually
|
||||
# required *unless* a default is provided.
|
||||
# Also note that Django<1.6 uses `default=False` for
|
||||
# models.BooleanField, but Django>=1.6 uses `default=None`.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if validator_kwarg:
|
||||
kwargs['validators'] = validator_kwarg
|
||||
|
||||
if issubclass(serializer_cls, ModelField):
|
||||
kwargs['model_field'] = model_field
|
||||
|
||||
return serializer_cls(**kwargs)
|
||||
|
||||
|
||||
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
|
||||
"""
|
||||
Options for HyperlinkedModelSerializer
|
||||
"""
|
||||
def __init__(self, meta):
|
||||
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
|
||||
self.view_name = getattr(meta, 'view_name', None)
|
||||
self.lookup_field = getattr(meta, 'lookup_field', None)
|
||||
model = relation_info.related
|
||||
depth = nested_depth
|
||||
return NestedSerializer
|
||||
|
||||
|
||||
class HyperlinkedModelSerializer(ModelSerializer):
|
||||
_options_class = HyperlinkedModelSerializerOptions
|
||||
_related_class = HyperlinkedRelatedField
|
||||
|
||||
def get_url_field(self):
|
||||
if self.opts.view_name is not None:
|
||||
view_name = self.opts.view_name
|
||||
else:
|
||||
view_name = self.get_default_view_name(self.opts.model)
|
||||
def _get_default_field_names(self, declared_fields, model_info):
|
||||
return (
|
||||
[api_settings.URL_FIELD_NAME] +
|
||||
list(declared_fields.keys()) +
|
||||
list(model_info.fields.keys()) +
|
||||
list(model_info.forward_relations.keys())
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
'view_name': view_name
|
||||
}
|
||||
if self.opts.lookup_field:
|
||||
kwargs['lookup_field'] = self.opts.lookup_field
|
||||
|
||||
return HyperlinkedIdentityField(**kwargs)
|
||||
|
||||
def get_pk_field(self, field_name, model_field):
|
||||
if self.opts.fields and model_field.name in self.opts.fields:
|
||||
return self.get_field(model_field)
|
||||
|
||||
def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
"""
|
||||
kwargs = {
|
||||
'queryset': related_model._default_manager,
|
||||
'view_name': self.get_default_view_name(related_model),
|
||||
}
|
||||
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
if has_through_model:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
if model_field:
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
if model_field.verbose_name and needs_label(model_field, field_name):
|
||||
kwargs['label'] = capfirst(model_field.verbose_name)
|
||||
if not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
help_text = clean_manytomany_helptext(model_field.help_text)
|
||||
if help_text:
|
||||
kwargs['help_text'] = help_text
|
||||
|
||||
return HyperlinkedRelatedField(**kwargs)
|
||||
|
||||
def get_default_view_name(self, model):
|
||||
"""
|
||||
Return the view name to use for related models.
|
||||
"""
|
||||
return '%(model_name)s-detail' % {
|
||||
'app_label': model._meta.app_label,
|
||||
'model_name': model._meta.object_name.lower()
|
||||
}
|
||||
|
||||
|
||||
ModelSerializer.nested_class = ModelSerializer
|
||||
HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer
|
||||
def _get_nested_class(self, nested_depth, relation_info):
|
||||
class NestedSerializer(HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = relation_info.related
|
||||
depth = nested_depth
|
||||
return NestedSerializer
|
||||
|
|
215
rest_framework/utils/field_mapping.py
Normal file
215
rest_framework/utils/field_mapping.py
Normal file
|
@ -0,0 +1,215 @@
|
|||
"""
|
||||
Helper functions for mapping model fields to a dictionary of default
|
||||
keyword arguments that should be used for their equivelent serializer fields.
|
||||
"""
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
from django.utils.text import capfirst
|
||||
from rest_framework.compat import clean_manytomany_helptext
|
||||
import inspect
|
||||
|
||||
|
||||
def lookup_class(mapping, instance):
|
||||
"""
|
||||
Takes a dictionary with classes as keys, and an object.
|
||||
Traverses the object's inheritance hierarchy in method
|
||||
resolution order, and returns the first matching value
|
||||
from the dictionary or raises a KeyError if nothing matches.
|
||||
"""
|
||||
for cls in inspect.getmro(instance.__class__):
|
||||
if cls in mapping:
|
||||
return mapping[cls]
|
||||
raise KeyError('Class %s not found in lookup.', cls.__name__)
|
||||
|
||||
|
||||
def needs_label(model_field, field_name):
|
||||
"""
|
||||
Returns `True` if the label based on the model's verbose name
|
||||
is not equal to the default label it would have based on it's field name.
|
||||
"""
|
||||
default_label = field_name.replace('_', ' ').capitalize()
|
||||
return capfirst(model_field.verbose_name) != default_label
|
||||
|
||||
|
||||
def get_detail_view_name(model):
|
||||
"""
|
||||
Given a model class, return the view name to use for URL relationships
|
||||
that refer to instances of the model.
|
||||
"""
|
||||
return '%(model_name)s-detail' % {
|
||||
'app_label': model._meta.app_label,
|
||||
'model_name': model._meta.object_name.lower()
|
||||
}
|
||||
|
||||
|
||||
def get_field_kwargs(field_name, model_field):
|
||||
"""
|
||||
Creates a default instance of a basic non-relational field.
|
||||
"""
|
||||
kwargs = {}
|
||||
validator_kwarg = model_field.validators
|
||||
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
|
||||
if model_field.verbose_name and needs_label(model_field, field_name):
|
||||
kwargs['label'] = capfirst(model_field.verbose_name)
|
||||
|
||||
if model_field.help_text:
|
||||
kwargs['help_text'] = model_field.help_text
|
||||
|
||||
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
# Read only implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if model_field.has_default():
|
||||
kwargs['default'] = model_field.get_default()
|
||||
# Having a default implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if model_field.flatchoices:
|
||||
# If this model field contains choices, then return now,
|
||||
# any further keyword arguments are not valid.
|
||||
kwargs['choices'] = model_field.flatchoices
|
||||
return kwargs
|
||||
|
||||
# Ensure that max_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_length = getattr(model_field, 'max_length', None)
|
||||
if max_length is not None:
|
||||
kwargs['max_length'] = max_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that min_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_length = getattr(model_field, 'min_length', None)
|
||||
if min_length is not None:
|
||||
kwargs['min_length'] = min_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MaxValueValidator)
|
||||
), None)
|
||||
if max_value is not None:
|
||||
kwargs['max_value'] = max_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxValueValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MinValueValidator)
|
||||
), None)
|
||||
if min_value is not None:
|
||||
kwargs['min_value'] = min_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinValueValidator)
|
||||
]
|
||||
|
||||
# URLField does not need to include the URLValidator argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.URLField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.URLValidator)
|
||||
]
|
||||
|
||||
# EmailField does not need to include the validate_email argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.EmailField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_email
|
||||
]
|
||||
|
||||
# SlugField do not need to include the 'validate_slug' argument,
|
||||
if isinstance(model_field, models.SlugField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_slug
|
||||
]
|
||||
|
||||
max_digits = getattr(model_field, 'max_digits', None)
|
||||
if max_digits is not None:
|
||||
kwargs['max_digits'] = max_digits
|
||||
|
||||
decimal_places = getattr(model_field, 'decimal_places', None)
|
||||
if decimal_places is not None:
|
||||
kwargs['decimal_places'] = decimal_places
|
||||
|
||||
if isinstance(model_field, models.BooleanField):
|
||||
# models.BooleanField has `blank=True`, but *is* actually
|
||||
# required *unless* a default is provided.
|
||||
# Also note that Django<1.6 uses `default=False` for
|
||||
# models.BooleanField, but Django>=1.6 uses `default=None`.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if validator_kwarg:
|
||||
kwargs['validators'] = validator_kwarg
|
||||
|
||||
# The following will only be used by ModelField classes.
|
||||
# Gets removed for everything else.
|
||||
kwargs['model_field'] = model_field
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_relation_kwargs(field_name, relation_info):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
"""
|
||||
model_field, related_model, to_many, has_through_model = relation_info
|
||||
kwargs = {
|
||||
'queryset': related_model._default_manager,
|
||||
'view_name': get_detail_view_name(related_model)
|
||||
}
|
||||
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
if has_through_model:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
if model_field:
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
if model_field.verbose_name and needs_label(model_field, field_name):
|
||||
kwargs['label'] = capfirst(model_field.verbose_name)
|
||||
if not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
help_text = clean_manytomany_helptext(model_field.help_text)
|
||||
if help_text:
|
||||
kwargs['help_text'] = help_text
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_nested_relation_kwargs(relation_info):
|
||||
kwargs = {'read_only': True}
|
||||
if relation_info.to_many:
|
||||
kwargs['many'] = True
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_url_kwargs(model_field):
|
||||
return {
|
||||
'view_name': get_detail_view_name(model_field)
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
"""
|
||||
Helper functions for returning the field information that is associated
|
||||
Helper function for returning the field information that is associated
|
||||
with a model class. This includes returning all the forward and reverse
|
||||
relationships and their associated metadata.
|
||||
|
||||
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
from django.db import models
|
||||
|
@ -9,8 +11,22 @@ from django.utils import six
|
|||
from django.utils.datastructures import SortedDict
|
||||
import inspect
|
||||
|
||||
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
|
||||
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
|
||||
|
||||
FieldInfo = namedtuple('FieldResult', [
|
||||
'pk', # Model field instance
|
||||
'fields', # Dict of field name -> model field instance
|
||||
'forward_relations', # Dict of field name -> RelationInfo
|
||||
'reverse_relations', # Dict of field name -> RelationInfo
|
||||
'fields_and_pk', # Shortcut for 'pk' + 'fields'
|
||||
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
|
||||
])
|
||||
|
||||
RelationInfo = namedtuple('RelationInfo', [
|
||||
'model_field',
|
||||
'related',
|
||||
'to_many',
|
||||
'has_through_model'
|
||||
])
|
||||
|
||||
|
||||
def _resolve_model(obj):
|
||||
|
@ -55,7 +71,7 @@ def get_field_info(model):
|
|||
forward_relations = SortedDict()
|
||||
for field in [field for field in opts.fields if field.serialize and field.rel]:
|
||||
forward_relations[field.name] = RelationInfo(
|
||||
field=field,
|
||||
model_field=field,
|
||||
related=_resolve_model(field.rel.to),
|
||||
to_many=False,
|
||||
has_through_model=False
|
||||
|
@ -64,7 +80,7 @@ def get_field_info(model):
|
|||
# Deal with forward many-to-many relationships.
|
||||
for field in [field for field in opts.many_to_many if field.serialize]:
|
||||
forward_relations[field.name] = RelationInfo(
|
||||
field=field,
|
||||
model_field=field,
|
||||
related=_resolve_model(field.rel.to),
|
||||
to_many=True,
|
||||
has_through_model=(
|
||||
|
@ -77,7 +93,7 @@ def get_field_info(model):
|
|||
for relation in opts.get_all_related_objects():
|
||||
accessor_name = relation.get_accessor_name()
|
||||
reverse_relations[accessor_name] = RelationInfo(
|
||||
field=None,
|
||||
model_field=None,
|
||||
related=relation.model,
|
||||
to_many=relation.field.rel.multiple,
|
||||
has_through_model=False
|
||||
|
@ -87,7 +103,7 @@ def get_field_info(model):
|
|||
for relation in opts.get_all_related_many_to_many_objects():
|
||||
accessor_name = relation.get_accessor_name()
|
||||
reverse_relations[accessor_name] = RelationInfo(
|
||||
field=None,
|
||||
model_field=None,
|
||||
related=relation.model,
|
||||
to_many=True,
|
||||
has_through_model=(
|
||||
|
@ -96,4 +112,18 @@ def get_field_info(model):
|
|||
)
|
||||
)
|
||||
|
||||
return FieldInfo(pk, fields, forward_relations, reverse_relations)
|
||||
# Shortcut that merges both regular fields and the pk,
|
||||
# for simplifying regular field lookup.
|
||||
fields_and_pk = SortedDict()
|
||||
fields_and_pk['pk'] = pk
|
||||
fields_and_pk[pk.name] = pk
|
||||
fields_and_pk.update(fields)
|
||||
|
||||
# Shortcut that merges both forward and reverse relationships
|
||||
|
||||
relations = SortedDict(
|
||||
list(forward_relations.items()) +
|
||||
list(reverse_relations.items())
|
||||
)
|
||||
|
||||
return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
from django.db import models
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
def foobar():
|
||||
|
@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
|
|||
name = models.CharField(max_length=100)
|
||||
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
|
||||
related_name='nullable_source')
|
||||
|
||||
|
||||
# Serializer used to test BasicModel
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
|
|
|
@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):
|
|||
expected = dedent("""
|
||||
TestSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
foreign_key = NestedModelSerializer(read_only=True):
|
||||
foreign_key = NestedSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
one_to_one = NestedModelSerializer(read_only=True):
|
||||
one_to_one = NestedSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
many_to_many = NestedModelSerializer(many=True, read_only=True):
|
||||
many_to_many = NestedSerializer(many=True, read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
through = NestedModelSerializer(many=True, read_only=True):
|
||||
through = NestedSerializer(many=True, read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
""")
|
||||
|
@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):
|
|||
expected = dedent("""
|
||||
TestSerializer():
|
||||
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
|
||||
foreign_key = NestedModelSerializer(read_only=True):
|
||||
foreign_key = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
one_to_one = NestedModelSerializer(read_only=True):
|
||||
one_to_one = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
many_to_many = NestedModelSerializer(many=True, read_only=True):
|
||||
many_to_many = NestedSerializer(many=True, read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
through = NestedModelSerializer(many=True, read_only=True):
|
||||
through = NestedSerializer(many=True, read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
""")
|
||||
|
|
|
@ -2,11 +2,12 @@ from __future__ import unicode_literals
|
|||
from django.conf.urls import patterns, url, include
|
||||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
from tests.models import BasicModel, BasicModelSerializer
|
||||
from tests.models import BasicModel
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework import generics
|
||||
from rest_framework import routers
|
||||
from rest_framework import serializers
|
||||
from rest_framework import status
|
||||
from rest_framework.renderers import (
|
||||
BaseRenderer,
|
||||
|
@ -17,6 +18,12 @@ from rest_framework import viewsets
|
|||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
# Serializer used to test BasicModel
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
|
||||
|
||||
class MockPickleRenderer(BaseRenderer):
|
||||
media_type = 'application/pickle'
|
||||
|
||||
|
|
|
@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
|
|||
|
||||
def setUp(self):
|
||||
class NoteSerializer(serializers.HyperlinkedModelSerializer):
|
||||
url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
|
||||
|
||||
class Meta:
|
||||
model = RouterTestModel
|
||||
lookup_field = 'uuid'
|
||||
fields = ('url', 'uuid', 'text')
|
||||
|
||||
class NoteViewSet(viewsets.ModelViewSet):
|
||||
|
@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
|
|||
serializer_class = NoteSerializer
|
||||
lookup_field = 'uuid'
|
||||
|
||||
RouterTestModel.objects.create(uuid='123', text='foo bar')
|
||||
|
||||
self.router = SimpleRouter()
|
||||
self.router.register(r'notes', NoteViewSet)
|
||||
|
||||
|
@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
|
|||
url(r'^', include(self.router.urls)),
|
||||
)
|
||||
|
||||
RouterTestModel.objects.create(uuid='123', text='foo bar')
|
||||
|
||||
def test_custom_lookup_field_route(self):
|
||||
detail_route = self.router.urls[-1]
|
||||
detail_url_pattern = detail_route.regex.pattern
|
||||
|
|
Loading…
Reference in New Issue
Block a user