Tests for validating custom_method_name router attribute

This commit is contained in:
tanwanirahul 2014-11-03 14:44:47 +01:00
parent d972df7c9c
commit ea8c405201

View File

@ -8,6 +8,7 @@ from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from collections import namedtuple
factory = APIRequestFactory() factory = APIRequestFactory()
@ -260,6 +261,14 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet):
def detail_route_get(self, request, *args, **kwargs): def detail_route_get(self, request, *args, **kwargs):
return Response({'method': 'link2'}) return Response({'method': 'link2'})
@list_route(custom_method_name="list_custom-route")
def list_custom_route_get(self, request, *args, **kwargs):
return Response({'method': 'link1'})
@detail_route(custom_method_name="detail_custom-route")
def detail_custom_route_get(self, request, *args, **kwargs):
return Response({'method': 'link2'})
class TestDynamicListAndDetailRouter(TestCase): class TestDynamicListAndDetailRouter(TestCase):
def setUp(self): def setUp(self):
@ -268,22 +277,33 @@ class TestDynamicListAndDetailRouter(TestCase):
def test_list_and_detail_route_decorators(self): def test_list_and_detail_route_decorators(self):
routes = self.router.get_routes(DynamicListAndDetailViewSet) routes = self.router.get_routes(DynamicListAndDetailViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name custom_method_name')
# Make sure all these endpoints exist and none have been clobbered # Make sure all these endpoints exist and none have been clobbered
for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']): for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
MethodNamesMap('list_route_get', 'list_route_get'),
MethodNamesMap('list_route_post', 'list_route_post'),
MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
MethodNamesMap('detail_route_get', 'detail_route_get'),
MethodNamesMap('detail_route_post', 'detail_route_post')
]):
route = decorator_routes[i] route = decorator_routes[i]
# check url listing # check url listing
if endpoint.startswith('list_'): method_name = endpoint.method_name
custom_method_name = endpoint.custom_method_name
if method_name.startswith('list_'):
self.assertEqual(route.url, self.assertEqual(route.url,
'^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint)) '^{{prefix}}/{0}{{trailing_slash}}$'.format(custom_method_name))
else: else:
self.assertEqual(route.url, self.assertEqual(route.url,
'^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(custom_method_name))
# check method to function mapping # check method to function mapping
if endpoint.endswith('_post'): if method_name.endswith('_post'):
method_map = 'post' method_map = 'post'
else: else:
method_map = 'get' method_map = 'get'
self.assertEqual(route.mapping[method_map], endpoint) self.assertEqual(route.mapping[method_map], method_name)
class TestRootWithAListlessViewset(TestCase): class TestRootWithAListlessViewset(TestCase):