mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			391 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Routers provide a convenient and consistent way of automatically
 | 
						|
determining the URL conf for your API.
 | 
						|
 | 
						|
They are used by simply instantiating a Router class, and then registering
 | 
						|
all the required ViewSets with that router.
 | 
						|
 | 
						|
For example, you might have a `urls.py` that looks something like this:
 | 
						|
 | 
						|
    router = routers.DefaultRouter()
 | 
						|
    router.register('users', UserViewSet, 'user')
 | 
						|
    router.register('accounts', AccountViewSet, 'account')
 | 
						|
 | 
						|
    urlpatterns = router.urls
 | 
						|
"""
 | 
						|
import itertools
 | 
						|
from collections import namedtuple
 | 
						|
 | 
						|
from django.core.exceptions import ImproperlyConfigured
 | 
						|
from django.urls import NoReverseMatch, path, re_path
 | 
						|
 | 
						|
from rest_framework import views
 | 
						|
from rest_framework.response import Response
 | 
						|
from rest_framework.reverse import reverse
 | 
						|
from rest_framework.schemas import SchemaGenerator
 | 
						|
from rest_framework.schemas.views import SchemaView
 | 
						|
from rest_framework.settings import api_settings
 | 
						|
from rest_framework.urlpatterns import format_suffix_patterns
 | 
						|
 | 
						|
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
 | 
						|
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
 | 
						|
 | 
						|
 | 
						|
def escape_curly_brackets(url_path):
 | 
						|
    """
 | 
						|
    Double brackets in regex of url_path for escape string formatting
 | 
						|
    """
 | 
						|
    return url_path.replace('{', '{{').replace('}', '}}')
 | 
						|
 | 
						|
 | 
						|
def flatten(list_of_lists):
 | 
						|
    """
 | 
						|
    Takes an iterable of iterables, returns a single iterable containing all items
 | 
						|
    """
 | 
						|
    return itertools.chain(*list_of_lists)
 | 
						|
 | 
						|
 | 
						|
class BaseRouter:
 | 
						|
    def __init__(self):
 | 
						|
        self.registry = []
 | 
						|
 | 
						|
    def register(self, prefix, viewset, basename=None):
 | 
						|
        if basename is None:
 | 
						|
            basename = self.get_default_basename(viewset)
 | 
						|
 | 
						|
        if self.is_already_registered(basename):
 | 
						|
            msg = (f'Router with basename "{basename}" is already registered. '
 | 
						|
                   f'Please provide a unique basename for viewset "{viewset}"')
 | 
						|
            raise ImproperlyConfigured(msg)
 | 
						|
 | 
						|
        self.registry.append((prefix, viewset, basename))
 | 
						|
 | 
						|
        # invalidate the urls cache
 | 
						|
        if hasattr(self, '_urls'):
 | 
						|
            del self._urls
 | 
						|
 | 
						|
    def is_already_registered(self, new_basename):
 | 
						|
        """
 | 
						|
        Check if `basename` is already registered
 | 
						|
        """
 | 
						|
        return any(basename == new_basename for _prefix, _viewset, basename in self.registry)
 | 
						|
 | 
						|
    def get_default_basename(self, viewset):
 | 
						|
        """
 | 
						|
        If `basename` is not specified, attempt to automatically determine
 | 
						|
        it from the viewset.
 | 
						|
        """
 | 
						|
        raise NotImplementedError('get_default_basename must be overridden')
 | 
						|
 | 
						|
    def get_urls(self):
 | 
						|
        """
 | 
						|
        Return a list of URL patterns, given the registered viewsets.
 | 
						|
        """
 | 
						|
        raise NotImplementedError('get_urls must be overridden')
 | 
						|
 | 
						|
    @property
 | 
						|
    def urls(self):
 | 
						|
        if not hasattr(self, '_urls'):
 | 
						|
            self._urls = self.get_urls()
 | 
						|
        return self._urls
 | 
						|
 | 
						|
 | 
						|
class SimpleRouter(BaseRouter):
 | 
						|
 | 
						|
    routes = [
 | 
						|
        # List route.
 | 
						|
        Route(
 | 
						|
            url=r'^{prefix}{trailing_slash}$',
 | 
						|
            mapping={
 | 
						|
                'get': 'list',
 | 
						|
                'post': 'create'
 | 
						|
            },
 | 
						|
            name='{basename}-list',
 | 
						|
            detail=False,
 | 
						|
            initkwargs={'suffix': 'List'}
 | 
						|
        ),
 | 
						|
        # Dynamically generated list routes. Generated using
 | 
						|
        # @action(detail=False) decorator on methods of the viewset.
 | 
						|
        DynamicRoute(
 | 
						|
            url=r'^{prefix}/{url_path}{trailing_slash}$',
 | 
						|
            name='{basename}-{url_name}',
 | 
						|
            detail=False,
 | 
						|
            initkwargs={}
 | 
						|
        ),
 | 
						|
        # Detail route.
 | 
						|
        Route(
 | 
						|
            url=r'^{prefix}/{lookup}{trailing_slash}$',
 | 
						|
            mapping={
 | 
						|
                'get': 'retrieve',
 | 
						|
                'put': 'update',
 | 
						|
                'patch': 'partial_update',
 | 
						|
                'delete': 'destroy'
 | 
						|
            },
 | 
						|
            name='{basename}-detail',
 | 
						|
            detail=True,
 | 
						|
            initkwargs={'suffix': 'Instance'}
 | 
						|
        ),
 | 
						|
        # Dynamically generated detail routes. Generated using
 | 
						|
        # @action(detail=True) decorator on methods of the viewset.
 | 
						|
        DynamicRoute(
 | 
						|
            url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
 | 
						|
            name='{basename}-{url_name}',
 | 
						|
            detail=True,
 | 
						|
            initkwargs={}
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
 | 
						|
    def __init__(self, trailing_slash=True, use_regex_path=True):
 | 
						|
        self.trailing_slash = '/' if trailing_slash else ''
 | 
						|
        self._use_regex = use_regex_path
 | 
						|
        if use_regex_path:
 | 
						|
            self._base_pattern = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
 | 
						|
            self._default_value_pattern = '[^/.]+'
 | 
						|
            self._url_conf = re_path
 | 
						|
        else:
 | 
						|
            self._base_pattern = '<{lookup_value}:{lookup_prefix}{lookup_url_kwarg}>'
 | 
						|
            self._default_value_pattern = 'str'
 | 
						|
            self._url_conf = path
 | 
						|
            # remove regex characters from routes
 | 
						|
            _routes = []
 | 
						|
            for route in self.routes:
 | 
						|
                url_param = route.url
 | 
						|
                if url_param[0] == '^':
 | 
						|
                    url_param = url_param[1:]
 | 
						|
                if url_param[-1] == '$':
 | 
						|
                    url_param = url_param[:-1]
 | 
						|
 | 
						|
                _routes.append(route._replace(url=url_param))
 | 
						|
            self.routes = _routes
 | 
						|
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
    def get_default_basename(self, viewset):
 | 
						|
        """
 | 
						|
        If `basename` is not specified, attempt to automatically determine
 | 
						|
        it from the viewset.
 | 
						|
        """
 | 
						|
        queryset = getattr(viewset, 'queryset', None)
 | 
						|
 | 
						|
        assert queryset is not None, '`basename` argument not specified, and could ' \
 | 
						|
            'not automatically determine the name from the viewset, as ' \
 | 
						|
            'it does not have a `.queryset` attribute.'
 | 
						|
 | 
						|
        return queryset.model._meta.object_name.lower()
 | 
						|
 | 
						|
    def get_routes(self, viewset):
 | 
						|
        """
 | 
						|
        Augment `self.routes` with any dynamically generated routes.
 | 
						|
 | 
						|
        Returns a list of the Route namedtuple.
 | 
						|
        """
 | 
						|
        # converting to list as iterables are good for one pass, known host needs to be checked again and again for
 | 
						|
        # different functions.
 | 
						|
        known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
 | 
						|
        extra_actions = viewset.get_extra_actions()
 | 
						|
 | 
						|
        # checking action names against the known actions list
 | 
						|
        not_allowed = [
 | 
						|
            action.__name__ for action in extra_actions
 | 
						|
            if action.__name__ in known_actions
 | 
						|
        ]
 | 
						|
        if not_allowed:
 | 
						|
            msg = ('Cannot use the @action decorator on the following '
 | 
						|
                   'methods, as they are existing routes: %s')
 | 
						|
            raise ImproperlyConfigured(msg % ', '.join(not_allowed))
 | 
						|
 | 
						|
        # partition detail and list actions
 | 
						|
        detail_actions = [action for action in extra_actions if action.detail]
 | 
						|
        list_actions = [action for action in extra_actions if not action.detail]
 | 
						|
 | 
						|
        routes = []
 | 
						|
        for route in self.routes:
 | 
						|
            if isinstance(route, DynamicRoute) and route.detail:
 | 
						|
                routes += [self._get_dynamic_route(route, action) for action in detail_actions]
 | 
						|
            elif isinstance(route, DynamicRoute) and not route.detail:
 | 
						|
                routes += [self._get_dynamic_route(route, action) for action in list_actions]
 | 
						|
            else:
 | 
						|
                routes.append(route)
 | 
						|
 | 
						|
        return routes
 | 
						|
 | 
						|
    def _get_dynamic_route(self, route, action):
 | 
						|
        initkwargs = route.initkwargs.copy()
 | 
						|
        initkwargs.update(action.kwargs)
 | 
						|
 | 
						|
        url_path = escape_curly_brackets(action.url_path)
 | 
						|
 | 
						|
        return Route(
 | 
						|
            url=route.url.replace('{url_path}', url_path),
 | 
						|
            mapping=action.mapping,
 | 
						|
            name=route.name.replace('{url_name}', action.url_name),
 | 
						|
            detail=route.detail,
 | 
						|
            initkwargs=initkwargs,
 | 
						|
        )
 | 
						|
 | 
						|
    def get_method_map(self, viewset, method_map):
 | 
						|
        """
 | 
						|
        Given a viewset, and a mapping of http methods to actions,
 | 
						|
        return a new mapping which only includes any mappings that
 | 
						|
        are actually implemented by the viewset.
 | 
						|
        """
 | 
						|
        bound_methods = {}
 | 
						|
        for method, action in method_map.items():
 | 
						|
            if hasattr(viewset, action):
 | 
						|
                bound_methods[method] = action
 | 
						|
        return bound_methods
 | 
						|
 | 
						|
    def get_lookup_regex(self, viewset, lookup_prefix=''):
 | 
						|
        """
 | 
						|
        Given a viewset, return the portion of URL regex that is used
 | 
						|
        to match against a single instance.
 | 
						|
 | 
						|
        Note that lookup_prefix is not used directly inside REST rest_framework
 | 
						|
        itself, but is required in order to nicely support nested router
 | 
						|
        implementations, such as drf-nested-routers.
 | 
						|
 | 
						|
        https://github.com/alanjds/drf-nested-routers
 | 
						|
        """
 | 
						|
        # Use `pk` as default field, unset set.  Default regex should not
 | 
						|
        # consume `.json` style suffixes and should break at '/' boundaries.
 | 
						|
        lookup_field = getattr(viewset, 'lookup_field', 'pk')
 | 
						|
        lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
 | 
						|
        lookup_value = None
 | 
						|
        if not self._use_regex:
 | 
						|
            # try to get a more appropriate attribute when not using regex
 | 
						|
            lookup_value = getattr(viewset, 'lookup_value_converter', None)
 | 
						|
        if lookup_value is None:
 | 
						|
            # fallback to legacy
 | 
						|
            lookup_value = getattr(viewset, 'lookup_value_regex', self._default_value_pattern)
 | 
						|
        return self._base_pattern.format(
 | 
						|
            lookup_prefix=lookup_prefix,
 | 
						|
            lookup_url_kwarg=lookup_url_kwarg,
 | 
						|
            lookup_value=lookup_value
 | 
						|
        )
 | 
						|
 | 
						|
    def get_urls(self):
 | 
						|
        """
 | 
						|
        Use the registered viewsets to generate a list of URL patterns.
 | 
						|
        """
 | 
						|
        ret = []
 | 
						|
 | 
						|
        for prefix, viewset, basename in self.registry:
 | 
						|
            lookup = self.get_lookup_regex(viewset)
 | 
						|
            routes = self.get_routes(viewset)
 | 
						|
 | 
						|
            for route in routes:
 | 
						|
 | 
						|
                # Only actions which actually exist on the viewset will be bound
 | 
						|
                mapping = self.get_method_map(viewset, route.mapping)
 | 
						|
                if not mapping:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # Build the url pattern
 | 
						|
                regex = route.url.format(
 | 
						|
                    prefix=prefix,
 | 
						|
                    lookup=lookup,
 | 
						|
                    trailing_slash=self.trailing_slash
 | 
						|
                )
 | 
						|
 | 
						|
                # If there is no prefix, the first part of the url is probably
 | 
						|
                #   controlled by project's urls.py and the router is in an app,
 | 
						|
                #   so a slash in the beginning will (A) cause Django to give
 | 
						|
                #   warnings and (B) generate URLS that will require using '//'.
 | 
						|
                if not prefix:
 | 
						|
                    if self._url_conf is path:
 | 
						|
                        if regex[0] == '/':
 | 
						|
                            regex = regex[1:]
 | 
						|
                    elif regex[:2] == '^/':
 | 
						|
                        regex = '^' + regex[2:]
 | 
						|
 | 
						|
                initkwargs = route.initkwargs.copy()
 | 
						|
                initkwargs.update({
 | 
						|
                    'basename': basename,
 | 
						|
                    'detail': route.detail,
 | 
						|
                })
 | 
						|
 | 
						|
                view = viewset.as_view(mapping, **initkwargs)
 | 
						|
                name = route.name.format(basename=basename)
 | 
						|
                ret.append(self._url_conf(regex, view, name=name))
 | 
						|
 | 
						|
        return ret
 | 
						|
 | 
						|
 | 
						|
class APIRootView(views.APIView):
 | 
						|
    """
 | 
						|
    The default basic root view for DefaultRouter
 | 
						|
    """
 | 
						|
    _ignore_model_permissions = True
 | 
						|
    schema = None  # exclude from schema
 | 
						|
    api_root_dict = None
 | 
						|
 | 
						|
    def get(self, request, *args, **kwargs):
 | 
						|
        # Return a plain {"name": "hyperlink"} response.
 | 
						|
        ret = {}
 | 
						|
        namespace = request.resolver_match.namespace
 | 
						|
        for key, url_name in self.api_root_dict.items():
 | 
						|
            if namespace:
 | 
						|
                url_name = namespace + ':' + url_name
 | 
						|
            try:
 | 
						|
                ret[key] = reverse(
 | 
						|
                    url_name,
 | 
						|
                    args=args,
 | 
						|
                    kwargs=kwargs,
 | 
						|
                    request=request,
 | 
						|
                    format=kwargs.get('format')
 | 
						|
                )
 | 
						|
            except NoReverseMatch:
 | 
						|
                # Don't bail out if eg. no list routes exist, only detail routes.
 | 
						|
                continue
 | 
						|
 | 
						|
        return Response(ret)
 | 
						|
 | 
						|
 | 
						|
class DefaultRouter(SimpleRouter):
 | 
						|
    """
 | 
						|
    The default router extends the SimpleRouter, but also adds in a default
 | 
						|
    API root view, and adds format suffix patterns to the URLs.
 | 
						|
    """
 | 
						|
    include_root_view = True
 | 
						|
    include_format_suffixes = True
 | 
						|
    root_view_name = 'api-root'
 | 
						|
    default_schema_renderers = None
 | 
						|
    APIRootView = APIRootView
 | 
						|
    APISchemaView = SchemaView
 | 
						|
    SchemaGenerator = SchemaGenerator
 | 
						|
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        if 'root_renderers' in kwargs:
 | 
						|
            self.root_renderers = kwargs.pop('root_renderers')
 | 
						|
        else:
 | 
						|
            self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
    def get_api_root_view(self, api_urls=None):
 | 
						|
        """
 | 
						|
        Return a basic root view.
 | 
						|
        """
 | 
						|
        api_root_dict = {}
 | 
						|
        list_name = self.routes[0].name
 | 
						|
        for prefix, viewset, basename in self.registry:
 | 
						|
            api_root_dict[prefix] = list_name.format(basename=basename)
 | 
						|
 | 
						|
        return self.APIRootView.as_view(api_root_dict=api_root_dict)
 | 
						|
 | 
						|
    def get_urls(self):
 | 
						|
        """
 | 
						|
        Generate the list of URL patterns, including a default root view
 | 
						|
        for the API, and appending `.json` style format suffixes.
 | 
						|
        """
 | 
						|
        urls = super().get_urls()
 | 
						|
 | 
						|
        if self.include_root_view:
 | 
						|
            view = self.get_api_root_view(api_urls=urls)
 | 
						|
            root_url = path('', view, name=self.root_view_name)
 | 
						|
            urls.append(root_url)
 | 
						|
 | 
						|
        if self.include_format_suffixes:
 | 
						|
            urls = format_suffix_patterns(urls)
 | 
						|
 | 
						|
        return urls
 |