mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-01 00:17:40 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Routers provide a convenient and consistent way of automatically
 | |
| determining the URL conf for your API.
 | |
| 
 | |
| They are used by simply instantiating a Router class, and then registering
 | |
| all the required ViewSets with that router.
 | |
| 
 | |
| For example, you might have a `urls.py` that looks something like this:
 | |
| 
 | |
|     router = routers.DefaultRouter()
 | |
|     router.register('users', UserViewSet, 'user')
 | |
|     router.register('accounts', AccountViewSet, 'account')
 | |
| 
 | |
|     urlpatterns = router.urls
 | |
| """
 | |
| import itertools
 | |
| from collections import namedtuple
 | |
| 
 | |
| from django.core.exceptions import ImproperlyConfigured
 | |
| from django.urls import NoReverseMatch, path, re_path
 | |
| 
 | |
| from rest_framework import views
 | |
| from rest_framework.response import Response
 | |
| from rest_framework.reverse import reverse
 | |
| from rest_framework.schemas import SchemaGenerator
 | |
| from rest_framework.schemas.views import SchemaView
 | |
| from rest_framework.settings import api_settings
 | |
| from rest_framework.urlpatterns import format_suffix_patterns
 | |
| 
 | |
| Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
 | |
| DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
 | |
| 
 | |
| 
 | |
| def escape_curly_brackets(url_path):
 | |
|     """
 | |
|     Double brackets in regex of url_path for escape string formatting
 | |
|     """
 | |
|     return url_path.replace('{', '{{').replace('}', '}}')
 | |
| 
 | |
| 
 | |
| def flatten(list_of_lists):
 | |
|     """
 | |
|     Takes an iterable of iterables, returns a single iterable containing all items
 | |
|     """
 | |
|     return itertools.chain(*list_of_lists)
 | |
| 
 | |
| 
 | |
| class BaseRouter:
 | |
|     def __init__(self):
 | |
|         self.registry = []
 | |
| 
 | |
|     def register(self, prefix, viewset, basename=None):
 | |
|         if basename is None:
 | |
|             basename = self.get_default_basename(viewset)
 | |
| 
 | |
|         if self.is_already_registered(basename):
 | |
|             msg = (f'Router with basename "{basename}" is already registered. '
 | |
|                    f'Please provide a unique basename for viewset "{viewset}"')
 | |
|             raise ImproperlyConfigured(msg)
 | |
| 
 | |
|         self.registry.append((prefix, viewset, basename))
 | |
| 
 | |
|         # invalidate the urls cache
 | |
|         if hasattr(self, '_urls'):
 | |
|             del self._urls
 | |
| 
 | |
|     def is_already_registered(self, new_basename):
 | |
|         """
 | |
|         Check if `basename` is already registered
 | |
|         """
 | |
|         return any(basename == new_basename for _prefix, _viewset, basename in self.registry)
 | |
| 
 | |
|     def get_default_basename(self, viewset):
 | |
|         """
 | |
|         If `basename` is not specified, attempt to automatically determine
 | |
|         it from the viewset.
 | |
|         """
 | |
|         raise NotImplementedError('get_default_basename must be overridden')
 | |
| 
 | |
|     def get_urls(self):
 | |
|         """
 | |
|         Return a list of URL patterns, given the registered viewsets.
 | |
|         """
 | |
|         raise NotImplementedError('get_urls must be overridden')
 | |
| 
 | |
|     @property
 | |
|     def urls(self):
 | |
|         if not hasattr(self, '_urls'):
 | |
|             self._urls = self.get_urls()
 | |
|         return self._urls
 | |
| 
 | |
| 
 | |
| class SimpleRouter(BaseRouter):
 | |
| 
 | |
|     routes = [
 | |
|         # List route.
 | |
|         Route(
 | |
|             url=r'^{prefix}{trailing_slash}$',
 | |
|             mapping={
 | |
|                 'get': 'list',
 | |
|                 'post': 'create'
 | |
|             },
 | |
|             name='{basename}-list',
 | |
|             detail=False,
 | |
|             initkwargs={'suffix': 'List'}
 | |
|         ),
 | |
|         # Dynamically generated list routes. Generated using
 | |
|         # @action(detail=False) decorator on methods of the viewset.
 | |
|         DynamicRoute(
 | |
|             url=r'^{prefix}/{url_path}{trailing_slash}$',
 | |
|             name='{basename}-{url_name}',
 | |
|             detail=False,
 | |
|             initkwargs={}
 | |
|         ),
 | |
|         # Detail route.
 | |
|         Route(
 | |
|             url=r'^{prefix}/{lookup}{trailing_slash}$',
 | |
|             mapping={
 | |
|                 'get': 'retrieve',
 | |
|                 'put': 'update',
 | |
|                 'patch': 'partial_update',
 | |
|                 'delete': 'destroy'
 | |
|             },
 | |
|             name='{basename}-detail',
 | |
|             detail=True,
 | |
|             initkwargs={'suffix': 'Instance'}
 | |
|         ),
 | |
|         # Dynamically generated detail routes. Generated using
 | |
|         # @action(detail=True) decorator on methods of the viewset.
 | |
|         DynamicRoute(
 | |
|             url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
 | |
|             name='{basename}-{url_name}',
 | |
|             detail=True,
 | |
|             initkwargs={}
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
|     def __init__(self, trailing_slash=True, use_regex_path=True):
 | |
|         self.trailing_slash = '/' if trailing_slash else ''
 | |
|         self._use_regex = use_regex_path
 | |
|         if use_regex_path:
 | |
|             self._base_pattern = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
 | |
|             self._default_value_pattern = '[^/.]+'
 | |
|             self._url_conf = re_path
 | |
|         else:
 | |
|             self._base_pattern = '<{lookup_value}:{lookup_prefix}{lookup_url_kwarg}>'
 | |
|             self._default_value_pattern = 'str'
 | |
|             self._url_conf = path
 | |
|             # remove regex characters from routes
 | |
|             _routes = []
 | |
|             for route in self.routes:
 | |
|                 url_param = route.url
 | |
|                 if url_param[0] == '^':
 | |
|                     url_param = url_param[1:]
 | |
|                 if url_param[-1] == '$':
 | |
|                     url_param = url_param[:-1]
 | |
| 
 | |
|                 _routes.append(route._replace(url=url_param))
 | |
|             self.routes = _routes
 | |
| 
 | |
|         super().__init__()
 | |
| 
 | |
|     def get_default_basename(self, viewset):
 | |
|         """
 | |
|         If `basename` is not specified, attempt to automatically determine
 | |
|         it from the viewset.
 | |
|         """
 | |
|         queryset = getattr(viewset, 'queryset', None)
 | |
| 
 | |
|         assert queryset is not None, '`basename` argument not specified, and could ' \
 | |
|             'not automatically determine the name from the viewset, as ' \
 | |
|             'it does not have a `.queryset` attribute.'
 | |
| 
 | |
|         return queryset.model._meta.object_name.lower()
 | |
| 
 | |
|     def get_routes(self, viewset):
 | |
|         """
 | |
|         Augment `self.routes` with any dynamically generated routes.
 | |
| 
 | |
|         Returns a list of the Route namedtuple.
 | |
|         """
 | |
|         # converting to list as iterables are good for one pass, known host needs to be checked again and again for
 | |
|         # different functions.
 | |
|         known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
 | |
|         extra_actions = viewset.get_extra_actions()
 | |
| 
 | |
|         # checking action names against the known actions list
 | |
|         not_allowed = [
 | |
|             action.__name__ for action in extra_actions
 | |
|             if action.__name__ in known_actions
 | |
|         ]
 | |
|         if not_allowed:
 | |
|             msg = ('Cannot use the @action decorator on the following '
 | |
|                    'methods, as they are existing routes: %s')
 | |
|             raise ImproperlyConfigured(msg % ', '.join(not_allowed))
 | |
| 
 | |
|         # partition detail and list actions
 | |
|         detail_actions = [action for action in extra_actions if action.detail]
 | |
|         list_actions = [action for action in extra_actions if not action.detail]
 | |
| 
 | |
|         routes = []
 | |
|         for route in self.routes:
 | |
|             if isinstance(route, DynamicRoute) and route.detail:
 | |
|                 routes += [self._get_dynamic_route(route, action) for action in detail_actions]
 | |
|             elif isinstance(route, DynamicRoute) and not route.detail:
 | |
|                 routes += [self._get_dynamic_route(route, action) for action in list_actions]
 | |
|             else:
 | |
|                 routes.append(route)
 | |
| 
 | |
|         return routes
 | |
| 
 | |
|     def _get_dynamic_route(self, route, action):
 | |
|         initkwargs = route.initkwargs.copy()
 | |
|         initkwargs.update(action.kwargs)
 | |
| 
 | |
|         url_path = escape_curly_brackets(action.url_path)
 | |
| 
 | |
|         return Route(
 | |
|             url=route.url.replace('{url_path}', url_path),
 | |
|             mapping=action.mapping,
 | |
|             name=route.name.replace('{url_name}', action.url_name),
 | |
|             detail=route.detail,
 | |
|             initkwargs=initkwargs,
 | |
|         )
 | |
| 
 | |
|     def get_method_map(self, viewset, method_map):
 | |
|         """
 | |
|         Given a viewset, and a mapping of http methods to actions,
 | |
|         return a new mapping which only includes any mappings that
 | |
|         are actually implemented by the viewset.
 | |
|         """
 | |
|         bound_methods = {}
 | |
|         for method, action in method_map.items():
 | |
|             if hasattr(viewset, action):
 | |
|                 bound_methods[method] = action
 | |
|         return bound_methods
 | |
| 
 | |
|     def get_lookup_regex(self, viewset, lookup_prefix=''):
 | |
|         """
 | |
|         Given a viewset, return the portion of URL regex that is used
 | |
|         to match against a single instance.
 | |
| 
 | |
|         Note that lookup_prefix is not used directly inside REST rest_framework
 | |
|         itself, but is required in order to nicely support nested router
 | |
|         implementations, such as drf-nested-routers.
 | |
| 
 | |
|         https://github.com/alanjds/drf-nested-routers
 | |
|         """
 | |
|         # Use `pk` as default field, unset set.  Default regex should not
 | |
|         # consume `.json` style suffixes and should break at '/' boundaries.
 | |
|         lookup_field = getattr(viewset, 'lookup_field', 'pk')
 | |
|         lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
 | |
|         lookup_value = None
 | |
|         if not self._use_regex:
 | |
|             # try to get a more appropriate attribute when not using regex
 | |
|             lookup_value = getattr(viewset, 'lookup_value_converter', None)
 | |
|         if lookup_value is None:
 | |
|             # fallback to legacy
 | |
|             lookup_value = getattr(viewset, 'lookup_value_regex', self._default_value_pattern)
 | |
|         return self._base_pattern.format(
 | |
|             lookup_prefix=lookup_prefix,
 | |
|             lookup_url_kwarg=lookup_url_kwarg,
 | |
|             lookup_value=lookup_value
 | |
|         )
 | |
| 
 | |
|     def get_urls(self):
 | |
|         """
 | |
|         Use the registered viewsets to generate a list of URL patterns.
 | |
|         """
 | |
|         ret = []
 | |
| 
 | |
|         for prefix, viewset, basename in self.registry:
 | |
|             lookup = self.get_lookup_regex(viewset)
 | |
|             routes = self.get_routes(viewset)
 | |
| 
 | |
|             for route in routes:
 | |
| 
 | |
|                 # Only actions which actually exist on the viewset will be bound
 | |
|                 mapping = self.get_method_map(viewset, route.mapping)
 | |
|                 if not mapping:
 | |
|                     continue
 | |
| 
 | |
|                 # Build the url pattern
 | |
|                 regex = route.url.format(
 | |
|                     prefix=prefix,
 | |
|                     lookup=lookup,
 | |
|                     trailing_slash=self.trailing_slash
 | |
|                 )
 | |
| 
 | |
|                 # If there is no prefix, the first part of the url is probably
 | |
|                 #   controlled by project's urls.py and the router is in an app,
 | |
|                 #   so a slash in the beginning will (A) cause Django to give
 | |
|                 #   warnings and (B) generate URLS that will require using '//'.
 | |
|                 if not prefix:
 | |
|                     if self._url_conf is path:
 | |
|                         if regex[0] == '/':
 | |
|                             regex = regex[1:]
 | |
|                     elif regex[:2] == '^/':
 | |
|                         regex = '^' + regex[2:]
 | |
| 
 | |
|                 initkwargs = route.initkwargs.copy()
 | |
|                 initkwargs.update({
 | |
|                     'basename': basename,
 | |
|                     'detail': route.detail,
 | |
|                 })
 | |
| 
 | |
|                 view = viewset.as_view(mapping, **initkwargs)
 | |
|                 name = route.name.format(basename=basename)
 | |
|                 ret.append(self._url_conf(regex, view, name=name))
 | |
| 
 | |
|         return ret
 | |
| 
 | |
| 
 | |
| class APIRootView(views.APIView):
 | |
|     """
 | |
|     The default basic root view for DefaultRouter
 | |
|     """
 | |
|     _ignore_model_permissions = True
 | |
|     schema = None  # exclude from schema
 | |
|     api_root_dict = None
 | |
| 
 | |
|     def get(self, request, *args, **kwargs):
 | |
|         # Return a plain {"name": "hyperlink"} response.
 | |
|         ret = {}
 | |
|         namespace = request.resolver_match.namespace
 | |
|         for key, url_name in self.api_root_dict.items():
 | |
|             if namespace:
 | |
|                 url_name = namespace + ':' + url_name
 | |
|             try:
 | |
|                 ret[key] = reverse(
 | |
|                     url_name,
 | |
|                     args=args,
 | |
|                     kwargs=kwargs,
 | |
|                     request=request,
 | |
|                     format=kwargs.get('format')
 | |
|                 )
 | |
|             except NoReverseMatch:
 | |
|                 # Don't bail out if eg. no list routes exist, only detail routes.
 | |
|                 continue
 | |
| 
 | |
|         return Response(ret)
 | |
| 
 | |
| 
 | |
| class DefaultRouter(SimpleRouter):
 | |
|     """
 | |
|     The default router extends the SimpleRouter, but also adds in a default
 | |
|     API root view, and adds format suffix patterns to the URLs.
 | |
|     """
 | |
|     include_root_view = True
 | |
|     include_format_suffixes = True
 | |
|     root_view_name = 'api-root'
 | |
|     default_schema_renderers = None
 | |
|     APIRootView = APIRootView
 | |
|     APISchemaView = SchemaView
 | |
|     SchemaGenerator = SchemaGenerator
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         if 'root_renderers' in kwargs:
 | |
|             self.root_renderers = kwargs.pop('root_renderers')
 | |
|         else:
 | |
|             self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
|     def get_api_root_view(self, api_urls=None):
 | |
|         """
 | |
|         Return a basic root view.
 | |
|         """
 | |
|         api_root_dict = {}
 | |
|         list_name = self.routes[0].name
 | |
|         for prefix, viewset, basename in self.registry:
 | |
|             api_root_dict[prefix] = list_name.format(basename=basename)
 | |
| 
 | |
|         return self.APIRootView.as_view(api_root_dict=api_root_dict)
 | |
| 
 | |
|     def get_urls(self):
 | |
|         """
 | |
|         Generate the list of URL patterns, including a default root view
 | |
|         for the API, and appending `.json` style format suffixes.
 | |
|         """
 | |
|         urls = super().get_urls()
 | |
| 
 | |
|         if self.include_root_view:
 | |
|             view = self.get_api_root_view(api_urls=urls)
 | |
|             root_url = path('', view, name=self.root_view_name)
 | |
|             urls.append(root_url)
 | |
| 
 | |
|         if self.include_format_suffixes:
 | |
|             urls = format_suffix_patterns(urls)
 | |
| 
 | |
|         return urls
 |