Fix format_suffix_patterns behavior with Django 2 path() routes (#5691)

* Add failing test for #5672

* Add get_original_route to complement get_regex_pattern

* [WIP] Fix path handling

* needs more tests
* maybe needs some refactoring

* Add django 2 variant for all tests and fix trailing slash bug

* Add more combinations to mixed path test
This commit is contained in:
Cristi Vîjdea 2017-12-20 13:17:54 +01:00 committed by Carlton Gibson
parent cf3929d88d
commit 6de12e574e
4 changed files with 249 additions and 41 deletions

View File

@ -29,7 +29,11 @@ except ImportError:
) )
def get_regex_pattern(urlpattern): def get_original_route(urlpattern):
"""
Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This
is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path().
"""
if hasattr(urlpattern, 'pattern'): if hasattr(urlpattern, 'pattern'):
# Django 2.0 # Django 2.0
return str(urlpattern.pattern) return str(urlpattern.pattern)
@ -38,6 +42,29 @@ def get_regex_pattern(urlpattern):
return urlpattern.regex.pattern return urlpattern.regex.pattern
def get_regex_pattern(urlpattern):
"""
Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression,
unlike get_original_route above.
"""
if hasattr(urlpattern, 'pattern'):
# Django 2.0
return urlpattern.pattern.regex.pattern
else:
# Django < 2.0
return urlpattern.regex.pattern
def is_route_pattern(urlpattern):
if hasattr(urlpattern, 'pattern'):
# Django 2.0
from django.urls.resolvers import RoutePattern
return isinstance(urlpattern.pattern, RoutePattern)
else:
# Django < 2.0
return False
def make_url_resolver(regex, urlpatterns): def make_url_resolver(regex, urlpatterns):
try: try:
# Django 2.0 # Django 2.0
@ -257,10 +284,11 @@ except ImportError:
# Django 1.x url routing syntax. Remove when dropping Django 1.11 support. # Django 1.x url routing syntax. Remove when dropping Django 1.11 support.
try: try:
from django.urls import include, path, re_path # noqa from django.urls import include, path, re_path, register_converter # noqa
except ImportError: except ImportError:
from django.conf.urls import include, url # noqa from django.conf.urls import include, url # noqa
path = None path = None
register_converter = None
re_path = url re_path = url

View File

@ -16,7 +16,7 @@ from django.utils import six
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import ( from rest_framework.compat import (
URLPattern, URLResolver, coreapi, coreschema, get_regex_pattern URLPattern, URLResolver, coreapi, coreschema, get_original_route
) )
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -170,7 +170,7 @@ class EndpointEnumerator(object):
api_endpoints = [] api_endpoints = []
for pattern in patterns: for pattern in patterns:
path_regex = prefix + get_regex_pattern(pattern) path_regex = prefix + get_original_route(pattern)
if isinstance(pattern, URLPattern): if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex) path = self.get_path_from_regex(path_regex)
callback = pattern.callback callback = pattern.callback

View File

