From 633d4eae8b8b91c66f023c87e4339147213039c2 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Tue, 28 Nov 2017 09:14:16 +0100 Subject: [PATCH] Extract method for `manual_fields` processing Allows reuse of logic to replace Field instances in a field list by `Field.name`. Adds a utility function for the logic plus a wrapper method on `AutoSchema`. Closes #5632 --- rest_framework/schemas/inspectors.py | 33 +++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 47f5b9e13..b1e461c54 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -105,6 +105,25 @@ def get_pk_description(model, model_field): ) +def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + by_name = {f.name: f for f in fields} + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + + class ViewInspector(object): """ Descriptor class on APIView. @@ -181,11 +200,7 @@ class AutoSchema(ViewInspector): fields += self.get_pagination_fields(path, method) fields += self.get_filter_fields(path, method) - if self._manual_fields is not None: - by_name = {f.name: f for f in fields} - for f in self._manual_fields: - by_name[f.name] = f - fields = list(by_name.values()) + fields = self.update_manual_fields(fields) if fields and any([field.location in ('form', 'body') for field in fields]): encoding = self.get_encoding(path, method) @@ -379,6 +394,14 @@ class AutoSchema(ViewInspector): fields += filter_backend().get_schema_fields(self.view) return fields + def update_manual_fields(self, fields): + """ + Adjust `fields` with `manual_fields` + """ + if self._manual_fields is not None: + fields = update_fields(fields, self._manual_fields) + return fields + def get_encoding(self, path, method): """ Return the 'encoding' parameter to use for a given endpoint.