resolved merge conflict

This commit is contained in:
Dima Knivets 2019-09-15 15:17:14 +03:00
commit a9cda0b734
8 changed files with 280 additions and 17 deletions

View File

@ -2,8 +2,8 @@
# just Django, but for the purposes of development and testing # just Django, but for the purposes of development and testing
# there are a number of packages that are useful to install. # there are a number of packages that are useful to install.
# Laying these out as seperate requirements files, allows us to # Laying these out as separate requirements files, allows us to
# only included the relevent sets when running tox, and ensures # only included the relevant sets when running tox, and ensures
# we are only ever declaring our dependencies in one place. # we are only ever declaring our dependencies in one place.
-r requirements/requirements-optionals.txt -r requirements/requirements-optionals.txt

View File

@ -138,6 +138,9 @@ class BasePagination:
def get_paginated_response(self, data): # pragma: no cover def get_paginated_response(self, data): # pragma: no cover
raise NotImplementedError('get_paginated_response() must be implemented.') raise NotImplementedError('get_paginated_response() must be implemented.')
def get_paginated_response_schema(self, schema):
return schema
def to_html(self): # pragma: no cover def to_html(self): # pragma: no cover
raise NotImplementedError('to_html() must be implemented to display page controls.') raise NotImplementedError('to_html() must be implemented to display page controls.')
@ -222,6 +225,26 @@ class PageNumberPagination(BasePagination):
('results', data) ('results', data)
])) ]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': schema,
},
}
def get_page_size(self, request): def get_page_size(self, request):
if self.page_size_query_param: if self.page_size_query_param:
try: try:
@ -369,6 +392,26 @@ class LimitOffsetPagination(BasePagination):
('results', data) ('results', data)
])) ]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': schema,
},
}
def get_limit(self, request): def get_limit(self, request):
if self.limit_query_param: if self.limit_query_param:
try: try:
@ -840,6 +883,22 @@ class CursorPagination(BasePagination):
('results', data) ('results', data)
])) ]))
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': schema,
},
}
def get_html_context(self): def get_html_context(self):
return { return {
'previous_url': self.get_previous_link(), 'previous_url': self.get_previous_link(),

View File

@ -209,11 +209,10 @@ class AutoSchema(ViewInspector):
if not is_list_view(path, method, view): if not is_list_view(path, method, view):
return [] return []
pagination = getattr(view, 'pagination_class', None) paginator = self._get_pagninator()
if not pagination: if not paginator:
return [] return []
paginator = view.pagination_class()
return paginator.get_schema_operation_parameters(view) return paginator.get_schema_operation_parameters(view)
def _map_field(self, field): def _map_field(self, field):
@ -387,7 +386,7 @@ class AutoSchema(ViewInspector):
schema['default'] = field.default schema['default'] = field.default
if field.help_text: if field.help_text:
schema['description'] = str(field.help_text) schema['description'] = str(field.help_text)
self._map_field_validators(field.validators, schema) self._map_field_validators(field, schema)
properties[field.field_name] = schema properties[field.field_name] = schema
@ -399,13 +398,11 @@ class AutoSchema(ViewInspector):
return result return result
def _map_field_validators(self, validators, schema): def _map_field_validators(self, field, schema):
""" """
map field validators map field validators
:param list:validators: list of field validators
:param dict:schema: schema that the validators get added to
""" """
for v in validators: for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
if isinstance(v, EmailValidator): if isinstance(v, EmailValidator):
@ -415,9 +412,15 @@ class AutoSchema(ViewInspector):
if isinstance(v, RegexValidator): if isinstance(v, RegexValidator):
schema['pattern'] = v.regex.pattern schema['pattern'] = v.regex.pattern
elif isinstance(v, MaxLengthValidator): elif isinstance(v, MaxLengthValidator):
schema['maxLength'] = v.limit_value attr_name = 'maxLength'
if isinstance(field, serializers.ListField):
attr_name = 'maxItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MinLengthValidator): elif isinstance(v, MinLengthValidator):
schema['minLength'] = v.limit_value attr_name = 'minLength'
if isinstance(field, serializers.ListField):
attr_name = 'minItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MaxValueValidator): elif isinstance(v, MaxValueValidator):
schema['maximum'] = v.limit_value schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator): elif isinstance(v, MinValueValidator):
@ -432,15 +435,22 @@ class AutoSchema(ViewInspector):
schema['maximum'] = int(digits * '9') + 1 schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum'] schema['minimum'] = -schema['maximum']
def _get_pagninator(self):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
return pagination_class()
return None
def map_parsers(self, path, method): def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes)) return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method): def map_renderers(self, path, method):
media_types = [] media_types = []
for renderer in self.view.renderer_classes: for renderer in self.view.renderer_classes:
# I assume this is not relevant to OpenAPI spec # BrowsableAPIRenderer not relevant to OpenAPI spec
if renderer != renderers.BrowsableAPIRenderer: if renderer == renderers.BrowsableAPIRenderer:
media_types.append(renderer.media_type) continue
media_types.append(renderer.media_type)
return media_types return media_types
def _get_serializer(self, method, path): def _get_serializer(self, method, path):
@ -513,6 +523,9 @@ class AutoSchema(ViewInspector):
'type': 'array', 'type': 'array',
'items': item_schema, 'items': item_schema,
} }
paginator = self._get_pagninator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else: else:
response_schema = item_schema response_schema = item_schema

View File

@ -356,7 +356,15 @@ class APIView(View):
throttle_durations.append(throttle.wait()) throttle_durations.append(throttle.wait())
if throttle_durations: if throttle_durations:
self.throttled(request, max(throttle_durations)) # Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
""" """

View File

