mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-14 13:46:59 +03:00
0407a0df8a
Thanks to Jon Dufresne (@jdufresne) for review. Co-authored-by: Asif Saif Uddin <auvipy@gmail.com> Co-authored-by: Rizwan Mansuri <Rizwan@webbyfox.com>
511 lines
17 KiB
Python
511 lines
17 KiB
Python
"""
|
|
inspectors.py # Per-endpoint view introspection
|
|
|
|
See schemas.__init__.py for package overview.
|
|
"""
|
|
import re
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from urllib import parse
|
|
from weakref import WeakKeyDictionary
|
|
|
|
from django.db import models
|
|
from django.utils.encoding import force_text, smart_text
|
|
from django.utils.translation import ugettext_lazy as _
|
|
|
|
from rest_framework import exceptions, serializers
|
|
from rest_framework.compat import coreapi, coreschema, uritemplate
|
|
from rest_framework.settings import api_settings
|
|
from rest_framework.utils import formatting
|
|
|
|
from .utils import is_list_view
|
|
|
|
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
|
|
|
|
|
def field_to_schema(field):
|
|
title = force_text(field.label) if field.label else ''
|
|
description = force_text(field.help_text) if field.help_text else ''
|
|
|
|
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
|
child_schema = field_to_schema(field.child)
|
|
return coreschema.Array(
|
|
items=child_schema,
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.DictField):
|
|
return coreschema.Object(
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.Serializer):
|
|
return coreschema.Object(
|
|
properties=OrderedDict([
|
|
(key, field_to_schema(value))
|
|
for key, value
|
|
in field.fields.items()
|
|
]),
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.ManyRelatedField):
|
|
related_field_schema = field_to_schema(field.child_relation)
|
|
|
|
return coreschema.Array(
|
|
items=related_field_schema,
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.PrimaryKeyRelatedField):
|
|
schema_cls = coreschema.String
|
|
model = getattr(field.queryset, 'model', None)
|
|
if model is not None:
|
|
model_field = model._meta.pk
|
|
if isinstance(model_field, models.AutoField):
|
|
schema_cls = coreschema.Integer
|
|
return schema_cls(title=title, description=description)
|
|
elif isinstance(field, serializers.RelatedField):
|
|
return coreschema.String(title=title, description=description)
|
|
elif isinstance(field, serializers.MultipleChoiceField):
|
|
return coreschema.Array(
|
|
items=coreschema.Enum(enum=list(field.choices)),
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.ChoiceField):
|
|
return coreschema.Enum(
|
|
enum=list(field.choices),
|
|
title=title,
|
|
description=description
|
|
)
|
|
elif isinstance(field, serializers.BooleanField):
|
|
return coreschema.Boolean(title=title, description=description)
|
|
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
|
return coreschema.Number(title=title, description=description)
|
|
elif isinstance(field, serializers.IntegerField):
|
|
return coreschema.Integer(title=title, description=description)
|
|
elif isinstance(field, serializers.DateField):
|
|
return coreschema.String(
|
|
title=title,
|
|
description=description,
|
|
format='date'
|
|
)
|
|
elif isinstance(field, serializers.DateTimeField):
|
|
return coreschema.String(
|
|
title=title,
|
|
description=description,
|
|
format='date-time'
|
|
)
|
|
elif isinstance(field, serializers.JSONField):
|
|
return coreschema.Object(title=title, description=description)
|
|
|
|
if field.style.get('base_template') == 'textarea.html':
|
|
return coreschema.String(
|
|
title=title,
|
|
description=description,
|
|
format='textarea'
|
|
)
|
|
|
|
return coreschema.String(title=title, description=description)
|
|
|
|
|
|
def get_pk_description(model, model_field):
|
|
if isinstance(model_field, models.AutoField):
|
|
value_type = _('unique integer value')
|
|
elif isinstance(model_field, models.UUIDField):
|
|
value_type = _('UUID string')
|
|
else:
|
|
value_type = _('unique value')
|
|
|
|
return _('A {value_type} identifying this {name}.').format(
|
|
value_type=value_type,
|
|
name=model._meta.verbose_name,
|
|
)
|
|
|
|
|
|
class ViewInspector:
|
|
"""
|
|
Descriptor class on APIView.
|
|
|
|
Provide subclass for per-view schema generation
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.instance_schemas = WeakKeyDictionary()
|
|
|
|
def __get__(self, instance, owner):
|
|
"""
|
|
Enables `ViewInspector` as a Python _Descriptor_.
|
|
|
|
This is how `view.schema` knows about `view`.
|
|
|
|
`__get__` is called when the descriptor is accessed on the owner.
|
|
(That will be when view.schema is called in our case.)
|
|
|
|
`owner` is always the owner class. (An APIView, or subclass for us.)
|
|
`instance` is the view instance or `None` if accessed from the class,
|
|
rather than an instance.
|
|
|
|
See: https://docs.python.org/3/howto/descriptor.html for info on
|
|
descriptor usage.
|
|
"""
|
|
if instance in self.instance_schemas:
|
|
return self.instance_schemas[instance]
|
|
|
|
self.view = instance
|
|
return self
|
|
|
|
def __set__(self, instance, other):
|
|
self.instance_schemas[instance] = other
|
|
if other is not None:
|
|
other.view = instance
|
|
|
|
@property
|
|
def view(self):
|
|
"""View property."""
|
|
assert self._view is not None, (
|
|
"Schema generation REQUIRES a view instance. (Hint: you accessed "
|
|
"`schema` from the view class rather than an instance.)"
|
|
)
|
|
return self._view
|
|
|
|
@view.setter
|
|
def view(self, value):
|
|
self._view = value
|
|
|
|
@view.deleter
|
|
def view(self):
|
|
self._view = None
|
|
|
|
def get_link(self, path, method, base_url):
|
|
"""
|
|
Generate `coreapi.Link` for self.view, path and method.
|
|
|
|
This is the main _public_ access point.
|
|
|
|
Parameters:
|
|
|
|
* path: Route path for view from URLConf.
|
|
* method: The HTTP request method.
|
|
* base_url: The project "mount point" as given to SchemaGenerator
|
|
"""
|
|
raise NotImplementedError(".get_link() must be overridden.")
|
|
|
|
|
|
class AutoSchema(ViewInspector):
|
|
"""
|
|
Default inspector for APIView
|
|
|
|
Responsible for per-view introspection and schema generation.
|
|
"""
|
|
def __init__(self, manual_fields=None):
|
|
"""
|
|
Parameters:
|
|
|
|
* `manual_fields`: list of `coreapi.Field` instances that
|
|
will be added to auto-generated fields, overwriting on `Field.name`
|
|
"""
|
|
super().__init__()
|
|
if manual_fields is None:
|
|
manual_fields = []
|
|
self._manual_fields = manual_fields
|
|
|
|
def get_link(self, path, method, base_url):
|
|
fields = self.get_path_fields(path, method)
|
|
fields += self.get_serializer_fields(path, method)
|
|
fields += self.get_pagination_fields(path, method)
|
|
fields += self.get_filter_fields(path, method)
|
|
|
|
manual_fields = self.get_manual_fields(path, method)
|
|
fields = self.update_fields(fields, manual_fields)
|
|
|
|
if fields and any([field.location in ('form', 'body') for field in fields]):
|
|
encoding = self.get_encoding(path, method)
|
|
else:
|
|
encoding = None
|
|
|
|
description = self.get_description(path, method)
|
|
|
|
if base_url and path.startswith('/'):
|
|
path = path[1:]
|
|
|
|
return coreapi.Link(
|
|
url=parse.urljoin(base_url, path),
|
|
action=method.lower(),
|
|
encoding=encoding,
|
|
fields=fields,
|
|
description=description
|
|
)
|
|
|
|
def get_description(self, path, method):
|
|
"""
|
|
Determine a link description.
|
|
|
|
This will be based on the method docstring if one exists,
|
|
or else the class docstring.
|
|
"""
|
|
view = self.view
|
|
|
|
method_name = getattr(view, 'action', method.lower())
|
|
method_docstring = getattr(view, method_name, None).__doc__
|
|
if method_docstring:
|
|
# An explicit docstring on the method or action.
|
|
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
|
|
else:
|
|
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
|
|
|
|
def _get_description_section(self, view, header, description):
|
|
lines = [line for line in description.splitlines()]
|
|
current_section = ''
|
|
sections = {'': ''}
|
|
|
|
for line in lines:
|
|
if header_regex.match(line):
|
|
current_section, seperator, lead = line.partition(':')
|
|
sections[current_section] = lead.strip()
|
|
else:
|
|
sections[current_section] += '\n' + line
|
|
|
|
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
|
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
|
if header in sections:
|
|
return sections[header].strip()
|
|
if header in coerce_method_names:
|
|
if coerce_method_names[header] in sections:
|
|
return sections[coerce_method_names[header]].strip()
|
|
return sections[''].strip()
|
|
|
|
def get_path_fields(self, path, method):
|
|
"""
|
|
Return a list of `coreapi.Field` instances corresponding to any
|
|
templated path variables.
|
|
"""
|
|
view = self.view
|
|
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
|
fields = []
|
|
|
|
for variable in uritemplate.variables(path):
|
|
title = ''
|
|
description = ''
|
|
schema_cls = coreschema.String
|
|
kwargs = {}
|
|
if model is not None:
|
|
# Attempt to infer a field description if possible.
|
|
try:
|
|
model_field = model._meta.get_field(variable)
|
|
except Exception:
|
|
model_field = None
|
|
|
|
if model_field is not None and model_field.verbose_name:
|
|
title = force_text(model_field.verbose_name)
|
|
|
|
if model_field is not None and model_field.help_text:
|
|
description = force_text(model_field.help_text)
|
|
elif model_field is not None and model_field.primary_key:
|
|
description = get_pk_description(model, model_field)
|
|
|
|
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
|
kwargs['pattern'] = view.lookup_value_regex
|
|
elif isinstance(model_field, models.AutoField):
|
|
schema_cls = coreschema.Integer
|
|
|
|
field = coreapi.Field(
|
|
name=variable,
|
|
location='path',
|
|
required=True,
|
|
schema=schema_cls(title=title, description=description, **kwargs)
|
|
)
|
|
fields.append(field)
|
|
|
|
return fields
|
|
|
|
def get_serializer_fields(self, path, method):
|
|
"""
|
|
Return a list of `coreapi.Field` instances corresponding to any
|
|
request body input, as determined by the serializer class.
|
|
"""
|
|
view = self.view
|
|
|
|
if method not in ('PUT', 'PATCH', 'POST'):
|
|
return []
|
|
|
|
if not hasattr(view, 'get_serializer'):
|
|
return []
|
|
|
|
try:
|
|
serializer = view.get_serializer()
|
|
except exceptions.APIException:
|
|
serializer = None
|
|
warnings.warn('{}.get_serializer() raised an exception during '
|
|
'schema generation. Serializer fields will not be '
|
|
'generated for {} {}.'
|
|
.format(view.__class__.__name__, method, path))
|
|
|
|
if isinstance(serializer, serializers.ListSerializer):
|
|
return [
|
|
coreapi.Field(
|
|
name='data',
|
|
location='body',
|
|
required=True,
|
|
schema=coreschema.Array()
|
|
)
|
|
]
|
|
|
|
if not isinstance(serializer, serializers.Serializer):
|
|
return []
|
|
|
|
fields = []
|
|
for field in serializer.fields.values():
|
|
if field.read_only or isinstance(field, serializers.HiddenField):
|
|
continue
|
|
|
|
required = field.required and method != 'PATCH'
|
|
field = coreapi.Field(
|
|
name=field.field_name,
|
|
location='form',
|
|
required=required,
|
|
schema=field_to_schema(field)
|
|
)
|
|
fields.append(field)
|
|
|
|
return fields
|
|
|
|
def get_pagination_fields(self, path, method):
|
|
view = self.view
|
|
|
|
if not is_list_view(path, method, view):
|
|
return []
|
|
|
|
pagination = getattr(view, 'pagination_class', None)
|
|
if not pagination:
|
|
return []
|
|
|
|
paginator = view.pagination_class()
|
|
return paginator.get_schema_fields(view)
|
|
|
|
def _allows_filters(self, path, method):
|
|
"""
|
|
Determine whether to include filter Fields in schema.
|
|
|
|
Default implementation looks for ModelViewSet or GenericAPIView
|
|
actions/methods that cause filtering on the default implementation.
|
|
|
|
Override to adjust behaviour for your view.
|
|
|
|
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
|
|
to allow changes based on user experience.
|
|
"""
|
|
if getattr(self.view, 'filter_backends', None) is None:
|
|
return False
|
|
|
|
if hasattr(self.view, 'action'):
|
|
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
|
|
|
return method.lower() in ["get", "put", "patch", "delete"]
|
|
|
|
def get_filter_fields(self, path, method):
|
|
if not self._allows_filters(path, method):
|
|
return []
|
|
|
|
fields = []
|
|
for filter_backend in self.view.filter_backends:
|
|
fields += filter_backend().get_schema_fields(self.view)
|
|
return fields
|
|
|
|
def get_manual_fields(self, path, method):
|
|
return self._manual_fields
|
|
|
|
@staticmethod
|
|
def update_fields(fields, update_with):
|
|
"""
|
|
Update list of coreapi.Field instances, overwriting on `Field.name`.
|
|
|
|
Utility function to handle replacing coreapi.Field fields
|
|
from a list by name. Used to handle `manual_fields`.
|
|
|
|
Parameters:
|
|
|
|
* `fields`: list of `coreapi.Field` instances to update
|
|
* `update_with: list of `coreapi.Field` instances to add or replace.
|
|
"""
|
|
if not update_with:
|
|
return fields
|
|
|
|
by_name = OrderedDict((f.name, f) for f in fields)
|
|
for f in update_with:
|
|
by_name[f.name] = f
|
|
fields = list(by_name.values())
|
|
return fields
|
|
|
|
def get_encoding(self, path, method):
|
|
"""
|
|
Return the 'encoding' parameter to use for a given endpoint.
|
|
"""
|
|
view = self.view
|
|
|
|
# Core API supports the following request encodings over HTTP...
|
|
supported_media_types = {
|
|
'application/json',
|
|
'application/x-www-form-urlencoded',
|
|
'multipart/form-data',
|
|
}
|
|
parser_classes = getattr(view, 'parser_classes', [])
|
|
for parser_class in parser_classes:
|
|
media_type = getattr(parser_class, 'media_type', None)
|
|
if media_type in supported_media_types:
|
|
return media_type
|
|
# Raw binary uploads are supported with "application/octet-stream"
|
|
if media_type == '*/*':
|
|
return 'application/octet-stream'
|
|
|
|
return None
|
|
|
|
|
|
class ManualSchema(ViewInspector):
|
|
"""
|
|
Allows providing a list of coreapi.Fields,
|
|
plus an optional description.
|
|
"""
|
|
def __init__(self, fields, description='', encoding=None):
|
|
"""
|
|
Parameters:
|
|
|
|
* `fields`: list of `coreapi.Field` instances.
|
|
* `description`: String description for view. Optional.
|
|
"""
|
|
super().__init__()
|
|
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
|
self._fields = fields
|
|
self._description = description
|
|
self._encoding = encoding
|
|
|
|
def get_link(self, path, method, base_url):
|
|
|
|
if base_url and path.startswith('/'):
|
|
path = path[1:]
|
|
|
|
return coreapi.Link(
|
|
url=parse.urljoin(base_url, path),
|
|
action=method.lower(),
|
|
encoding=self._encoding,
|
|
fields=self._fields,
|
|
description=self._description
|
|
)
|
|
|
|
|
|
class DefaultSchema(ViewInspector):
|
|
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
|
|
def __get__(self, instance, owner):
|
|
result = super().__get__(instance, owner)
|
|
if not isinstance(result, DefaultSchema):
|
|
return result
|
|
|
|
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
|
|
assert issubclass(inspector_class, ViewInspector), (
|
|
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
|
|
)
|
|
inspector = inspector_class()
|
|
inspector.view = instance
|
|
return inspector
|