diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 47f5b9e13..b1e461c54 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -105,6 +105,25 @@ def get_pk_description(model, model_field): ) +def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + by_name = {f.name: f for f in fields} + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + + class ViewInspector(object): """ Descriptor class on APIView. @@ -181,11 +200,7 @@ class AutoSchema(ViewInspector): fields += self.get_pagination_fields(path, method) fields += self.get_filter_fields(path, method) - if self._manual_fields is not None: - by_name = {f.name: f for f in fields} - for f in self._manual_fields: - by_name[f.name] = f - fields = list(by_name.values()) + fields = self.update_manual_fields(fields) if fields and any([field.location in ('form', 'body') for field in fields]): encoding = self.get_encoding(path, method) @@ -379,6 +394,14 @@ class AutoSchema(ViewInspector): fields += filter_backend().get_schema_fields(self.view) return fields + def update_manual_fields(self, fields): + """ + Adjust `fields` with `manual_fields` + """ + if self._manual_fields is not None: + fields = update_fields(fields, self._manual_fields) + return fields + def get_encoding(self, path, method): """ Return the 'encoding' parameter to use for a given endpoint.