diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 0c6c2d390..70b71666a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -48,7 +48,7 @@ def is_simple_callable(obj): return len_args <= len_defaults -def get_attribute(instance, attrs): +def get_attribute(instance, attrs, required): """ Similar to Python's built in `getattr(instance, attr)`, but takes a list of nested attributes, instead of a single attribute. @@ -59,15 +59,19 @@ def get_attribute(instance, attrs): if instance is None: # Break out early if we get `None` at any point in a nested lookup. return None + if getattr(instance, '__getitem__', None): + try: + return instance[attr] + except (KeyError, TypeError, AttributeError): + pass try: instance = getattr(instance, attr) except ObjectDoesNotExist: return None except AttributeError as exc: - try: - return instance[attr] - except (KeyError, TypeError, AttributeError): - raise exc + if not required: + return None + raise exc if is_simple_callable(instance): instance = instance() return instance @@ -275,7 +279,7 @@ class Field(object): Given the *outgoing* object instance, return the primitive value that should be used for this field. """ - return get_attribute(instance, self.source_attrs) + return get_attribute(instance, self.source_attrs, self.required) def get_default(self): """ diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 75d68204b..000ad6566 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -85,7 +85,7 @@ class RelatedField(Field): return queryset def get_iterable(self, instance, source_attrs): - relationship = get_attribute(instance, source_attrs) + relationship = get_attribute(instance, source_attrs, self.required) return relationship.all() if (hasattr(relationship, 'all')) else relationship @property @@ -134,10 +134,10 @@ class PrimaryKeyRelatedField(RelatedField): # the related object. We return this directly instead of returning the # object itself, which would require a database lookup. try: - instance = get_attribute(instance, self.source_attrs[:-1]) + instance = get_attribute(instance, self.source_attrs[:-1], self.required) return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1])) except AttributeError: - return get_attribute(instance, self.source_attrs) + return get_attribute(instance, self.source_attrs, self.required) def get_iterable(self, instance, source_attrs): # For consistency with `get_attribute` we're using `serializable_value()` @@ -349,7 +349,7 @@ class ManyRelatedField(Field): ] def get_attribute(self, instance): - return self.child_relation.get_iterable(instance, self.source_attrs) + return self.child_relation.get_iterable(instance, self.source_attrs, self.required) def to_representation(self, iterable): return [