This commit is contained in:
Catstyle_Lee 2014-12-09 11:40:28 +00:00
commit 625e8869e2
2 changed files with 14 additions and 10 deletions

View File

@ -48,7 +48,7 @@ def is_simple_callable(obj):
return len_args <= len_defaults return len_args <= len_defaults
def get_attribute(instance, attrs): def get_attribute(instance, attrs, required):
""" """
Similar to Python's built in `getattr(instance, attr)`, Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute. but takes a list of nested attributes, instead of a single attribute.
@ -59,14 +59,18 @@ def get_attribute(instance, attrs):
if instance is None: if instance is None:
# Break out early if we get `None` at any point in a nested lookup. # Break out early if we get `None` at any point in a nested lookup.
return None return None
if getattr(instance, '__getitem__', None):
try:
return instance[attr]
except (KeyError, TypeError, AttributeError):
pass
try: try:
instance = getattr(instance, attr) instance = getattr(instance, attr)
except ObjectDoesNotExist: except ObjectDoesNotExist:
return None return None
except AttributeError as exc: except AttributeError as exc:
try: if not required:
return instance[attr] return None
except (KeyError, TypeError, AttributeError):
raise exc raise exc
if is_simple_callable(instance): if is_simple_callable(instance):
instance = instance() instance = instance()
@ -275,7 +279,7 @@ class Field(object):
Given the *outgoing* object instance, return the primitive value Given the *outgoing* object instance, return the primitive value
that should be used for this field. that should be used for this field.
""" """
return get_attribute(instance, self.source_attrs) return get_attribute(instance, self.source_attrs, self.required)
def get_default(self): def get_default(self):
""" """

View File

@ -85,7 +85,7 @@ class RelatedField(Field):
return queryset return queryset
def get_iterable(self, instance, source_attrs): def get_iterable(self, instance, source_attrs):
relationship = get_attribute(instance, source_attrs) relationship = get_attribute(instance, source_attrs, self.required)
return relationship.all() if (hasattr(relationship, 'all')) else relationship return relationship.all() if (hasattr(relationship, 'all')) else relationship
@property @property
@ -134,10 +134,10 @@ class PrimaryKeyRelatedField(RelatedField):
# the related object. We return this directly instead of returning the # the related object. We return this directly instead of returning the
# object itself, which would require a database lookup. # object itself, which would require a database lookup.
try: try:
instance = get_attribute(instance, self.source_attrs[:-1]) instance = get_attribute(instance, self.source_attrs[:-1], self.required)
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1])) return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError: except AttributeError:
return get_attribute(instance, self.source_attrs) return get_attribute(instance, self.source_attrs, self.required)
def get_iterable(self, instance, source_attrs): def get_iterable(self, instance, source_attrs):
# For consistency with `get_attribute` we're using `serializable_value()` # For consistency with `get_attribute` we're using `serializable_value()`
@ -349,7 +349,7 @@ class ManyRelatedField(Field):
] ]
def get_attribute(self, instance): def get_attribute(self, instance):
return self.child_relation.get_iterable(instance, self.source_attrs) return self.child_relation.get_iterable(instance, self.source_attrs, self.required)
def to_representation(self, iterable): def to_representation(self, iterable):
return [ return [