@ -264,6 +264,58 @@ class TestOperationIntrospection(TestCase):
}, },
} }
def test_paginated_list_response_body_generation(self):
"""Test that pagination properties are added for a paginated list view."""
path = '/'
method = 'GET'
class Pagination(pagination.BasePagination):
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'item': schema,
}
class ItemSerializer(serializers.Serializer):
text = serializers.CharField()
class View(generics.GenericAPIView):
serializer_class = ItemSerializer
pagination_class = Pagination
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
assert responses == {
'200': {
'description': '',
'content': {
'application/json': {
'schema': {
'type': 'object',
'item': {
'type': 'array',
'items': {
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
},
},
},
},
},
},
}
def test_delete_response_body_generation(self): def test_delete_response_body_generation(self):
"""Test that a view's delete method generates a proper response body schema.""" """Test that a view's delete method generates a proper response body schema."""
path = '/{id}/' path = '/{id}/'
@ -312,15 +364,27 @@ class TestOperationIntrospection(TestCase):
assert attachment['format'] == 'binary' assert attachment['format'] == 'binary'
def test_retrieve_response_body_generation(self): def test_retrieve_response_body_generation(self):
"""Test that a list of properties is returned for retrieve item views.""" """
Test that a list of properties is returned for retrieve item views.
Pagination properties should not be added as the view represents a single item.
"""
path = '/{id}/' path = '/{id}/'
method = 'GET' method = 'GET'
class Pagination(pagination.BasePagination):
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'item': schema,
}
class ItemSerializer(serializers.Serializer): class ItemSerializer(serializers.Serializer):
text = serializers.CharField() text = serializers.CharField()
class View(generics.GenericAPIView): class View(generics.GenericAPIView):
serializer_class = ItemSerializer serializer_class = ItemSerializer
pagination_class = Pagination
view = create_view( view = create_view(
View, View,
@ -419,6 +483,9 @@ class TestOperationIntrospection(TestCase):
assert properties['string']['minLength'] == 2 assert properties['string']['minLength'] == 2
assert properties['string']['maxLength'] == 10 assert properties['string']['maxLength'] == 10
assert properties['lst']['minItems'] == 2
assert properties['lst']['maxItems'] == 10
assert properties['regex']['pattern'] == r'[ABC]12{3}' assert properties['regex']['pattern'] == r'[ABC]12{3}'
assert properties['regex']['description'] == 'must have an A, B, or C followed by 1222' assert properties['regex']['description'] == 'must have an A, B, or C followed by 1222'

View File

@ -85,6 +85,12 @@ class ExampleValidatedSerializer(serializers.Serializer):
), ),
help_text='must have an A, B, or C followed by 1222' help_text='must have an A, B, or C followed by 1222'
) )
lst = serializers.ListField(
validators=(
MaxLengthValidator(limit_value=10),
MinLengthValidator(limit_value=2),
)
)
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2) decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
validators=(DecimalValidator(max_digits=17, decimal_places=4),)) validators=(DecimalValidator(max_digits=17, decimal_places=4),))

View File

@ -259,6 +259,37 @@ class TestPageNumberPagination:
with pytest.raises(exceptions.NotFound): with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request) self.paginate_queryset(request)
def test_get_paginated_response_schema(self):
unpaginated_schema = {
'type': 'object',
'item': {
'properties': {
'test-property': {
'type': 'integer',
},
},
},
}
assert self.pagination.get_paginated_response_schema(unpaginated_schema) == {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': unpaginated_schema,
},
}
class TestPageNumberPaginationOverride: class TestPageNumberPaginationOverride:
""" """
@ -535,6 +566,37 @@ class TestLimitOffset:
assert content.get('next') == next_url assert content.get('next') == next_url
assert content.get('previous') == prev_url assert content.get('previous') == prev_url
def test_get_paginated_response_schema(self):
unpaginated_schema = {
'type': 'object',
'item': {
'properties': {
'test-property': {
'type': 'integer',
},
},
},
}
assert self.pagination.get_paginated_response_schema(unpaginated_schema) == {
'type': 'object',
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': unpaginated_schema,
},
}
class CursorPaginationTestsMixin: class CursorPaginationTestsMixin:
@ -834,6 +896,33 @@ class CursorPaginationTestsMixin:
assert current == [1, 1, 1, 1, 1] assert current == [1, 1, 1, 1, 1]
assert next == [1, 2, 3, 4, 4] assert next == [1, 2, 3, 4, 4]
def test_get_paginated_response_schema(self):
unpaginated_schema = {
'type': 'object',
'item': {
'properties': {
'test-property': {
'type': 'integer',
},
},
},
}
assert self.pagination.get_paginated_response_schema(unpaginated_schema) == {
'type': 'object',
'properties': {
'next': {
'type': 'string',
'nullable': True,
},
'previous': {
'type': 'string',
'nullable': True,
},
'results': unpaginated_schema,
},
}
class TestCursorPagination(CursorPaginationTestsMixin): class TestCursorPagination(CursorPaginationTestsMixin):
""" """

View File

@ -159,6 +159,27 @@ class ThrottlingTests(TestCase):
assert response.status_code == 429 assert response.status_code == 429
assert int(response['retry-after']) == 58 assert int(response['retry-after']) == 58
def test_throttle_rate_change_negative(self):
self.set_throttle_timer(MockView_DoubleThrottling, 0)
request = self.factory.get('/')
for dummy in range(24):
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
assert int(response['retry-after']) == 60
previous_rate = User3SecRateThrottle.rate
try:
User3SecRateThrottle.rate = '1/sec'
for dummy in range(24):
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
assert int(response['retry-after']) == 60
finally:
# reset
User3SecRateThrottle.rate = previous_rate
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
""" """
Ensure the response returns an Retry-After field with status and next attributes Ensure the response returns an Retry-After field with status and next attributes