Add manual_fields kwarg to AutoSchema

This commit is contained in:
Carlton Gibson 2017-08-30 12:30:10 +02:00
parent 416e57f7b8
commit 398824f96b
3 changed files with 58 additions and 8 deletions

View File

@ -164,8 +164,7 @@ appropriate Core API `Link` object for the view, request method and path:
auto_schema = view.schema auto_schema = view.schema
coreapi_link = auto_schema.get_link(...) coreapi_link = auto_schema.get_link(...)
(In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for
(Aside: In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for
each view, allowed method and path.) each view, allowed method and path.)
To customise the `Link` generation you may: To customise the `Link` generation you may:
@ -178,9 +177,9 @@ To customise the `Link` generation you may:
class CustomView(APIView): class CustomView(APIView):
... ...
schema = AutoSchema( schema = AutoSchema(
manual_fields= { manual_fields=[
"extra_field": coreapi.Field(...) coreapi.Field("extra_field", ...),
} ]
) )
This allows extension for the most common case without subclassing. This allows extension for the most common case without subclassing.
@ -512,9 +511,22 @@ A class that deals with introspection of individual views for schema generation.
`AutoSchema` is attached to `APIView` via the `schema` attribute. `AutoSchema` is attached to `APIView` via the `schema` attribute.
Typically you will subclass `AutoSchema` to customise schema generation The `AutoSchema` constructor takes a single keyword argument `manual_fields`.
and then set your subclass on your view.
**`manual_fields`**: a `list` of `coreapi.Field` instances that will be added to
the generated fields. Generated fields with a matching `name` will be overwritten.
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
For more advanced customisation subclass `AutoSchema` to customise schema generation.
class CustomViewSchema(AutoSchema): class CustomViewSchema(AutoSchema):
""" """
@ -529,10 +541,13 @@ and then set your subclass on your view.
class MyView(APIView): class MyView(APIView):
schema = CustomViewSchema() schema = CustomViewSchema()
The following methods are available to override.
### get_link(self, path, method, base_url) ### get_link(self, path, method, base_url)
Returns a `coreapi.Link` instance corresponding to the given view. Returns a `coreapi.Link` instance corresponding to the given view.
This is the main entry point.
You can override this if you need to provide custom behaviors for particular views. You can override this if you need to provide custom behaviors for particular views.
### get_description(self, path, method) ### get_description(self, path, method)
@ -565,7 +580,7 @@ Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fiel
## ManualSchema ## ManualSchema
`APIViewSchemaDescriptor` subclass for specifying a manual schema. Allows specifying a manual schema for a view:
class MyView(APIView): class MyView(APIView):
schema = ManualSchema(coreapi.Link( schema = ManualSchema(coreapi.Link(

View File

@ -301,12 +301,28 @@ class AutoSchema(ViewInspector):
Responsible for per-view instrospection and schema generation. Responsible for per-view instrospection and schema generation.
""" """
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
self._manual_fields = manual_fields
def get_link(self, path, method, base_url): def get_link(self, path, method, base_url):
fields = self.get_path_fields(path, method) fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method) fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method) fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method) fields += self.get_filter_fields(path, method)
if self._manual_fields is not None:
by_name = {f.name: f for f in fields}
for f in self._manual_fields:
by_name[f.name] = f
fields = list(by_name.values())
if fields and any([field.location in ('form', 'body') for field in fields]): if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method) encoding = self.get_encoding(path, method)
else: else:

View File

@ -513,6 +513,25 @@ class TestDescriptor(TestCase):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
def test_manual_fields(self):
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
view = CustomView()
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
def test_view_with_manual_schema(self): def test_view_with_manual_schema(self):
expected = coreapi.Link( expected = coreapi.Link(