Fix NamespaceVersioning with hyperlinked serializer fields

This commit is contained in:
Tom Christie 2015-02-05 00:58:09 +00:00
parent 83673e8f74
commit e1c4513312
6 changed files with 72 additions and 61 deletions

View File

@ -1,7 +1,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
from django.core.urlresolvers import get_script_prefix, NoReverseMatch, Resolver404 from django.core.urlresolvers import get_script_prefix, resolve, NoReverseMatch, Resolver404
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils import six from django.utils import six
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
@ -9,7 +9,7 @@ from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import OrderedDict from rest_framework.compat import OrderedDict
from rest_framework.fields import get_attribute, empty, Field from rest_framework.fields import get_attribute, empty, Field
from rest_framework.reverse import reverse, resolve from rest_framework.reverse import reverse
from rest_framework.utils import html from rest_framework.utils import html
@ -167,11 +167,10 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None) self.format = kwargs.pop('format', None)
# We include these simply for dependency injection in tests. # We include this simply for dependency injection in tests.
# We can't add them as class attributes or they would expect an # We can't add it as a class attributes or it would expect an
# implicit `self` argument to be passed. # implicit `self` argument to be passed.
self.reverse = reverse self.reverse = reverse
self.resolve = resolve
super(HyperlinkedRelatedField, self).__init__(**kwargs) super(HyperlinkedRelatedField, self).__init__(**kwargs)
@ -219,11 +218,18 @@ class HyperlinkedRelatedField(RelatedField):
data = '/' + data[len(prefix):] data = '/' + data[len(prefix):]
try: try:
match = self.resolve(data, request=request) match = resolve(data)
except Resolver404: except Resolver404:
self.fail('no_match') self.fail('no_match')
if match.view_name != self.view_name: try:
expected_viewname = request.versioning_scheme.get_versioned_viewname(
self.view_name, request
)
except AttributeError:
expected_viewname = self.view_name
if match.view_name != expected_viewname:
self.fail('incorrect_match') self.fail('incorrect_match')
try: try:

View File

@ -3,23 +3,10 @@ Provide urlresolver functions that return fully qualified URLs or view names
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.urlresolvers import reverse as django_reverse from django.core.urlresolvers import reverse as django_reverse
from django.core.urlresolvers import resolve as django_resolve
from django.utils import six from django.utils import six
from django.utils.functional import lazy from django.utils.functional import lazy
def resolve(path, urlconf=None, request=None):
"""
If versioning is being used then we pass any `resolve` calls through
to the versioning scheme instance, so that the resulting view name
can be modified if needed.
"""
scheme = getattr(request, 'versioning_scheme', None)
if scheme is not None:
return scheme.resolve(path, urlconf, request)
return django_resolve(path, urlconf)
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
""" """
If versioning is being used then we pass any `reverse` calls through If versioning is being used then we pass any `reverse` calls through

View File

@ -1,8 +1,6 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.core.urlresolvers import resolve as django_resolve
from django.core.urlresolvers import ResolverMatch
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import unicode_http_header from rest_framework.compat import unicode_http_header
from rest_framework.reverse import _reverse from rest_framework.reverse import _reverse
@ -26,9 +24,6 @@ class BaseVersioning(object):
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
return _reverse(viewname, args, kwargs, request, format, **extra) return _reverse(viewname, args, kwargs, request, format, **extra)
def resolve(self, path, urlconf=None):
return django_resolve(path, urlconf)
def is_allowed_version(self, version): def is_allowed_version(self, version):
if not self.allowed_versions: if not self.allowed_versions:
return True return True
@ -127,21 +122,13 @@ class NamespaceVersioning(BaseVersioning):
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None: if request.version is not None:
viewname = request.version + ':' + viewname viewname = self.get_versioned_viewname(viewname, request)
return super(NamespaceVersioning, self).reverse( return super(NamespaceVersioning, self).reverse(
viewname, args, kwargs, request, format, **extra viewname, args, kwargs, request, format, **extra
) )
def resolve(self, path, urlconf=None, request=None): def get_versioned_viewname(self, viewname, request):
match = django_resolve(path, urlconf) return request.version + ':' + viewname
if match.namespace:
_, view_name = match.view_name.split(':')
return ResolverMatch(func=match.func,
args=match.args,
kwargs=match.kwargs,
url_name=view_name,
app_name=match.app_name)
return match
class HostNameVersioning(BaseVersioning): class HostNameVersioning(BaseVersioning):

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url from django.conf.urls import url
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -14,8 +14,7 @@ request = factory.get('/') # Just to ensure we have a request in the serializer
dummy_view = lambda request, pk: None dummy_view = lambda request, pk: None
urlpatterns = patterns( urlpatterns = [
'',
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
@ -24,7 +23,7 @@ urlpatterns = patterns(
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
) ]
# ManyToMany # ManyToMany

View File

@ -1,4 +1,4 @@
from .utils import MockObject, MockQueryset from .utils import MockObject, MockQueryset, UsingURLPatterns
from django.conf.urls import include, url from django.conf.urls import include, url
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from rest_framework import serializers from rest_framework import serializers
@ -6,8 +6,9 @@ from rest_framework import status, versioning
from rest_framework.decorators import APIView from rest_framework.decorators import APIView
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory, APITestCase, APISimpleTestCase from rest_framework.test import APIRequestFactory, APITestCase
from rest_framework.versioning import NamespaceVersioning from rest_framework.versioning import NamespaceVersioning
import pytest
class RequestVersionView(APIView): class RequestVersionView(APIView):
@ -35,18 +36,6 @@ factory = APIRequestFactory()
mock_view = lambda request: None mock_view = lambda request: None
dummy_view = lambda request, pk: None dummy_view = lambda request, pk: None
included_patterns = [
url(r'^namespaced/$', mock_view, name='another'),
url(r'^example/(?P<pk>\d+)/$', dummy_view, name='example-detail')
]
urlpatterns = [
url(r'^v1/', include(included_patterns, namespace='v1')),
url(r'^another/$', mock_view, name='another'),
url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another'),
url(r'^example/(?P<pk>\d+)/$', dummy_view, name='example-detail')
]
class TestRequestVersion: class TestRequestVersion:
def test_unversioned(self): def test_unversioned(self):
@ -121,8 +110,17 @@ class TestRequestVersion:
assert response.data == {'version': None} assert response.data == {'version': None}
class TestURLReversing(APITestCase): class TestURLReversing(UsingURLPatterns, APITestCase):
urls = 'tests.test_versioning' included = [
url(r'^namespaced/$', mock_view, name='another'),
url(r'^example/(?P<pk>\d+)/$', dummy_view, name='example-detail')
]
urlpatterns = [
url(r'^v1/', include(included, namespace='v1')),
url(r'^another/$', mock_view, name='another'),
url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another'),
]
def test_reverse_unversioned(self): def test_reverse_unversioned(self):
view = ReverseView.as_view() view = ReverseView.as_view()
@ -230,10 +228,18 @@ class TestInvalidVersion:
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
class TestHyperlinkedRelatedField(APISimpleTestCase): class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase):
urls = 'tests.test_versioning' included = [
url(r'^namespaced/(?P<pk>\d+)/$', mock_view, name='namespaced'),
]
urlpatterns = [
url(r'^v1/', include(included, namespace='v1')),
url(r'^v2/', include(included, namespace='v2'))
]
def setUp(self): def setUp(self):
super(TestHyperlinkedRelatedField, self).setUp()
class HyperlinkedMockQueryset(MockQueryset): class HyperlinkedMockQueryset(MockQueryset):
def get(self, **lookup): def get(self, **lookup):
@ -248,13 +254,15 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
MockObject(pk=3, name='baz') MockObject(pk=3, name='baz')
]) ])
self.field = serializers.HyperlinkedRelatedField( self.field = serializers.HyperlinkedRelatedField(
view_name='example-detail', view_name='namespaced',
queryset=self.queryset queryset=self.queryset
) )
request = factory.post('/', urlconf='tests.test_versioning') request = factory.post('/', urlconf='tests.test_versioning')
request.versioning_scheme = NamespaceVersioning() request.versioning_scheme = NamespaceVersioning()
request.version = 'v1'
self.field._context = {'request': request} self.field._context = {'request': request}
def test_bug_2489(self): def test_bug_2489(self):
self.field.to_internal_value('/example/3/') self.field.to_internal_value('/v1/namespaced/3/')
self.field.to_internal_value('/v1/example/3/') with pytest.raises(serializers.ValidationError):
self.field.to_internal_value('/v2/namespaced/3/')

View File

@ -2,6 +2,30 @@ from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import NoReverseMatch from django.core.urlresolvers import NoReverseMatch
class UsingURLPatterns(object):
"""
Isolates URL patterns used during testing on the test class itself.
For example:
class MyTestCase(UsingURLPatterns, TestCase):
urlpatterns = [
...
]
def test_something(self):
...
"""
urls = __name__
def setUp(self):
global urlpatterns
urlpatterns = self.urlpatterns
def tearDown(self):
global urlpatterns
urlpatterns = []
class MockObject(object): class MockObject(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs