From a0cdba627767ec489fbc683483a8904bdfe606a9 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Mon, 4 Dec 2017 09:07:43 +0100 Subject: [PATCH] Extract method for `manual_fields` processing (#5633) * Extract method for `manual_fields` processing Allows reuse of logic to replace Field instances in a field list by `Field.name`. Adds a utility function for the logic plus a wrapper method on `AutoSchema`. Closes #5632 * Manual fields suggestions (#2) * Use OrderedDict in inspectors * Move empty check to 'update_fields()' * Make 'update_fields()' an AutoSchema staticmethod * Add 'AutoSchema.get_manual_fields()' * Conform '.get_manual_fields()' to other methods * Add test for update_fields * Make sure `manual_fields` is a list. (As documented to be) * Add docs for new AutoSchema methods. * `get_manual_fields` * `update_fields` * Add release notes for PR. --- docs/api-guide/schemas.md | 25 +++++++++++++++++ docs/topics/release-notes.md | 19 +++++++++++++ rest_framework/schemas/inspectors.py | 35 +++++++++++++++++++----- tests/test_schemas.py | 40 ++++++++++++++++++++++++++-- 4 files changed, 111 insertions(+), 8 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 22894a978..2b83e0671 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -603,6 +603,31 @@ Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fiel Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. +### get_manual_fields(self, path, method) + +Return a list of `coreapi.Field()` instances to be added to or replace generated fields. Defaults to (optional) `manual_fields` passed to `AutoSchema` constructor. + +May be overridden to customise manual fields by `path` or `method`. For example, a per-method adjustment may look like this: + +```python +def get_manual_fields(self, path, method): + """Example adding per-method fields.""" + + extra_fields = [] + if method=='GET': + extra_fields = # ... list of extra fields for GET ... + if method=='POST': + extra_fields = # ... list of extra fields for POST ... + + manual_fields = super().get_manual_fields() + return manual_fields + extra_fields +``` + +### update_fields(fields, update_with) + +Utility `staticmethod`. Encapsulates logic to add or replace fields from a list +by `Field.name`. May be overridden to adjust replacement criteria. + ## ManualSchema diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 44d8c7a12..2f2cdf1a1 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,25 @@ You can determine your currently installed version using `pip freeze`: ## 3.7.x series +### 3.7.4 + +**Date**: UNRELEASED + +* Extract method for `manual_fields` processing [#5633][gh5633] + + Allows for easier customisation of `manual_fields` processing, for example + to provide per-method manual fields. `AutoSchema` adds `get_manual_fields`, + as the intended override point, and a utility method `update_fields`, to + handle by-name field replacement from a list, which, in general, you are not + expected to override. + + Note: `AutoSchema.__init__` now ensures `manual_fields` is a list. + Previously may have been stored internally as `None`. + + +[gh5633]: https://github.com/encode/django-rest-framework/issues/5633 + + ### 3.7.3 diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 47f5b9e13..008d7c091 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -172,7 +172,8 @@ class AutoSchema(ViewInspector): * `manual_fields`: list of `coreapi.Field` instances that will be added to auto-generated fields, overwriting on `Field.name` """ - + if manual_fields is None: + manual_fields = [] self._manual_fields = manual_fields def get_link(self, path, method, base_url): @@ -181,11 +182,8 @@ class AutoSchema(ViewInspector): fields += self.get_pagination_fields(path, method) fields += self.get_filter_fields(path, method) - if self._manual_fields is not None: - by_name = {f.name: f for f in fields} - for f in self._manual_fields: - by_name[f.name] = f - fields = list(by_name.values()) + manual_fields = self.get_manual_fields(path, method) + fields = self.update_fields(fields, manual_fields) if fields and any([field.location in ('form', 'body') for field in fields]): encoding = self.get_encoding(path, method) @@ -379,6 +377,31 @@ class AutoSchema(ViewInspector): fields += filter_backend().get_schema_fields(self.view) return fields + def get_manual_fields(self, path, method): + return self._manual_fields + + @staticmethod + def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + if not update_with: + return fields + + by_name = OrderedDict((f.name, f) for f in fields) + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + def get_encoding(self, path, method): """ Return the 'encoding' parameter to use for a given endpoint. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 56692d4f5..ba561a959 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -516,7 +516,7 @@ class Test4605Regression(TestCase): assert prefix == '/' -class TestDescriptor(TestCase): +class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): view = APIView() @@ -528,7 +528,43 @@ class TestDescriptor(TestCase): with pytest.raises(AssertionError): descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? - def test_manual_fields(self): + def test_update_fields(self): + """ + That updating fields by-name helper is correct + + Recall: `update_fields(fields, update_with)` + """ + schema = AutoSchema() + fields = [] + + # Adds a field... + fields = schema.update_fields(fields, [ + coreapi.Field( + "my_field", + required=True, + location="path", + schema=coreschema.String() + ), + ]) + + assert len(fields) == 1 + assert fields[0].name == "my_field" + + # Replaces a field... + fields = schema.update_fields(fields, [ + coreapi.Field( + "my_field", + required=False, + location="path", + schema=coreschema.String() + ), + ]) + + assert len(fields) == 1 + assert fields[0].required is False + + def test_get_manual_fields(self): + """That get_manual_fields is applied during get_link""" class CustomView(APIView): schema = AutoSchema(manual_fields=[