Refactoring due to the previous commit.

This commit is contained in:
Xavier Ordoquy 2014-01-30 14:27:09 +01:00
parent 1319da59ce
commit c2ee52239d
2 changed files with 50 additions and 30 deletions

View File

@ -123,7 +123,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@ -131,7 +131,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter works. # Tests that the date filter works.
search_date = datetime.date(2012, 9, 22) search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if f['date'] == search_date]
@ -146,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@ -179,7 +179,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set with 'lt' in the filter class works. # Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal] expected_data = [f for f in self.data if f['decimal'] < search_decimal]
@ -187,7 +187,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if f['date'] > search_date]
@ -195,7 +195,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff' search_text = 'ff'
request = factory.get('/?text=%s' % search_text) request = factory.get('/', {'text': '%s' % search_text})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()] expected_data = [f for f in self.data if search_text in f['text'].lower()]
@ -204,7 +204,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that multiple filters works. # Tests that multiple filters works.
search_decimal = Decimal('5.25') search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) request = factory.get('/', {
'decimal': '%s' % (search_decimal,),
'date': '%s' % (search_date,)
})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and
@ -229,7 +232,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
view = FilterFieldsRootView.as_view() view = FilterFieldsRootView.as_view()
search_integer = 10 search_integer = 10
request = factory.get('/?integer=%s' % search_integer) request = factory.get('/', {'integer': '%s' % search_integer})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -260,14 +263,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set that should fail. # Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0] high_item = self.objects.filter(decimal__gt=search_decimal)[0]
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(high_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed. # Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0] low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item) low_item_data = self._serialize_object(low_item)
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(low_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data) self.assertEqual(response.data, low_item_data)
@ -276,7 +283,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item) valid_item_data = self._serialize_object(valid_item)
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) response = self.client.get(
'{url}'.format(url=self._get_url(valid_item)), {
'decimal': '{decimal}'.format(decimal=search_decimal),
'date': '{date}'.format(date=search_date)
})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data) self.assertEqual(response.data, valid_item_data)
@ -310,7 +321,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -327,7 +338,7 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text') search_fields = ('=title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=zzz') request = factory.get('/', {'search': 'zzz'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -343,7 +354,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text') search_fields = ('title', '^text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -391,7 +402,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=text') request = factory.get('/', {'ordering': 'text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -410,7 +421,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-text') request = factory.get('/', {'ordering': '-text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -429,7 +440,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=foobar') request = factory.get('/', {'ordering': 'foobar'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -498,7 +509,7 @@ class OrderingFilterTests(TestCase):
models.Count("relateds")) models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=relateds__count') request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -561,7 +572,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-username') request = factory.get('/', {'ordering': '-username'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
@ -591,7 +602,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=password') request = factory.get('/', {'ordering': 'password'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:

View File

@ -13,6 +13,15 @@ from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
# Helper function to split arguments out of an url
def split_arguments_from_url(url):
if '?' not in url:
return url
path, args = url.split('?')
args = dict(r.split('=') for r in args.split('&'))
return path, args
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
""" """
@ -79,7 +88,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -88,7 +97,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -141,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
EXPECTED_NUM_QUERIES = 2 EXPECTED_NUM_QUERIES = 2
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -150,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -159,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -186,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = BasicFilterFieldsRootView.as_view() view = BasicFilterFieldsRootView.as_view()
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -195,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -204,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -312,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):
""" """
If paginate_by_param is set, the new kwarg should limit per view requests. If paginate_by_param is set, the new kwarg should limit per view requests.
""" """
request = factory.get('/?page_size=5') request = factory.get('/', {'page_size': 5})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])
@ -340,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):
""" """
If max_paginate_by is set, it should limit page size for the view. If max_paginate_by is set, it should limit page size for the view.
""" """
request = factory.get('/?page_size=10') request = factory.get('/', data={'page_size': 10})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])