get_base_field() refactor

This commit is contained in:
Tom Christie 2014-09-18 11:20:56 +01:00
parent c0155fd9dc
commit 5b7e4af0d6
9 changed files with 379 additions and 390 deletions

View File

@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value dictionary[keys[-1]] = value
def field_name_to_label(field_name):
return field_name.replace('_', ' ').capitalize()
class SkipField(Exception): class SkipField(Exception):
pass pass
@ -162,7 +158,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name. # `self.label` should deafult to being based on the field name.
if self.label is None: if self.label is None:
self.label = field_name_to_label(self.field_name) self.label = field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name. # self.source should default to being the same as the field name.
if self.source is None: if self.source is None:

View File

@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
} }
def __init__(self, view_name, **kwargs): def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
self.view_name = view_name self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
URL of relationships to other objects. URL of relationships to other objects.
""" """
def __init__(self, view_name, **kwargs): def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True kwargs['read_only'] = True
kwargs['source'] = '*' kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
'invalid': _('Invalid value.'), 'invalid': _('Invalid value.'),
} }
def __init__(self, slug_field, **kwargs): def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs) super(SlugRelatedField, self).__init__(**kwargs)

View File

@ -10,17 +10,19 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.text import capfirst
from collections import namedtuple from collections import namedtuple
from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import (
get_url_kwargs, get_field_kwargs,
get_relation_kwargs, get_nested_relation_kwargs,
lookup_class
)
import copy import copy
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
@ -126,7 +128,7 @@ class SerializerMetaclass(type):
""" """
@classmethod @classmethod
def _get_fields(cls, bases, attrs): def _get_declared_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name)) fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items()) for field_name, obj in list(attrs.items())
if isinstance(obj, Field)] if isinstance(obj, Field)]
@ -136,25 +138,18 @@ class SerializerMetaclass(type):
# fields. Note that we loop over the bases in *reverse*. This is necessary # fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields. # in order to maintain the correct order of fields.
for base in bases[::-1]: for base in bases[::-1]:
if hasattr(base, 'base_fields'): if hasattr(base, '_declared_fields'):
fields = list(base.base_fields.items()) + fields fields = list(base._declared_fields.items()) + fields
return SortedDict(fields) return SortedDict(fields)
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
attrs['base_fields'] = cls._get_fields(bases, attrs) attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(SerializerMetaclass) @six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer): class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs):
if kwargs.pop('many', False):
kwargs['child'] = cls()
return ListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {}) self.context = kwargs.pop('context', {})
kwargs.pop('partial', None) kwargs.pop('partial', None)
@ -165,14 +160,22 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances. # Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer # This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class. # instance without affecting every other serializer class.
self.fields = self.get_fields() self.fields = self._get_base_fields()
# Setup all the child fields, to provide them with the current context. # Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field.bind(field_name, self, self) field.bind(field_name, self, self)
def get_fields(self): def __new__(cls, *args, **kwargs):
return copy.deepcopy(self.base_fields) # We override this method in order to automagically create
# `ListSerializer` classes instead when `many=True` is set.
if kwargs.pop('many', False):
kwargs['child'] = cls()
return ListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls, *args, **kwargs)
def _get_base_fields(self):
return copy.deepcopy(self._declared_fields)
def bind(self, field_name, parent, root): def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound # If the serializer is used as a field then when it becomes bound
@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer):
return representation.list_repr(self, indent=1) return representation.list_repr(self, indent=1)
class ModelSerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
self.model = getattr(meta, 'model')
self.fields = getattr(meta, 'fields', ())
self.depth = getattr(meta, 'depth', 0)
def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
for cls in inspect.getmro(instance.__class__):
if cls in mapping:
return mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
return capfirst(model_field.verbose_name) != field_name_to_label(field_name)
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
field_mapping = { _field_mapping = {
models.AutoField: IntegerField, models.AutoField: IntegerField,
models.BigIntegerField: IntegerField, models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
@ -368,16 +340,10 @@ class ModelSerializer(Serializer):
models.TimeField: TimeField, models.TimeField: TimeField,
models.URLField: URLField, models.URLField: URLField,
} }
nested_class = None # We fill this in at the end of this module. _related_class = PrimaryKeyRelatedField
_options_class = ModelSerializerOptions
def __init__(self, *args, **kwargs):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
def create(self, attrs): def create(self, attrs):
ModelClass = self.opts.model ModelClass = self.Meta.model
return ModelClass.objects.create(**attrs) return ModelClass.objects.create(**attrs)
def update(self, obj, attrs): def update(self, obj, attrs):
@ -385,319 +351,97 @@ class ModelSerializer(Serializer):
setattr(obj, attr, value) setattr(obj, attr, value)
obj.save() obj.save()
def get_fields(self): def _get_base_fields(self):
# Get the explicitly declared fields. declared_fields = copy.deepcopy(self._declared_fields)
fields = copy.deepcopy(self.base_fields)
# Add in the default fields.
for key, val in self.get_default_fields().items():
if key not in fields:
fields[key] = val
# If `fields` is set on the `Meta` class,
# then use only those fields, and in that order.
if self.opts.fields:
fields = SortedDict([
(key, fields[key]) for key in self.opts.fields
])
return fields
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
info = model_meta.get_field_info(self.opts.model)
ret = SortedDict() ret = SortedDict()
model = getattr(self.Meta, 'model')
fields = getattr(self.Meta, 'fields', None)
depth = getattr(self.Meta, 'depth', 0)
# URL field # Retrieve metadata about fields & relationships on the model class.
serializer_url_field = self.get_url_field() info = model_meta.get_field_info(model)
if serializer_url_field:
ret[api_settings.URL_FIELD_NAME] = serializer_url_field
# Primary key field # Use the default set of fields if none is supplied explicitly.
field_name = info.pk.name if fields is None:
serializer_pk_field = self.get_pk_field(field_name, info.pk) fields = self._get_default_field_names(declared_fields, info)
if serializer_pk_field:
ret[field_name] = serializer_pk_field
# Regular fields for field_name in fields:
for field_name, field in info.fields.items(): if field_name in declared_fields:
ret[field_name] = self.get_field(field_name, field) # Field is explicitly declared on the class, use that.
ret[field_name] = declared_fields[field_name]
continue
# Forward relations elif field_name == api_settings.URL_FIELD_NAME:
for field_name, relation_info in info.forward_relations.items(): # Create the URL field.
if self.opts.depth: field_cls = HyperlinkedIdentityField
ret[field_name] = self.get_nested_field(field_name, *relation_info) kwargs = get_url_kwargs(model)
elif field_name in info.fields_and_pk:
# Create regular model fields.
model_field = info.fields_and_pk[field_name]
field_cls = lookup_class(self._field_mapping, model_field)
kwargs = get_field_kwargs(field_name, model_field)
if 'choices' in kwargs:
# Fields with choices get coerced into `ChoiceField`
# instead of using their regular typed field.
field_cls = ChoiceField
if not issubclass(field_cls, ModelField):
# `model_field` is only valid for the fallback case of
# `ModelField`, which is used when no other typed field
# matched to the model field.
kwargs.pop('model_field', None)
elif field_name in info.relations:
# Create forward and reverse relationships.
relation_info = info.relations[field_name]
if depth:
field_cls = self._get_nested_class(depth, relation_info)
kwargs = get_nested_relation_kwargs(relation_info)
else: else:
ret[field_name] = self.get_related_field(field_name, *relation_info) field_cls = self._related_class
kwargs = get_relation_kwargs(field_name, relation_info)
# `view_name` is only valid for hyperlinked relationships.
if not issubclass(field_cls, HyperlinkedRelatedField):
kwargs.pop('view_name', None)
# Reverse relations
for accessor_name, relation_info in info.reverse_relations.items():
if accessor_name in self.opts.fields:
if self.opts.depth:
ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info)
else: else:
ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) assert False, 'Field name `%s` is not valid.' % field_name
ret[field_name] = field_cls(**kwargs)
return ret return ret
def get_url_field(self): def _get_default_field_names(self, declared_fields, model_info):
return None return (
[model_info.pk.name] +
list(declared_fields.keys()) +
list(model_info.fields.keys()) +
list(model_info.forward_relations.keys())
)
def get_pk_field(self, field_name, model_field): def _get_nested_class(self, nested_depth, relation_info):
""" class NestedSerializer(ModelSerializer):
Returns a default instance of the pk field.
"""
return self.get_field(field_name, model_field)
def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(self.nested_class):
class Meta: class Meta:
model = related_model model = relation_info.related
depth = self.opts.depth - 1 depth = nested_depth
return NestedSerializer
kwargs = {'read_only': True}
if to_many:
kwargs['many'] = True
return NestedModelSerializer(**kwargs)
def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
kwargs = {
'queryset': related_model._default_manager,
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
serializer_cls = lookup_class(self.field_mapping, model_field)
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.flatchoices:
# If this model field contains choices, then use a ChoiceField,
# rather than the standard serializer field for this type.
# Note that we return this prior to setting any validation type
# keyword arguments, as those are not valid initializers.
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
# Also note that Django<1.6 uses `default=False` for
# models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
if issubclass(serializer_cls, ModelField):
kwargs['model_field'] = model_field
return serializer_cls(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions _related_class = HyperlinkedRelatedField
def get_url_field(self): def _get_default_field_names(self, declared_fields, model_info):
if self.opts.view_name is not None: return (
view_name = self.opts.view_name [api_settings.URL_FIELD_NAME] +
else: list(declared_fields.keys()) +
view_name = self.get_default_view_name(self.opts.model) list(model_info.fields.keys()) +
list(model_info.forward_relations.keys())
)
kwargs = { def _get_nested_class(self, nested_depth, relation_info):
'view_name': view_name class NestedSerializer(HyperlinkedModelSerializer):
} class Meta:
if self.opts.lookup_field: model = relation_info.related
kwargs['lookup_field'] = self.opts.lookup_field depth = nested_depth
return NestedSerializer
return HyperlinkedIdentityField(**kwargs)
def get_pk_field(self, field_name, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
kwargs = {
'queryset': related_model._default_manager,
'view_name': self.get_default_view_name(related_model),
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return HyperlinkedRelatedField(**kwargs)
def get_default_view_name(self, model):
"""
Return the view name to use for related models.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
ModelSerializer.nested_class = ModelSerializer
HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer

View File

@ -0,0 +1,215 @@
"""
Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivelent serializer fields.
"""
from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
import inspect
def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
for cls in inspect.getmro(instance.__class__):
if cls in mapping:
return mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
return capfirst(model_field.verbose_name) != default_label
def get_detail_view_name(model):
"""
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.flatchoices:
# If this model field contains choices, then return now,
# any further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
# Also note that Django<1.6 uses `default=False` for
# models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
return kwargs
def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, has_through_model = relation_info
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
if relation_info.to_many:
kwargs['many'] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}

View File

@ -1,7 +1,9 @@
""" """
Helper functions for returning the field information that is associated Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse with a model class. This includes returning all the forward and reverse
relationships and their associated metadata. relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import namedtuple from collections import namedtuple
from django.db import models from django.db import models
@ -9,8 +11,22 @@ from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
import inspect import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related',
'to_many',
'has_through_model'
])
def _resolve_model(obj): def _resolve_model(obj):
@ -55,7 +71,7 @@ def get_field_info(model):
forward_relations = SortedDict() forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]: for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
field=field, model_field=field,
related=_resolve_model(field.rel.to), related=_resolve_model(field.rel.to),
to_many=False, to_many=False,
has_through_model=False has_through_model=False
@ -64,7 +80,7 @@ def get_field_info(model):
# Deal with forward many-to-many relationships. # Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]: for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
field=field, model_field=field,
related=_resolve_model(field.rel.to), related=_resolve_model(field.rel.to),
to_many=True, to_many=True,
has_through_model=( has_through_model=(
@ -77,7 +93,7 @@ def get_field_info(model):
for relation in opts.get_all_related_objects(): for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
field=None, model_field=None,
related=relation.model, related=relation.model,
to_many=relation.field.rel.multiple, to_many=relation.field.rel.multiple,
has_through_model=False has_through_model=False
@ -87,7 +103,7 @@ def get_field_info(model):
for relation in opts.get_all_related_many_to_many_objects(): for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
field=None, model_field=None,
related=relation.model, related=relation.model,
to_many=True, to_many=True,
has_through_model=( has_through_model=(
@ -96,4 +112,18 @@ def get_field_info(model):
) )
) )
return FieldInfo(pk, fields, forward_relations, reverse_relations) # Shortcut that merges both regular fields and the pk,
# for simplifying regular field lookup.
fields_and_pk = SortedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
# Shortcut that merges both forward and reverse relationships
relations = SortedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
def foobar(): def foobar():
@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True, target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source') related_name='nullable_source')
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel

