Clean up field classes

This commit is contained in:
Tom Christie 2012-10-04 13:28:14 +01:00
parent d89d6887d2
commit 3a06dde884
2 changed files with 177 additions and 155 deletions

View File

@ -26,21 +26,88 @@ def is_simple_callable(obj):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
empty = '' empty = ''
def __init__(self, source=None, readonly=False, required=None, def __init__(self, source=None):
validators=[], error_messages=None):
self.parent = None self.parent = None
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
self.source = source self.source = source
def initialize(self, parent):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
def field_from_native(self, data, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
return
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
if getattr(self, 'type_name', None):
return {'type': self.type_name}
return {}
class WritableField(Field):
"""
Base for read/write fields.
"""
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
super(WritableField, self).__init__(source=source)
self.readonly = readonly self.readonly = readonly
if required is None: if required is None:
self.required = not(readonly) self.required = not(readonly)
@ -56,19 +123,6 @@ class Field(object):
self.validators = self.default_validators + validators self.validators = self.default_validators + validators
def initialize(self, parent, model_field=None):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
if model_field:
self.model_field = model_field
def validate(self, value): def validate(self, value):
if value in validators.EMPTY_VALUES and self.required: if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
@ -117,96 +171,75 @@ class Field(object):
""" """
Reverts a simple representation back to the field's value. Reverts a simple representation back to the field's value.
""" """
if hasattr(self, 'model_field'): return value
class ModelField(WritableField):
"""
A generic field that can be used against an arbirtrary model field.
"""
def __init__(self, *args, **kwargs):
try: try:
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value) self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
super(ModelField, self).__init__(*args, **kwargs)
def from_native(self, value):
try:
rel = self.model_field.rel
except: except:
return self.model_field.to_python(value) return self.model_field.to_python(value)
return value return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
""" value = self.model_field._get_val_from_obj(obj)
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
self.obj = obj # Need to hang onto this in the case of model fields
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value): if is_protected_type(value):
return value return value
elif hasattr(self, 'model_field'):
return self.model_field.value_to_string(self.obj) return self.model_field.value_to_string(self.obj)
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self): def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
try:
return { return {
"type": self.model_field.get_internal_type() "type": self.model_field.get_internal_type()
} }
except AttributeError:
return {} ##### Relational fields #####
class HyperlinkedIdentityField(Field): class RelatedField(WritableField):
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
view_name = self.parent.opts.view_name
view_kwargs = {'pk': obj.pk}
return reverse(view_name, kwargs=view_kwargs, request=request)
class RelatedField(Field):
""" """
A base class for model related fields or related managers. Base class for related model fields.
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None) self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs) super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): return self.to_native(value)
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def attributes(self): def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[(self.source or field_name) + '_id'] = self.from_native(value)
class ManyRelatedField(RelatedField):
"""
Base class for related model managers.
"""
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into):
try: try:
return { value = data.getlist(self.source or field_name)
"rel": self.model_field.rel.__class__.__name__, except:
"to": smart_unicode(self.model_field.rel.to._meta) value = data.get(self.source or field_name)
} else:
except AttributeError: if value == ['']:
return {} value = []
into[field_name] = [self.from_native(item) for item in value]
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
@ -215,20 +248,11 @@ class PrimaryKeyRelatedField(RelatedField):
""" """
def to_native(self, pk): def to_native(self, pk):
"""
You can subclass this method to provide different serialization
behavior based on the pk.
"""
return pk return pk
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
# This is only implemented for performance reasons
#
# We could leave the default `RelatedField.field_to_native()` in place,
# and inside just implement `to_native()` as `return obj.pk`
#
# That would involve an extra database lookup.
try: try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name) pk = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
# RelatedObject (reverse relationship) # RelatedObject (reverse relationship)
@ -237,18 +261,17 @@ class PrimaryKeyRelatedField(RelatedField):
# Forward relationship # Forward relationship
return self.to_native(pk) return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[field_name + '_id'] = self.from_native(value)
class ManyPrimaryKeyRelatedField(ManyRelatedField):
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
""" """
Serializes a to-many related field or related manager to a pk value. Serializes a to-many related field or related manager to a pk value.
""" """
def to_native(self, pk):
return pk
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
try: try:
# Prefer obj.serializable_value for performance reasons
queryset = obj.serializable_value(self.source or field_name) queryset = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
# RelatedManager (reverse relationship) # RelatedManager (reverse relationship)
@ -257,40 +280,25 @@ class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
# Forward relationship # Forward relationship
return [self.to_native(item.pk) for item in queryset.all()] return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
value = data.getlist(field_name)
except:
value = data.get(field_name)
else:
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class HyperlinkedIdentityField(Field):
class NaturalKeyRelatedField(RelatedField):
""" """
Serializes a model related field or related manager to a natural key value. A field that represents the model's identity using a hyperlink.
""" """
is_natural_key = True # XML renderer handles these differently def __init__(self, *args, **kwargs):
pass
def to_native(self, obj): def field_to_native(self, obj, field_name):
if hasattr(obj, 'natural_key'): request = self.context.get('request', None)
return obj.natural_key() view_name = self.parent.opts.view_name
return obj view_kwargs = {'pk': obj.pk}
return reverse(view_name, kwargs=view_kwargs, request=request)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[self.model_field.attname] = self.from_native(value)
def from_native(self, value):
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
manager = self.model_field.rel.to._default_manager
manager = manager.db_manager(DEFAULT_DB_ALIAS)
return manager.get_by_natural_key(*value).pk
class BooleanField(Field): ##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."), 'invalid': _(u"'%s' value must be either True or False."),
} }
@ -307,7 +315,9 @@ class BooleanField(Field):
raise ValidationError(self.error_messages['invalid'] % value) raise ValidationError(self.error_messages['invalid'] % value)
class CharField(Field): class CharField(WritableField):
type_name = 'CharField'
def __init__(self, max_length=None, min_length=None, *args, **kwargs): def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length self.max_length, self.min_length = max_length, min_length
super(CharField, self).__init__(*args, **kwargs) super(CharField, self).__init__(*args, **kwargs)
@ -323,6 +333,8 @@ class CharField(Field):
class EmailField(CharField): class EmailField(CharField):
type_name = 'EmailField'
default_error_messages = { default_error_messages = {
'invalid': _('Enter a valid e-mail address.'), 'invalid': _('Enter a valid e-mail address.'),
} }
@ -339,7 +351,9 @@ class EmailField(CharField):
return result return result
class DateField(Field): class DateField(WritableField):
type_name = 'DateField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be " 'invalid': _(u"'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."), u"in YYYY-MM-DD format."),
@ -373,7 +387,9 @@ class DateField(Field):
raise ValidationError(msg) raise ValidationError(msg)
class DateTimeField(Field): class DateTimeField(WritableField):
type_name = 'DateTimeField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in " 'invalid': _(u"'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
@ -424,7 +440,9 @@ class DateTimeField(Field):
raise ValidationError(msg) raise ValidationError(msg)
class IntegerField(Field): class IntegerField(WritableField):
type_name = 'IntegerField'
default_error_messages = { default_error_messages = {
'invalid': _('Enter a whole number.'), 'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
@ -450,7 +468,9 @@ class IntegerField(Field):
return value return value
class FloatField(Field): class FloatField(WritableField):
type_name = 'FloatField'
default_error_messages = { default_error_messages = {
'invalid': _("'%s' value must be a float."), 'invalid': _("'%s' value must be a float."),
} }

View File

@ -123,16 +123,8 @@ class BaseSerializer(Field):
# Get the explicitly declared fields # Get the explicitly declared fields
for key, field in self.fields.items(): for key, field in self.fields.items():
ret[key] = field ret[key] = field
# Determine if the declared field corrosponds to a model field.
try:
if key == 'pk':
model_field = obj._meta.pk
else:
model_field = obj._meta.get_field_by_name(key)[0]
except:
model_field = None
# Set up the field # Set up the field
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
# Add in the default fields # Add in the default fields
fields = self.default_fields(serialize, obj, data, nested) fields = self.default_fields(serialize, obj, data, nested)
@ -157,12 +149,12 @@ class BaseSerializer(Field):
##### #####
# Field methods - used when the serializer class is itself used as a field. # Field methods - used when the serializer class is itself used as a field.
def initialize(self, parent, model_field=None): def initialize(self, parent):
""" """
Same behaviour as usual Field, except that we need to keep track Same behaviour as usual Field, except that we need to keep track
of state so that we can deal with handling maximum depth and recursion. of state so that we can deal with handling maximum depth and recursion.
""" """
super(BaseSerializer, self).initialize(parent, model_field) super(BaseSerializer, self).initialize(parent)
self.stack = parent.stack[:] self.stack = parent.stack[:]
if parent.opts.nested and not isinstance(parent.opts.nested, bool): if parent.opts.nested and not isinstance(parent.opts.nested, bool):
self.opts.nested = parent.opts.nested - 1 self.opts.nested = parent.opts.nested - 1
@ -296,12 +288,22 @@ class ModelSerializerOptions(SerializerOptions):
self.model = getattr(meta, 'model', None) self.model = getattr(meta, 'model', None)
class ModelSerializer(RelatedField, Serializer): class ModelSerializer(Serializer):
""" """
A serializer that deals with model instances and querysets. A serializer that deals with model instances and querysets.
""" """
_options_class = ModelSerializerOptions _options_class = ModelSerializerOptions
def field_to_native(self, obj, field_name):
"""
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def default_fields(self, serialize, obj=None, data=None, nested=False): def default_fields(self, serialize, obj=None, data=None, nested=False):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
@ -330,7 +332,7 @@ class ModelSerializer(RelatedField, Serializer):
field = self.get_field(model_field) field = self.get_field(model_field)
if field: if field:
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
ret[model_field.name] = field ret[model_field.name] = field
return ret return ret
@ -339,7 +341,7 @@ class ModelSerializer(RelatedField, Serializer):
""" """
Returns a default instance of the pk field. Returns a default instance of the pk field.
""" """
return Field(readonly=True) return Field()
def get_nested_field(self, model_field): def get_nested_field(self, model_field):
""" """
@ -373,7 +375,7 @@ class ModelSerializer(RelatedField, Serializer):
try: try:
return field_mapping[model_field.__class__]() return field_mapping[model_field.__class__]()
except KeyError: except KeyError:
return Field() return ModelField(model_field=model_field)
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
""" """