From 63ee414e3c6e1dedff572aebcbb504794c39aabb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Wed, 20 Dec 2017 11:04:56 +0100 Subject: [PATCH] Add get_original_route to complement get_regex_pattern --- rest_framework/compat.py | 19 ++++++++++++++++++- rest_framework/schemas/generators.py | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 9502c245f..d39057a52 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -29,7 +29,11 @@ except ImportError: ) -def get_regex_pattern(urlpattern): +def get_original_route(urlpattern): + """ + Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This + is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path(). + """ if hasattr(urlpattern, 'pattern'): # Django 2.0 return str(urlpattern.pattern) @@ -38,6 +42,19 @@ def get_regex_pattern(urlpattern): return urlpattern.regex.pattern +def get_regex_pattern(urlpattern): + """ + Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression, + unlike get_original_route above. + """ + if hasattr(urlpattern, 'pattern'): + # Django 2.0 + return urlpattern.pattern.regex.pattern + else: + # Django < 2.0 + return urlpattern.regex.pattern + + def make_url_resolver(regex, urlpatterns): try: # Django 2.0 diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 6f5c04475..10af6ee04 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -16,7 +16,7 @@ from django.utils import six from rest_framework import exceptions from rest_framework.compat import ( - URLPattern, URLResolver, coreapi, coreschema, get_regex_pattern + URLPattern, URLResolver, coreapi, coreschema, get_original_route ) from rest_framework.request import clone_request from rest_framework.settings import api_settings @@ -170,7 +170,7 @@ class EndpointEnumerator(object): api_endpoints = [] for pattern in patterns: - path_regex = prefix + get_regex_pattern(pattern) + path_regex = prefix + get_original_route(pattern) if isinstance(pattern, URLPattern): path = self.get_path_from_regex(path_regex) callback = pattern.callback