diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe999aec..4322f2134 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -317,17 +317,17 @@ class ModelSerializerOptions(object): self.depth = getattr(meta, 'depth', 0) -def lookup_class(mapping, obj): +def lookup_class(mapping, instance): """ Takes a dictionary with classes as keys, and an object. Traverses the object's inheritance hierarchy in method resolution order, and returns the first matching value - from the dictionary or None. + from the dictionary or raises a KeyError if nothing matches. """ - return next( - (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), - None - ) + for cls in inspect.getmro(instance.__class__): + if cls in mapping: + return mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) class ModelSerializer(Serializer): @@ -341,6 +341,7 @@ class ModelSerializer(Serializer): models.DateTimeField: DateTimeField, models.DecimalField: DecimalField, models.EmailField: EmailField, + models.Field: ModelField, models.FileField: FileField, models.FloatField: FloatField, models.ImageField: ImageField, @@ -484,6 +485,7 @@ class ModelSerializer(Serializer): """ Creates a default instance of a basic non-relational field. """ + serializer_cls = lookup_class(self.field_mapping, model_field) kwargs = {} validator_kwarg = model_field.validators @@ -602,11 +604,10 @@ class ModelSerializer(Serializer): if validator_kwarg: kwargs['validators'] = validator_kwarg - cls = lookup_class(self.field_mapping, model_field) - if cls is None: - cls = ModelField + if issubclass(serializer_cls, ModelField): kwargs['model_field'] = model_field - return cls(**kwargs) + + return serializer_cls(**kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions):