django-rest-framework/rest_framework/routers.py
2019-10-04 13:50:19 -07:00

368 lines
13 KiB
Python

"""
Routers provide a convenient and consistent way of automatically
determining the URL conf for your API.
They are used by simply instantiating a Router class, and then registering
all the required ViewSets with that router.
For example, you might have a `urls.py` that looks something like this:
router = routers.DefaultRouter()
router.register('users', UserViewSet, 'user')
router.register('accounts', AccountViewSet, 'account')
urlpatterns = router.urls
"""
import itertools
import warnings
from collections import OrderedDict, namedtuple
from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.urls import NoReverseMatch
from django.utils.deprecation import RenameMethodsBase
from rest_framework import RemovedInDRF311Warning, views
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
def escape_curly_brackets(url_path):
"""
Double brackets in regex of url_path for escape string formatting
"""
return url_path.replace('{', '{{').replace('}', '}}')
def flatten(list_of_lists):
"""
Takes an iterable of iterables, returns a single iterable containing all items
"""
return itertools.chain(*list_of_lists)
class RenameRouterMethods(RenameMethodsBase):
renamed_methods = (
('get_default_base_name', 'get_default_basename', RemovedInDRF311Warning),
)
class BaseRouter(metaclass=RenameRouterMethods):
def __init__(self):
self.registry = []
def register(self, prefix, viewset, basename=None, base_name=None):
if base_name is not None:
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
warnings.warn(msg, RemovedInDRF311Warning, 2)
assert not (basename and base_name), (
"Do not provide both the `basename` and `base_name` arguments.")
if basename is None:
basename = base_name
if basename is None:
basename = self.get_default_basename(viewset)
self.registry.append((prefix, viewset, basename))
# invalidate the urls cache
if hasattr(self, '_urls'):
del self._urls
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
raise NotImplementedError('get_default_basename must be overridden')
def get_urls(self):
"""
Return a list of URL patterns, given the registered viewsets.
"""
raise NotImplementedError('get_urls must be overridden')
@property
def urls(self):
if not hasattr(self, '_urls'):
self._urls = self.get_urls()
return self._urls
class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
},
name='{basename}-list',
detail=False,
initkwargs={'suffix': 'List'}
),
# Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=False,
initkwargs={}
),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy'
},
name='{basename}-detail',
detail=True,
initkwargs={'suffix': 'Instance'}
),
# Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=True,
initkwargs={}
),
]
def __init__(self, trailing_slash=True):
self.trailing_slash = '/' if trailing_slash else ''
super().__init__()
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
queryset = getattr(viewset, 'queryset', None)
assert queryset is not None, '`basename` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.queryset` attribute.'
return queryset.model._meta.object_name.lower()
def get_routes(self, viewset):
"""
Augment `self.routes` with any dynamically generated routes.
Returns a list of the Route namedtuple.
"""
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions.
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list
not_allowed = [
action.__name__ for action in extra_actions
if action.__name__ in known_actions
]
if not_allowed:
msg = ('Cannot use the @action decorator on the following '
'methods, as they are existing routes: %s')
raise ImproperlyConfigured(msg % ', '.join(not_allowed))
# partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail]
list_actions = [action for action in extra_actions if not action.detail]
routes = []
for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail:
routes += [self._get_dynamic_route(route, action) for action in detail_actions]
elif isinstance(route, DynamicRoute) and not route.detail:
routes += [self._get_dynamic_route(route, action) for action in list_actions]
else:
routes.append(route)
return routes
def _get_dynamic_route(self, route, action):
initkwargs = route.initkwargs.copy()
initkwargs.update(action.kwargs)
url_path = escape_curly_brackets(action.url_path)
return Route(
url=route.url.replace('{url_path}', url_path),
mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name),
detail=route.detail,
initkwargs=initkwargs,
)
def get_method_map(self, viewset, method_map):
"""
Given a viewset, and a mapping of http methods to actions,
return a new mapping which only includes any mappings that
are actually implemented by the viewset.
"""
bound_methods = {}
for method, action in method_map.items():
if hasattr(viewset, action):
bound_methods[method] = action
return bound_methods
def get_lookup_regex(self, viewset, lookup_prefix=''):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
Note that lookup_prefix is not used directly inside REST rest_framework
itself, but is required in order to nicely support nested router
implementations, such as drf-nested-routers.
https://github.com/alanjds/drf-nested-routers
"""
base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
# Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg,
lookup_value=lookup_value
)
def get_urls(self):
"""
Use the registered viewsets to generate a list of URL patterns.
"""
ret = []
for prefix, viewset, basename in self.registry:
lookup = self.get_lookup_regex(viewset)
routes = self.get_routes(viewset)
for route in routes:
# Only actions which actually exist on the viewset will be bound
mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue
# Build the url pattern
regex = route.url.format(
prefix=prefix,
lookup=lookup,
trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
if not prefix and regex[:2] == '^/':
regex = '^' + regex[2:]
initkwargs = route.initkwargs.copy()
initkwargs.update({
'basename': basename,
'detail': route.detail,
})
view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
return ret
class APIRootView(views.APIView):
"""
The default basic root view for DefaultRouter
"""
_ignore_model_permissions = True
schema = None # exclude from schema
api_root_dict = None
def get(self, request, *args, **kwargs):
# Return a plain {"name": "hyperlink"} response.
ret = OrderedDict()
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
url_name = namespace + ':' + url_name
try:
ret[key] = reverse(
url_name,
args=args,
kwargs=kwargs,
request=request,
format=kwargs.get('format', None)
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
continue
return Response(ret)
class DefaultRouter(SimpleRouter):
"""
The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs.
"""
include_root_view = True
include_format_suffixes = True
root_view_name = 'api-root'
default_schema_renderers = None
APIRootView = APIRootView
APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs):
if 'root_renderers' in kwargs:
self.root_renderers = kwargs.pop('root_renderers')
else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super().__init__(*args, **kwargs)
def get_api_root_view(self, api_urls=None):
"""
Return a basic root view.
"""
api_root_dict = OrderedDict()
list_name = self.routes[0].name
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)
return self.APIRootView.as_view(api_root_dict=api_root_dict)
def get_urls(self):
"""
Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes.
"""
urls = super().get_urls()
if self.include_root_view:
view = self.get_api_root_view(api_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
urls.append(root_url)
if self.include_format_suffixes:
urls = format_suffix_patterns(urls)
return urls