@ -2,11 +2,39 @@ from __future__ import unicode_literals
from django.conf.urls import include, url from django.conf.urls import include, url
from rest_framework.compat import URLResolver, get_regex_pattern from rest_framework.compat import (
URLResolver, get_regex_pattern, is_route_pattern, path, register_converter
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): def _get_format_path_converter(suffix_kwarg, allowed):
if allowed:
if len(allowed) == 1:
allowed_pattern = allowed[0]
else:
allowed_pattern = '(?:%s)' % '|'.join(allowed)
suffix_pattern = r"\.%s/?" % allowed_pattern
else:
suffix_pattern = r"\.[a-z0-9]+/?"
class FormatSuffixConverter:
regex = suffix_pattern
def to_python(self, value):
return value.strip('./')
def to_url(self, value):
return '.' + value + '/'
converter_name = 'drf_format_suffix'
if allowed:
converter_name += '_' + '_'.join(allowed)
return converter_name, FormatSuffixConverter
def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route=None):
ret = [] ret = []
for urlpattern in urlpatterns: for urlpattern in urlpatterns:
if isinstance(urlpattern, URLResolver): if isinstance(urlpattern, URLResolver):
@ -18,8 +46,18 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
# Add in the included patterns, after applying the suffixes # Add in the included patterns, after applying the suffixes
patterns = apply_suffix_patterns(urlpattern.url_patterns, patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern, suffix_pattern,
suffix_required) suffix_required,
ret.append(url(regex, include((patterns, app_name), namespace), kwargs)) suffix_route)
# if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern):
assert path is not None
route = str(urlpattern.pattern)
new_pattern = path(route, include((patterns, app_name), namespace), kwargs)
else:
new_pattern = url(regex, include((patterns, app_name), namespace), kwargs)
ret.append(new_pattern)
else: else:
# Regular URL pattern # Regular URL pattern
regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern
@ -29,7 +67,17 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
# Add in both the existing and the new urlpattern # Add in both the existing and the new urlpattern
if not suffix_required: if not suffix_required:
ret.append(urlpattern) ret.append(urlpattern)
ret.append(url(regex, view, kwargs, name))
# if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern):
assert path is not None
assert suffix_route is not None
route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route
new_pattern = path(route, view, kwargs, name)
else:
new_pattern = url(regex, view, kwargs, name)
ret.append(new_pattern)
return ret return ret
@ -60,4 +108,12 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
else: else:
suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) if path and register_converter:
converter_name, suffix_converter = _get_format_path_converter(suffix_kwarg, allowed)
register_converter(suffix_converter, converter_name)
suffix_route = '<%s:%s>' % (converter_name, suffix_kwarg)
else:
suffix_route = None
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route)

View File

@ -1,12 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import unittest
from collections import namedtuple from collections import namedtuple
from django.conf.urls import include, url from django.conf.urls import include, url
from django.test import TestCase from django.test import TestCase
from django.urls import Resolver404 from django.urls import Resolver404
from rest_framework.compat import make_url_resolver from rest_framework.compat import make_url_resolver, path, re_path
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
@ -23,41 +24,29 @@ class FormatSuffixTests(TestCase):
Tests `format_suffix_patterns` against different URLPatterns to ensure the Tests `format_suffix_patterns` against different URLPatterns to ensure the
URLs still resolve properly, including any captured parameters. URLs still resolve properly, including any captured parameters.
""" """
def _resolve_urlpatterns(self, urlpatterns, test_paths): def _resolve_urlpatterns(self, urlpatterns, test_paths, allowed=None):
factory = APIRequestFactory() factory = APIRequestFactory()
try: try:
urlpatterns = format_suffix_patterns(urlpatterns) urlpatterns = format_suffix_patterns(urlpatterns, allowed=allowed)
except Exception: except Exception:
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
resolver = make_url_resolver(r'^/', urlpatterns) resolver = make_url_resolver(r'^/', urlpatterns)
for test_path in test_paths: for test_path in test_paths:
request = factory.get(test_path.path)
try: try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) test_path, expected_resolved = test_path
except Exception: except (TypeError, ValueError):
self.fail("Failed to resolve URL: %s" % request.path_info) expected_resolved = True
assert callback_args == test_path.args
assert callback_kwargs == test_path.kwargs
def test_trailing_slash(self):
factory = APIRequestFactory()
urlpatterns = format_suffix_patterns([
url(r'^test/$', dummy_view),
])
resolver = make_url_resolver(r'^/', urlpatterns)
test_paths = [
(URLTestPath('/test.api', (), {'format': 'api'}), True),
(URLTestPath('/test/.api', (), {'format': 'api'}), False),
(URLTestPath('/test.api/', (), {'format': 'api'}), True),
]
for test_path, expected_resolved in test_paths:
request = factory.get(test_path.path) request = factory.get(test_path.path)
try: try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
except Resolver404: except Resolver404:
callback, callback_args, callback_kwargs = (None, None, None) callback, callback_args, callback_kwargs = (None, None, None)
if expected_resolved:
raise
except Exception:
self.fail("Failed to resolve URL: %s" % request.path_info)
if not expected_resolved: if not expected_resolved:
assert callback is None assert callback is None
continue continue
@ -65,10 +54,28 @@ class FormatSuffixTests(TestCase):
assert callback_args == test_path.args assert callback_args == test_path.args
assert callback_kwargs == test_path.kwargs assert callback_kwargs == test_path.kwargs
def test_format_suffix(self): def _test_trailing_slash(self, urlpatterns):
urlpatterns = [ test_paths = [
url(r'^test$', dummy_view), (URLTestPath('/test.api', (), {'format': 'api'}), True),
(URLTestPath('/test/.api', (), {'format': 'api'}), False),
(URLTestPath('/test.api/', (), {'format': 'api'}), True),
] ]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_trailing_slash(self):
urlpatterns = [
url(r'^test/$', dummy_view),
]
self._test_trailing_slash(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_trailing_slash_django2(self):
urlpatterns = [
path('test/', dummy_view),
]
self._test_trailing_slash(urlpatterns)
def _test_format_suffix(self, urlpatterns):
test_paths = [ test_paths = [
URLTestPath('/test', (), {}), URLTestPath('/test', (), {}),
URLTestPath('/test.api', (), {'format': 'api'}), URLTestPath('/test.api', (), {'format': 'api'}),
@ -76,10 +83,36 @@ class FormatSuffixTests(TestCase):
] ]
self._resolve_urlpatterns(urlpatterns, test_paths) self._resolve_urlpatterns(urlpatterns, test_paths)
def test_default_args(self): def test_format_suffix(self):
urlpatterns = [ urlpatterns = [
url(r'^test$', dummy_view, {'foo': 'bar'}), url(r'^test$', dummy_view),
] ]
self._test_format_suffix(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_format_suffix_django2(self):
urlpatterns = [
path('test', dummy_view),
]
self._test_format_suffix(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_format_suffix_django2_args(self):
urlpatterns = [
path('convtest/<int:pk>', dummy_view),
re_path(r'^retest/(?P<pk>[0-9]+)$', dummy_view),
]
test_paths = [
URLTestPath('/convtest/42', (), {'pk': 42}),
URLTestPath('/convtest/42.api', (), {'pk': 42, 'format': 'api'}),
URLTestPath('/convtest/42.asdf', (), {'pk': 42, 'format': 'asdf'}),
URLTestPath('/retest/42', (), {'pk': '42'}),
URLTestPath('/retest/42.api', (), {'pk': '42', 'format': 'api'}),
URLTestPath('/retest/42.asdf', (), {'pk': '42', 'format': 'asdf'}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def _test_default_args(self, urlpatterns):
test_paths = [ test_paths = [
URLTestPath('/test', (), {'foo': 'bar', }), URLTestPath('/test', (), {'foo': 'bar', }),
URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
@ -87,6 +120,27 @@ class FormatSuffixTests(TestCase):
] ]
self._resolve_urlpatterns(urlpatterns, test_paths) self._resolve_urlpatterns(urlpatterns, test_paths)
def test_default_args(self):
urlpatterns = [
url(r'^test$', dummy_view, {'foo': 'bar'}),
]
self._test_default_args(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_default_args_django2(self):
urlpatterns = [
path('test', dummy_view, {'foo': 'bar'}),
]
self._test_default_args(urlpatterns)
def _test_included_urls(self, urlpatterns):
test_paths = [
URLTestPath('/test/path', (), {'foo': 'bar', }),
URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_included_urls(self): def test_included_urls(self):
nested_patterns = [ nested_patterns = [
url(r'^path$', dummy_view) url(r'^path$', dummy_view)
@ -94,9 +148,79 @@ class FormatSuffixTests(TestCase):
urlpatterns = [ urlpatterns = [
url(r'^test/', include(nested_patterns), {'foo': 'bar'}), url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
] ]
self._test_included_urls(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_included_urls_django2(self):
nested_patterns = [
path('path', dummy_view)
]
urlpatterns = [
path('test/', include(nested_patterns), {'foo': 'bar'}),
]
self._test_included_urls(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_included_urls_django2_mixed(self):
nested_patterns = [
path('path', dummy_view)
]
urlpatterns = [
url('^test/', include(nested_patterns), {'foo': 'bar'}),
]
self._test_included_urls(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_included_urls_django2_mixed_args(self):
nested_patterns = [
path('path/<int:child>', dummy_view),
url('^url/(?P<child>[0-9]+)$', dummy_view)
]
urlpatterns = [
url('^purl/(?P<parent>[0-9]+)/', include(nested_patterns), {'foo': 'bar'}),
path('ppath/<int:parent>/', include(nested_patterns), {'foo': 'bar'}),
]
test_paths = [ test_paths = [
URLTestPath('/test/path', (), {'foo': 'bar', }), # parent url() nesting child path()
URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), URLTestPath('/purl/87/path/42', (), {'parent': '87', 'child': 42, 'foo': 'bar', }),
URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), URLTestPath('/purl/87/path/42.api', (), {'parent': '87', 'child': 42, 'foo': 'bar', 'format': 'api'}),
URLTestPath('/purl/87/path/42.asdf', (), {'parent': '87', 'child': 42, 'foo': 'bar', 'format': 'asdf'}),
# parent path() nesting child url()
URLTestPath('/ppath/87/url/42', (), {'parent': 87, 'child': '42', 'foo': 'bar', }),
URLTestPath('/ppath/87/url/42.api', (), {'parent': 87, 'child': '42', 'foo': 'bar', 'format': 'api'}),
URLTestPath('/ppath/87/url/42.asdf', (), {'parent': 87, 'child': '42', 'foo': 'bar', 'format': 'asdf'}),
# parent path() nesting child path()
URLTestPath('/ppath/87/path/42', (), {'parent': 87, 'child': 42, 'foo': 'bar', }),
URLTestPath('/ppath/87/path/42.api', (), {'parent': 87, 'child': 42, 'foo': 'bar', 'format': 'api'}),
URLTestPath('/ppath/87/path/42.asdf', (), {'parent': 87, 'child': 42, 'foo': 'bar', 'format': 'asdf'}),
# parent url() nesting child url()
URLTestPath('/purl/87/url/42', (), {'parent': '87', 'child': '42', 'foo': 'bar', }),
URLTestPath('/purl/87/url/42.api', (), {'parent': '87', 'child': '42', 'foo': 'bar', 'format': 'api'}),
URLTestPath('/purl/87/url/42.asdf', (), {'parent': '87', 'child': '42', 'foo': 'bar', 'format': 'asdf'}),
] ]
self._resolve_urlpatterns(urlpatterns, test_paths) self._resolve_urlpatterns(urlpatterns, test_paths)
def _test_allowed_formats(self, urlpatterns):
allowed_formats = ['good', 'ugly']
test_paths = [
(URLTestPath('/test.good/', (), {'format': 'good'}), True),
(URLTestPath('/test.bad', (), {}), False),
(URLTestPath('/test.ugly', (), {'format': 'ugly'}), True),
]
self._resolve_urlpatterns(urlpatterns, test_paths, allowed=allowed_formats)
def test_allowed_formats(self):
urlpatterns = [
url('^test$', dummy_view),
]
self._test_allowed_formats(urlpatterns)
@unittest.skipUnless(path, 'needs Django 2')
def test_allowed_formats_django2(self):
urlpatterns = [
path('test', dummy_view),
]
self._test_allowed_formats(urlpatterns)