""" Routers provide a convenient and consistent way of automatically determining the URL conf for your API. They are used by simply instantiating a Router class, and then registering all the required ViewSets with that router. For example, you might have a `urls.py` that looks something like this: router = routers.DefaultRouter() router.register('users', UserViewSet, 'user') router.register('accounts', AccountViewSet, 'account') urlpatterns = router.urls """ import itertools from collections import namedtuple from django.core.exceptions import ImproperlyConfigured from django.urls import NoReverseMatch, path, re_path from rest_framework import views from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator from rest_framework.schemas.views import SchemaView from rest_framework.settings import api_settings from rest_framework.urlpatterns import format_suffix_patterns Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs']) DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs']) def escape_curly_brackets(url_path): """ Double brackets in regex of url_path for escape string formatting """ return url_path.replace('{', '{{').replace('}', '}}') def flatten(list_of_lists): """ Takes an iterable of iterables, returns a single iterable containing all items """ return itertools.chain(*list_of_lists) class BaseRouter: def __init__(self): self.registry = [] def register(self, prefix, viewset, basename=None): if basename is None: basename = self.get_default_basename(viewset) if self.is_already_registered(basename): msg = (f'Router with basename "{basename}" is already registered. ' f'Please provide a unique basename for viewset "{viewset}"') raise ImproperlyConfigured(msg) self.registry.append((prefix, viewset, basename)) # invalidate the urls cache if hasattr(self, '_urls'): del self._urls def is_already_registered(self, new_basename): """ Check if `basename` is already registered """ return any(basename == new_basename for _prefix, _viewset, basename in self.registry) def get_default_basename(self, viewset): """ If `basename` is not specified, attempt to automatically determine it from the viewset. """ raise NotImplementedError('get_default_basename must be overridden') def get_urls(self): """ Return a list of URL patterns, given the registered viewsets. """ raise NotImplementedError('get_urls must be overridden') @property def urls(self): if not hasattr(self, '_urls'): self._urls = self.get_urls() return self._urls class SimpleRouter(BaseRouter): routes = [ # List route. Route( url=r'^{prefix}{trailing_slash}$', mapping={ 'get': 'list', 'post': 'create' }, name='{basename}-list', detail=False, initkwargs={'suffix': 'List'} ), # Dynamically generated list routes. Generated using # @action(detail=False) decorator on methods of the viewset. DynamicRoute( url=r'^{prefix}/{url_path}{trailing_slash}$', name='{basename}-{url_name}', detail=False, initkwargs={} ), # Detail route. Route( url=r'^{prefix}/{lookup}{trailing_slash}$', mapping={ 'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy' }, name='{basename}-detail', detail=True, initkwargs={'suffix': 'Instance'} ), # Dynamically generated detail routes. Generated using # @action(detail=True) decorator on methods of the viewset. DynamicRoute( url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$', name='{basename}-{url_name}', detail=True, initkwargs={} ), ] def __init__(self, trailing_slash=True, use_regex_path=True): self.trailing_slash = '/' if trailing_slash else '' self._use_regex = use_regex_path if use_regex_path: self._base_pattern = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})' self._default_value_pattern = '[^/.]+' self._url_conf = re_path else: self._base_pattern = '<{lookup_value}:{lookup_prefix}{lookup_url_kwarg}>' self._default_value_pattern = 'path' self._url_conf = path # remove regex characters from routes _routes = [] for route in self.routes: url_param = route.url if url_param[0] == '^': url_param = url_param[1:] if url_param[-1] == '$': url_param = url_param[:-1] _routes.append(route._replace(url=url_param)) self.routes = _routes super().__init__() def get_default_basename(self, viewset): """ If `basename` is not specified, attempt to automatically determine it from the viewset. """ queryset = getattr(viewset, 'queryset', None) assert queryset is not None, '`basename` argument not specified, and could ' \ 'not automatically determine the name from the viewset, as ' \ 'it does not have a `.queryset` attribute.' return queryset.model._meta.object_name.lower() def get_routes(self, viewset): """ Augment `self.routes` with any dynamically generated routes. Returns a list of the Route namedtuple. """ # converting to list as iterables are good for one pass, known host needs to be checked again and again for # different functions. known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])) extra_actions = viewset.get_extra_actions() # checking action names against the known actions list not_allowed = [ action.__name__ for action in extra_actions if action.__name__ in known_actions ] if not_allowed: msg = ('Cannot use the @action decorator on the following ' 'methods, as they are existing routes: %s') raise ImproperlyConfigured(msg % ', '.join(not_allowed)) # partition detail and list actions detail_actions = [action for action in extra_actions if action.detail] list_actions = [action for action in extra_actions if not action.detail] routes = [] for route in self.routes: if isinstance(route, DynamicRoute) and route.detail: routes += [self._get_dynamic_route(route, action) for action in detail_actions] elif isinstance(route, DynamicRoute) and not route.detail: routes += [self._get_dynamic_route(route, action) for action in list_actions] else: routes.append(route) return routes def _get_dynamic_route(self, route, action): initkwargs = route.initkwargs.copy() initkwargs.update(action.kwargs) url_path = escape_curly_brackets(action.url_path) return Route( url=route.url.replace('{url_path}', url_path), mapping=action.mapping, name=route.name.replace('{url_name}', action.url_name), detail=route.detail, initkwargs=initkwargs, ) def get_method_map(self, viewset, method_map): """ Given a viewset, and a mapping of http methods to actions, return a new mapping which only includes any mappings that are actually implemented by the viewset. """ bound_methods = {} for method, action in method_map.items(): if hasattr(viewset, action): bound_methods[method] = action return bound_methods def get_lookup_regex(self, viewset, lookup_prefix=''): """ Given a viewset, return the portion of URL regex that is used to match against a single instance. Note that lookup_prefix is not used directly inside REST rest_framework itself, but is required in order to nicely support nested router implementations, such as drf-nested-routers. https://github.com/alanjds/drf-nested-routers """ # Use `pk` as default field, unset set. Default regex should not # consume `.json` style suffixes and should break at '/' boundaries. lookup_field = getattr(viewset, 'lookup_field', 'pk') lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field lookup_value = None if not self._use_regex: # try to get a more appropriate attribute when not using regex lookup_value = getattr(viewset, 'lookup_value_converter', None) if lookup_value is None: # fallback to legacy lookup_value = getattr(viewset, 'lookup_value_regex', self._default_value_pattern) return self._base_pattern.format( lookup_prefix=lookup_prefix, lookup_url_kwarg=lookup_url_kwarg, lookup_value=lookup_value ) def get_urls(self): """ Use the registered viewsets to generate a list of URL patterns. """ ret = [] for prefix, viewset, basename in self.registry: lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) for route in routes: # Only actions which actually exist on the viewset will be bound mapping = self.get_method_map(viewset, route.mapping) if not mapping: continue # Build the url pattern regex = route.url.format( prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash ) # If there is no prefix, the first part of the url is probably # controlled by project's urls.py and the router is in an app, # so a slash in the beginning will (A) cause Django to give # warnings and (B) generate URLS that will require using '//'. if not prefix: if self._url_conf is path: if regex[0] == '/': regex = regex[1:] elif regex[:2] == '^/': regex = '^' + regex[2:] initkwargs = route.initkwargs.copy() initkwargs.update({ 'basename': basename, 'detail': route.detail, }) view = viewset.as_view(mapping, **initkwargs) name = route.name.format(basename=basename) ret.append(self._url_conf(regex, view, name=name)) return ret class APIRootView(views.APIView): """ The default basic root view for DefaultRouter """ _ignore_model_permissions = True schema = None # exclude from schema api_root_dict = None def get(self, request, *args, **kwargs): # Return a plain {"name": "hyperlink"} response. ret = {} namespace = request.resolver_match.namespace for key, url_name in self.api_root_dict.items(): if namespace: url_name = namespace + ':' + url_name try: ret[key] = reverse( url_name, args=args, kwargs=kwargs, request=request, format=kwargs.get('format') ) except NoReverseMatch: # Don't bail out if eg. no list routes exist, only detail routes. continue return Response(ret) class DefaultRouter(SimpleRouter): """ The default router extends the SimpleRouter, but also adds in a default API root view, and adds format suffix patterns to the URLs. """ include_root_view = True include_format_suffixes = True root_view_name = 'api-root' default_schema_renderers = None APIRootView = APIRootView APISchemaView = SchemaView SchemaGenerator = SchemaGenerator def __init__(self, *args, **kwargs): if 'root_renderers' in kwargs: self.root_renderers = kwargs.pop('root_renderers') else: self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super().__init__(*args, **kwargs) def get_api_root_view(self, api_urls=None): """ Return a basic root view. """ api_root_dict = {} list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) return self.APIRootView.as_view(api_root_dict=api_root_dict) def get_urls(self): """ Generate the list of URL patterns, including a default root view for the API, and appending `.json` style format suffixes. """ urls = super().get_urls() if self.include_root_view: view = self.get_api_root_view(api_urls=urls) root_url = path('', view, name=self.root_view_name) urls.append(root_url) if self.include_format_suffixes: urls = format_suffix_patterns(urls) return urls