diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index 5a7ba09a8..3aa4c7a77 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,8 +3,10 @@ Provide urlresolver functions that return fully qualified URLs or view names """ from __future__ import unicode_literals +from django.conf import settings, urls from django.core.urlresolvers import reverse as django_reverse -from django.core.urlresolvers import NoReverseMatch +from django.core.urlresolvers import NoReverseMatch, resolve +from django.http import Http404 from django.utils import six from django.utils.functional import lazy @@ -35,20 +37,60 @@ def preserve_builtin_query_params(url, request=None): def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ If versioning is being used then we pass any `reverse` calls through - to the versioning scheme instance, so that the resulting URL - can be modified if needed. + to the versioning scheme instance, so that the resulting URL can be modified if needed. """ + url = None + + # Substitute reverse function by scheme's one if versioning enabled scheme = getattr(request, 'versioning_scheme', None) if scheme is not None: - try: - url = scheme.reverse(viewname, args, kwargs, request, format, **extra) - except NoReverseMatch: - # In case the versioning scheme reversal fails, fallback to the - # default implementation - url = _reverse(viewname, args, kwargs, request, format, **extra) + def reverse_url(*a, **kw): + try: + return scheme.reverse(*a, **kw) + except NoReverseMatch: + # In case the versioning scheme reversal fails, fallback to the default implementation + return _reverse(*a, **kw) else: - url = _reverse(viewname, args, kwargs, request, format, **extra) + reverse_url = _reverse + try: + # Resolving URL normally + url = reverse_url(viewname, args, kwargs, request, format, **extra) + except NoReverseMatch: + if request and ':' not in viewname: + # Retrieving current namespace through request + try: + current_namespace = request.resolver_match.namespace + except AttributeError: + try: + current_namespace = resolve(request.path).namespace + except Http404: + current_namespace = None + + if current_namespace: + try: + # Trying to resolve URL with current namespace + viewname_to_try = '{namespace}:{viewname}'.format(namespace=current_namespace, viewname=viewname) + url = reverse_url(viewname_to_try, args, kwargs, request, format, **extra) + except NoReverseMatch: + # Trying to resolve URL with other namespaces + # (Could be wrong if views have the same name in different namespaces) + urlpatterns = urls.import_module(settings.ROOT_URLCONF).urlpatterns + namespaces = [urlpattern.namespace for urlpattern in urlpatterns + if getattr(urlpattern, 'namespace', current_namespace) != current_namespace] + + # Remove duplicates but preserve order of elements + from collections import OrderedDict + for namespace in OrderedDict.fromkeys(namespaces): + try: + viewname_to_try = '{namespace}:{viewname}'.format(namespace=namespace, viewname=viewname) + url = reverse_url(viewname_to_try, args, kwargs, request, format, **extra) + break + except NoReverseMatch: + continue + # Raise exception if everything else fails + if not url: + raise return preserve_builtin_query_params(url, request) diff --git a/tests/test_namespaces.py b/tests/test_namespaces.py new file mode 100644 index 000000000..999b410dd --- /dev/null +++ b/tests/test_namespaces.py @@ -0,0 +1,154 @@ +from django.conf.urls import include, url +from django.db import models + +from rest_framework import serializers, status, viewsets +from rest_framework.reverse import reverse +from rest_framework.routers import DefaultRouter +from rest_framework.test import APIRequestFactory, APITestCase + +from .urls import urlpatterns + +factory = APIRequestFactory() + + +# no namesapce: Model, serializer and viewset + + +class NoNamespaceModel(models.Model): + pass + + +class NoNamespaceModelSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = NoNamespaceModel + + +class NoNamespaceModelViewSet(viewsets.ModelViewSet): + queryset = NoNamespaceModel.objects.all() + serializer_class = NoNamespaceModelSerializer + + +no_namespace_router = DefaultRouter() +no_namespace_router.register('no_ns_model', NoNamespaceModelViewSet) + + +# namespace1: Model, serializer and viewset + + +class Namespace1Model(models.Model): + # Reference to NoNamespaceModel + fk_no_ns_model = models.ForeignKey(NoNamespaceModel) + + +class Namespace1ModelSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = Namespace1Model + + +class Namespace1ModelViewSet(viewsets.ModelViewSet): + queryset = Namespace1Model.objects.all() + serializer_class = Namespace1ModelSerializer + + +namespace1_router = DefaultRouter() +namespace1_router.register('ns_1_model', Namespace1ModelViewSet) + + +# namespace2: Models, serializers and viewsets + + +class Namespace2Model1(models.Model): + # Reference to Namespace1Model + fk_ns_1_model = models.ForeignKey(Namespace1Model) + + +class Namespace2Model2(models.Model): + # Reference to Namespace2Model1 + fk_ns_2_model_1 = models.ForeignKey(Namespace2Model1) + + +class Namespace2Model1Serializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = Namespace2Model1 + + +class Namespace2Model2Serializer(serializers.HyperlinkedModelSerializer): + fk_ns_2_model_1 = Namespace2Model1Serializer(read_only=True) + + class Meta: + model = Namespace2Model2 + + +class Namespace2Model1ViewSet(viewsets.ModelViewSet): + queryset = Namespace2Model1.objects.all() + serializer_class = Namespace2Model1Serializer + + +class Namespace2Model2ViewSet(viewsets.ModelViewSet): + queryset = Namespace2Model2.objects.all() + serializer_class = Namespace2Model2Serializer + + +namespace2_router = DefaultRouter() +namespace2_router.register('ns_2_model_1', Namespace2Model1ViewSet) +namespace2_router.register('ns_2_model_2', Namespace2Model2ViewSet) + + +urlpatterns += [ + url(r'^nonamespace/', include(no_namespace_router.urls)), + url(r'^namespace1/', include(namespace1_router.urls, namespace='namespace1')), + url(r'^namespace2/', include(namespace2_router.urls, namespace='namespace2')), +] + + +class NamespaceTestCase(APITestCase): + + def setUp(self): + self.request = factory.request() + self.no_ns_item = NoNamespaceModel.objects.create() + self.ns_1_item = Namespace1Model.objects.create(fk_no_ns_model=self.no_ns_item) + self.ns_2_model_1_item = Namespace2Model1.objects.create(fk_ns_1_model=self.ns_1_item) + self.ns_2_model_2_item = Namespace2Model2.objects.create(fk_ns_2_model_1=self.ns_2_model_1_item) + self.url_no_ns_item = '/nonamespace/no_ns_model/{pk}/'.format(pk=self.no_ns_item.pk) + self.url_ns_1_item = '/namespace1/ns_1_model/{pk}/'.format(pk=self.ns_1_item.pk) + self.url_ns_2_model_1_item = '/namespace2/ns_2_model_1/{pk}/'.format(pk=self.ns_2_model_1_item.pk) + self.url_ns_2_model_2_item = '/namespace2/ns_2_model_2/{pk}/'.format(pk=self.ns_2_model_2_item.pk) + + def test_reverse_with_namespace(self): + # Namespace 1 + reverse_ns_1_item = reverse('namespace1:namespace1model-detail', args=[self.ns_1_item.pk]) + self.assertEquals(reverse_ns_1_item, self.url_ns_1_item) + + # Namespace 2 - Model 1 + reverse_ns_2_model_1_item = reverse('namespace2:namespace2model1-detail', args=[self.ns_2_model_1_item.pk]) + self.assertEquals(reverse_ns_2_model_1_item, self.url_ns_2_model_1_item) + + # Namespace 2 - Model 2 + reverse_ns_2_model_2_item = reverse('namespace2:namespace2model2-detail', args=[self.ns_2_model_2_item.pk]) + self.assertEquals(reverse_ns_2_model_2_item, self.url_ns_2_model_2_item) + + def test_hyperlinked_identity_field_with_no_namespace(self): + response = self.client.get(self.url_ns_1_item) + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data.get('url', None), self.request.build_absolute_uri(self.url_ns_1_item)) + + # Test the hyperlink of the NoNamespaceModel FK + fk_url = response.data.get('fk_no_ns_model', None) + self.assertEquals(fk_url, self.request.build_absolute_uri(self.url_no_ns_item)) + + def test_hyperlinked_identity_field_with_different_namespace(self): + response = self.client.get(self.url_ns_2_model_1_item) + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data.get('url', None), self.request.build_absolute_uri(self.url_ns_2_model_1_item)) + # Test the hyperlink of the NameSpace1Model FK + self.assertEquals(response.data.get('fk_ns_1_model', None), self.request.build_absolute_uri(self.url_ns_1_item)) + + def test_hyperlinked_identity_field_with_same_namespace(self): + response = self.client.get(self.url_ns_2_model_2_item) + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data.get('url', None), self.request.build_absolute_uri(self.url_ns_2_model_2_item)) + response_item = response.data.get('fk_ns_2_model_1', {}) + # Test the hyperlink of the Namespace2Model1 FK + self.assertEquals(response_item.get('url', None), self.request.build_absolute_uri(self.url_ns_2_model_1_item)) + # Test the hyperlink of the NameSpace1Model FK + self.assertEquals(response_item.get('fk_ns_1_model', None), self.request.build_absolute_uri(self.url_ns_1_item))