Move get_serializer_fields to descriptor

This commit is contained in:
Carlton Gibson 2017-08-23 10:59:25 +02:00
parent 238c5994a2
commit 0a173efac2

View File

@ -278,8 +278,9 @@ class APIViewSchemaDescriptor(object):
view = self.view
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
# TEMP: now we proxy back to the generator
fields += generator.get_serializer_fields(path, method, view)
fields += generator.get_pagination_fields(path, method, view)
fields += generator.get_filter_fields(path, method, view)
@ -384,6 +385,50 @@ class APIViewSchemaDescriptor(object):
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
# TODO: Where should this live?
# - We import APIView here. So we can't import the descriptor into `views`
@ -592,48 +637,6 @@ class SchemaGenerator(object):
return None
def get_serializer_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method, view):
if not is_list_view(path, method, view):
return []