Add support for namespaces in APIRoot.

This commit is contained in:
Rocky Meza 2014-12-24 19:37:13 -07:00
parent 35768344db
commit d8a7e65514
4 changed files with 77 additions and 14 deletions

View File

@ -189,6 +189,15 @@ class RequestFactory(DjangoRequestFactory):
return self.request(**r) return self.request(**r)
# request only provides `resolver_match` from 1.5 onwards.
def get_resolver_match(request):
try:
return request.resolver_match
except AttributeError: # Django < 1.5
from django.core.urlresolvers import resolve
return resolve(request.path_info)
# Markdown is optional # Markdown is optional
try: try:
import markdown import markdown

View File

@ -21,7 +21,7 @@ from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch from django.core.urlresolvers import NoReverseMatch
from rest_framework import views from rest_framework import views
from rest_framework.compat import OrderedDict from rest_framework.compat import OrderedDict, get_resolver_match
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
@ -290,9 +290,19 @@ class DefaultRouter(SimpleRouter):
class APIRoot(views.APIView): class APIRoot(views.APIView):
_ignore_model_permissions = True _ignore_model_permissions = True
def get_namespace(self):
"""
Attempt to retrieve the namespace of the current router.
"""
resolver_match = get_resolver_match(self.request)
return resolver_match.namespace
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
ret = OrderedDict() ret = OrderedDict()
namespace = self.get_namespace()
for key, url_name in api_root_dict.items(): for key, url_name in api_root_dict.items():
if namespace:
url_name = namespace + ':' + url_name
try: try:
ret[key] = reverse( ret[key] = reverse(
url_name, url_name,

32
tests/router_test_urls.py Normal file
View File

@ -0,0 +1,32 @@
from django.conf.urls import url, include
from rest_framework import viewsets, mixins, routers
from .test_routers import APIRootTestModel
class APIRootTestViewSet(viewsets.ModelViewSet):
model = APIRootTestModel
class ListlessViewSet(mixins.RetrieveModelMixin,
viewsets.GenericViewSet):
model = APIRootTestModel
router = routers.DefaultRouter()
router.register(r'test-model', APIRootTestViewSet)
listless_router = routers.DefaultRouter()
# Avoid conflict with the api/ route.
listless_router.root_view_name = 'listless-api-root'
listless_router.register(r'full', APIRootTestViewSet, 'full')
listless_router.register(r'listless', ListlessViewSet, 'listless')
urlpatterns = [
url(r'^api/', include(router.urls)),
url(r'^namespaced-api/', include(router.urls, namespace='api-namespace')),
url(r'^listless/', include(listless_router.urls)),
]

View File

@ -3,7 +3,8 @@ from django.conf.urls import patterns, url, include
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, mixins, permissions from django.core import urlresolvers
from rest_framework import serializers, viewsets, permissions
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.routers import SimpleRouter, DefaultRouter
@ -307,17 +308,28 @@ class TestDynamicListAndDetailRouter(TestCase):
self.assertEqual(route.mapping[method_map], method_name) self.assertEqual(route.mapping[method_map], method_name)
class TestRootWithAListlessViewset(TestCase): # APIRoot
def setUp(self): class APIRootTestModel(models.Model):
class NoteViewSet(mixins.RetrieveModelMixin, pass
viewsets.GenericViewSet):
model = RouterTestModel
self.router = DefaultRouter()
self.router.register(r'notes', NoteViewSet)
self.view = self.router.urls[0].callback
def test_api_root(self): class TestAPIRootView(TestCase):
request = factory.get('/') urls = 'tests.router_test_urls'
response = self.view(request)
self.assertEqual(response.data, {}) def test_listless(self):
url = urlresolvers.reverse('listless-api-root')
response = self.client.get(url)
self.assertIn('full', response.data)
self.assertNotIn('listless', response.data)
def test_normal_api_root_contains_routes(self):
url = urlresolvers.reverse('api-root')
response = self.client.get(url)
self.assertIn('test-model', response.data)
self.assertEqual(response.data['test-model'], 'http://testserver/api/test-model/')
def test_namespaced_api_root_contains_routes(self):
url = urlresolvers.reverse('api-namespace:api-root')
response = self.client.get(url)
self.assertIn('test-model', response.data)
self.assertEqual(response.data['test-model'], 'http://testserver/namespaced-api/test-model/')