Extract method for manual_fields processing (#5633)

* Extract method for `manual_fields` processing

Allows reuse of logic to replace Field instances in a field list by `Field.name`.

Adds a utility function for the logic plus a wrapper method on `AutoSchema`.

Closes #5632

* Manual fields suggestions (#2)

* Use OrderedDict in inspectors

* Move empty check to 'update_fields()'

* Make 'update_fields()' an AutoSchema staticmethod

* Add 'AutoSchema.get_manual_fields()'

* Conform '.get_manual_fields()' to other methods

* Add test for update_fields

* Make sure `manual_fields` is a list.

(As documented to be)

* Add docs for new AutoSchema methods.

* `get_manual_fields`
* `update_fields`

* Add release notes for PR.
This commit is contained in:
Carlton Gibson 2017-12-04 09:07:43 +01:00 committed by GitHub
parent daba5e9ba5
commit a0cdba6277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 8 deletions

View File

@ -603,6 +603,31 @@ Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fiel
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view.
### get_manual_fields(self, path, method)
Return a list of `coreapi.Field()` instances to be added to or replace generated fields. Defaults to (optional) `manual_fields` passed to `AutoSchema` constructor.
May be overridden to customise manual fields by `path` or `method`. For example, a per-method adjustment may look like this:
```python
def get_manual_fields(self, path, method):
"""Example adding per-method fields."""
extra_fields = []
if method=='GET':
extra_fields = # ... list of extra fields for GET ...
if method=='POST':
extra_fields = # ... list of extra fields for POST ...
manual_fields = super().get_manual_fields()
return manual_fields + extra_fields
```
### update_fields(fields, update_with)
Utility `staticmethod`. Encapsulates logic to add or replace fields from a list
by `Field.name`. May be overridden to adjust replacement criteria.
## ManualSchema ## ManualSchema

View File

@ -40,6 +40,25 @@ You can determine your currently installed version using `pip freeze`:
## 3.7.x series ## 3.7.x series
### 3.7.4
**Date**: UNRELEASED
* Extract method for `manual_fields` processing [#5633][gh5633]
Allows for easier customisation of `manual_fields` processing, for example
to provide per-method manual fields. `AutoSchema` adds `get_manual_fields`,
as the intended override point, and a utility method `update_fields`, to
handle by-name field replacement from a list, which, in general, you are not
expected to override.
Note: `AutoSchema.__init__` now ensures `manual_fields` is a list.
Previously may have been stored internally as `None`.
[gh5633]: https://github.com/encode/django-rest-framework/issues/5633
### 3.7.3 ### 3.7.3

View File

@ -172,7 +172,8 @@ class AutoSchema(ViewInspector):
* `manual_fields`: list of `coreapi.Field` instances that * `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name` will be added to auto-generated fields, overwriting on `Field.name`
""" """
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields self._manual_fields = manual_fields
def get_link(self, path, method, base_url): def get_link(self, path, method, base_url):
@ -181,11 +182,8 @@ class AutoSchema(ViewInspector):
fields += self.get_pagination_fields(path, method) fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method) fields += self.get_filter_fields(path, method)
if self._manual_fields is not None: manual_fields = self.get_manual_fields(path, method)
by_name = {f.name: f for f in fields} fields = self.update_fields(fields, manual_fields)
for f in self._manual_fields:
by_name[f.name] = f
fields = list(by_name.values())
if fields and any([field.location in ('form', 'body') for field in fields]): if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method) encoding = self.get_encoding(path, method)
@ -379,6 +377,31 @@ class AutoSchema(ViewInspector):
fields += filter_backend().get_schema_fields(self.view) fields += filter_backend().get_schema_fields(self.view)
return fields return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields
def get_encoding(self, path, method): def get_encoding(self, path, method):
""" """
Return the 'encoding' parameter to use for a given endpoint. Return the 'encoding' parameter to use for a given endpoint.

View File

@ -516,7 +516,7 @@ class Test4605Regression(TestCase):
assert prefix == '/' assert prefix == '/'
class TestDescriptor(TestCase): class TestAutoSchema(TestCase):
def test_apiview_schema_descriptor(self): def test_apiview_schema_descriptor(self):
view = APIView() view = APIView()
@ -528,7 +528,43 @@ class TestDescriptor(TestCase):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
def test_manual_fields(self): def test_update_fields(self):
"""
That updating fields by-name helper is correct
Recall: `update_fields(fields, update_with)`
"""
schema = AutoSchema()
fields = []
# Adds a field...
fields = schema.update_fields(fields, [
coreapi.Field(
"my_field",
required=True,
location="path",
schema=coreschema.String()
),
])
assert len(fields) == 1
assert fields[0].name == "my_field"
# Replaces a field...
fields = schema.update_fields(fields, [
coreapi.Field(
"my_field",
required=False,
location="path",
schema=coreschema.String()
),
])
assert len(fields) == 1
assert fields[0].required is False
def test_get_manual_fields(self):
"""That get_manual_fields is applied during get_link"""
class CustomView(APIView): class CustomView(APIView):
schema = AutoSchema(manual_fields=[ schema = AutoSchema(manual_fields=[