From 82d4b2083292659358d5df4d03d2115576e8ae4e Mon Sep 17 00:00:00 2001 From: Timo Tuominen Date: Mon, 1 Sep 2014 12:17:36 +0300 Subject: [PATCH 1/4] Add subclass matching to serializer field mapping. --- rest_framework/serializers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f24..6d25161e2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -907,6 +907,9 @@ class ModelSerializer(Serializer): try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: + for model_field_class, serializer_field_class in self.field_mapping.items(): + if isinstance(model_field, model_field_class): + return serializer_field_class(**kwargs) return ModelField(model_field=model_field, **kwargs) def get_validation_exclusions(self, instance=None): From ae84b8b0e8a99261ea2436f77ab5238f21603c0c Mon Sep 17 00:00:00 2001 From: Timo Tuominen Date: Mon, 1 Sep 2014 15:03:39 +0300 Subject: [PATCH 2/4] Traverse the method resolution order when mapping serializer fields. --- rest_framework/serializers.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6d25161e2..f37fbf980 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -904,13 +904,11 @@ class ModelSerializer(Serializer): for attribute in attributes: kwargs.update({attribute: getattr(model_field, attribute)}) - try: - return self.field_mapping[model_field.__class__](**kwargs) - except KeyError: - for model_field_class, serializer_field_class in self.field_mapping.items(): - if isinstance(model_field, model_field_class): - return serializer_field_class(**kwargs) - return ModelField(model_field=model_field, **kwargs) + for model_field_baseclass in inspect.getmro(model_field.__class__): + serializer_field_class = self.field_mapping.get(model_field_baseclass) + if serializer_field_class: + return serializer_field_class(**kwargs) + return ModelField(model_field=model_field, **kwargs) def get_validation_exclusions(self, instance=None): """ From 582f6fdd4b0fb12a7c0d1fefe265499a284c9b79 Mon Sep 17 00:00:00 2001 From: Timo Tuominen Date: Mon, 1 Sep 2014 15:54:33 +0300 Subject: [PATCH 3/4] Add utility function to match classes in dictionary. --- rest_framework/serializers.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f37fbf980..5c33300c4 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -625,6 +625,21 @@ class ModelSerializerOptions(SerializerOptions): self.write_only_fields = getattr(meta, 'write_only_fields', ()) +def _get_class_mapping(mapping, obj): + """ + Takes a dictionary with classes as keys, and an object. + Traverses the object's inheritance hierarchy in method + resolution order, and returns the first matching value + from the dictionary or None. + + """ + for baseclass in inspect.getmro(obj.__class__): + val = mapping.get(baseclass) + if val: + return val + return None + + class ModelSerializer(Serializer): """ A serializer that deals with model instances and querysets. @@ -899,15 +914,16 @@ class ModelSerializer(Serializer): models.URLField: ['max_length'], } - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] + attributes = _get_class_mapping(attribute_dict, model_field) + if attributes: for attribute in attributes: kwargs.update({attribute: getattr(model_field, attribute)}) - for model_field_baseclass in inspect.getmro(model_field.__class__): - serializer_field_class = self.field_mapping.get(model_field_baseclass) - if serializer_field_class: - return serializer_field_class(**kwargs) + serializer_field_class = _get_class_mapping( + self.field_mapping, model_field) + + if serializer_field_class: + return serializer_field_class(**kwargs) return ModelField(model_field=model_field, **kwargs) def get_validation_exclusions(self, instance=None): From e437520217e20d500d641b95482d49484b1f24a7 Mon Sep 17 00:00:00 2001 From: Timo Tuominen Date: Mon, 1 Sep 2014 17:02:48 +0300 Subject: [PATCH 4/4] Generator implementation of class mapping. --- rest_framework/serializers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5c33300c4..b3db35823 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -633,11 +633,10 @@ def _get_class_mapping(mapping, obj): from the dictionary or None. """ - for baseclass in inspect.getmro(obj.__class__): - val = mapping.get(baseclass) - if val: - return val - return None + return next( + (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), + None + ) class ModelSerializer(Serializer):