View File

@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
foreign_key = NestedModelSerializer(read_only=True): foreign_key = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True): one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True): many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True): through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail') url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedModelSerializer(read_only=True): foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True): one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True): many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True): through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)

View File

@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.conf.urls import patterns, url, include from django.conf.urls import patterns, url, include
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from tests.models import BasicModel, BasicModelSerializer from tests.models import BasicModel
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework import generics from rest_framework import generics
from rest_framework import routers from rest_framework import routers
from rest_framework import serializers
from rest_framework import status from rest_framework import status
from rest_framework.renderers import ( from rest_framework.renderers import (
BaseRenderer, BaseRenderer,
@ -17,6 +18,12 @@ from rest_framework import viewsets
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class MockPickleRenderer(BaseRenderer): class MockPickleRenderer(BaseRenderer):
media_type = 'application/pickle' media_type = 'application/pickle'

View File

@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
def setUp(self): def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer): class NoteSerializer(serializers.HyperlinkedModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
class Meta: class Meta:
model = RouterTestModel model = RouterTestModel
lookup_field = 'uuid'
fields = ('url', 'uuid', 'text') fields = ('url', 'uuid', 'text')
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
serializer_class = NoteSerializer serializer_class = NoteSerializer
lookup_field = 'uuid' lookup_field = 'uuid'
RouterTestModel.objects.create(uuid='123', text='foo bar')
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes', NoteViewSet)
@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
) )
RouterTestModel.objects.create(uuid='123', text='foo bar')
def test_custom_lookup_field_route(self): def test_custom_lookup_field_route(self):
detail_route = self.router.urls[-1] detail_route = self.router.urls[-1]
detail_url_pattern = detail_route.regex.pattern detail_url_pattern = detail_route.regex.pattern