""" Routers provide a convenient and consistent way of automatically determining the URL conf for your API. They are used by simply instantiating a Router class, and then registering all the required ViewSets with that router. For example, you might have a `urls.py` that looks something like this: router = routers.DefaultRouter() router.register('users', UserViewSet, 'user') router.register('accounts', AccountViewSet, 'account') urlpatterns = router.urls """ from __future__ import unicode_literals import itertools from collections import namedtuple from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import NoReverseMatch from rest_framework import views from rest_framework.compat import OrderedDict, get_resolver_match from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.urlpatterns import format_suffix_patterns Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) DynamicDetailRoute = namedtuple('DynamicDetailRoute', ['url', 'name', 'initkwargs']) DynamicListRoute = namedtuple('DynamicListRoute', ['url', 'name', 'initkwargs']) def replace_methodname(format_string, methodname): """ Partially format a format_string, swapping out any '{methodname}' or '{methodnamehyphen}' components. """ methodnamehyphen = methodname.replace('_', '-') ret = format_string ret = ret.replace('{methodname}', methodname) ret = ret.replace('{methodnamehyphen}', methodnamehyphen) return ret def flatten(list_of_lists): """ Takes an iterable of iterables, returns a single iterable containing all items """ return itertools.chain(*list_of_lists) class BaseRouter(object): def __init__(self): self.registry = [] def register(self, prefix, viewset, base_name=None): if base_name is None: base_name = self.get_default_base_name(viewset) self.registry.append((prefix, viewset, base_name)) def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine it from the viewset. """ raise NotImplementedError('get_default_base_name must be overridden') def get_urls(self): """ Return a list of URL patterns, given the registered viewsets. """ raise NotImplementedError('get_urls must be overridden') @property def urls(self): if not hasattr(self, '_urls'): self._urls = self.get_urls() return self._urls class SimpleRouter(BaseRouter): routes = [ # List route. Route( url=r'^{prefix}{trailing_slash}$', mapping={ 'get': 'list', 'post': 'create' }, name='{basename}-list', initkwargs={'suffix': 'List'} ), # Dynamically generated list routes. # Generated using @list_route decorator # on methods of the viewset. DynamicListRoute( url=r'^{prefix}/{methodname}{trailing_slash}$', name='{basename}-{methodnamehyphen}', initkwargs={} ), # Detail route. Route( url=r'^{prefix}/{lookup}{trailing_slash}$', mapping={ 'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy' }, name='{basename}-detail', initkwargs={'suffix': 'Instance'} ), # Dynamically generated detail routes. # Generated using @detail_route decorator on methods of the viewset. DynamicDetailRoute( url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$', name='{basename}-{methodnamehyphen}', initkwargs={} ), ] def __init__(self, trailing_slash=True): self.trailing_slash = trailing_slash and '/' or '' super(SimpleRouter, self).__init__() def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine it from the viewset. """ queryset = getattr(viewset, 'queryset', None) assert queryset is not None, '`base_name` argument not specified, and could ' \ 'not automatically determine the name from the viewset, as ' \ 'it does not have a `.queryset` attribute.' return queryset.model._meta.object_name.lower() def get_routes(self, viewset): """ Augment `self.routes` with any dynamically generated routes. Returns a list of the Route namedtuple. """ known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]) # Determine any `@detail_route` or `@list_route` decorated methods on the viewset detail_routes = [] list_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) httpmethods = getattr(attr, 'bind_to_methods', None) detail = getattr(attr, 'detail', True) if httpmethods: if methodname in known_actions: raise ImproperlyConfigured('Cannot use @detail_route or @list_route ' 'decorators on method "%s" ' 'as it is an existing route' % methodname) httpmethods = [method.lower() for method in httpmethods] if detail: detail_routes.append((httpmethods, methodname)) else: list_routes.append((httpmethods, methodname)) def _get_dynamic_routes(route, dynamic_routes): ret = [] for httpmethods, methodname in dynamic_routes: method_kwargs = getattr(viewset, methodname).kwargs initkwargs = route.initkwargs.copy() initkwargs.update(method_kwargs) url_path = initkwargs.pop("url_path", None) or methodname ret.append(Route( url=replace_methodname(route.url, url_path), mapping=dict((httpmethod, methodname) for httpmethod in httpmethods), name=replace_methodname(route.name, url_path), initkwargs=initkwargs, )) return ret ret = [] for route in self.routes: if isinstance(route, DynamicDetailRoute): # Dynamic detail routes (@detail_route decorator) ret += _get_dynamic_routes(route, detail_routes) elif isinstance(route, DynamicListRoute): # Dynamic list routes (@list_route decorator) ret += _get_dynamic_routes(route, list_routes) else: # Standard route ret.append(route) return ret def get_method_map(self, viewset, method_map): """ Given a viewset, and a mapping of http methods to actions, return a new mapping which only includes any mappings that are actually implemented by the viewset. """ bound_methods = {} for method, action in method_map.items(): if hasattr(viewset, action): bound_methods[method] = action return bound_methods def get_lookup_regex(self, viewset, lookup_prefix=''): """ Given a viewset, return the portion of URL regex that is used to match against a single instance. Note that lookup_prefix is not used directly inside REST rest_framework itself, but is required in order to nicely support nested router implementations, such as drf-nested-routers. https://github.com/alanjds/drf-nested-routers """ base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})' # Use `pk` as default field, unset set. Default regex should not # consume `.json` style suffixes and should break at '/' boundaries. lookup_field = getattr(viewset, 'lookup_field', 'pk') lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+') return base_regex.format( lookup_prefix=lookup_prefix, lookup_url_kwarg=lookup_url_kwarg, lookup_value=lookup_value ) def get_urls(self): """ Use the registered viewsets to generate a list of URL patterns. """ ret = [] for prefix, viewset, basename in self.registry: lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) for route in routes: # Only actions which actually exist on the viewset will be bound mapping = self.get_method_map(viewset, route.mapping) if not mapping: continue # Build the url pattern regex = route.url.format( prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash ) view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) return ret class DefaultRouter(SimpleRouter): """ The default router extends the SimpleRouter, but also adds in a default API root view, and adds format suffix patterns to the URLs. """ include_root_view = True include_format_suffixes = True root_view_name = 'api-root' def get_api_root_view(self): """ Return a view to use as the API root. """ api_root_dict = OrderedDict() list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) class APIRoot(views.APIView): _ignore_model_permissions = True def get(self, request, *args, **kwargs): ret = OrderedDict() namespace = get_resolver_match(request).namespace for key, url_name in api_root_dict.items(): if namespace: url_name = namespace + ':' + url_name try: ret[key] = reverse( url_name, request=request, format=kwargs.get('format', None) ) except NoReverseMatch: # Don't bail out if eg. no list routes exist, only detail routes. continue return Response(ret) return APIRoot.as_view() def get_urls(self): """ Generate the list of URL patterns, including a default root view for the API, and appending `.json` style format suffixes. """ urls = [] if self.include_root_view: root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) urls.append(root_url) default_urls = super(DefaultRouter, self).get_urls() urls.extend(default_urls) if self.include_format_suffixes: urls = format_suffix_patterns(urls) return urls