Tidy up lookup_class

This commit is contained in:
Tom Christie 2014-09-11 20:22:32 +01:00
parent 3318f75a71
commit ab40780dc2

View File

@ -317,17 +317,17 @@ class ModelSerializerOptions(object):
self.depth = getattr(meta, 'depth', 0) self.depth = getattr(meta, 'depth', 0)
def lookup_class(mapping, obj): def lookup_class(mapping, instance):
""" """
Takes a dictionary with classes as keys, and an object. Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value resolution order, and returns the first matching value
from the dictionary or None. from the dictionary or raises a KeyError if nothing matches.
""" """
return next( for cls in inspect.getmro(instance.__class__):
(mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), if cls in mapping:
None return mapping[cls]
) raise KeyError('Class %s not found in lookup.', cls.__name__)
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
@ -341,6 +341,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField, models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField, models.DecimalField: DecimalField,
models.EmailField: EmailField, models.EmailField: EmailField,
models.Field: ModelField,
models.FileField: FileField, models.FileField: FileField,
models.FloatField: FloatField, models.FloatField: FloatField,
models.ImageField: ImageField, models.ImageField: ImageField,
@ -484,6 +485,7 @@ class ModelSerializer(Serializer):
""" """
Creates a default instance of a basic non-relational field. Creates a default instance of a basic non-relational field.
""" """
serializer_cls = lookup_class(self.field_mapping, model_field)
kwargs = {} kwargs = {}
validator_kwarg = model_field.validators validator_kwarg = model_field.validators
@ -602,11 +604,10 @@ class ModelSerializer(Serializer):
if validator_kwarg: if validator_kwarg:
kwargs['validators'] = validator_kwarg kwargs['validators'] = validator_kwarg
cls = lookup_class(self.field_mapping, model_field) if issubclass(serializer_cls, ModelField):
if cls is None:
cls = ModelField
kwargs['model_field'] = model_field kwargs['model_field'] = model_field
return cls(**kwargs)
return serializer_cls(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):