From 34a93e6a50365dc8901fa493b1bf6d4eb3d09bf7 Mon Sep 17 00:00:00 2001 From: Ivan Gonzalez Date: Mon, 4 Jan 2021 12:41:34 -0500 Subject: [PATCH] feat(routers): add .extend method to BaseRouter and include function the .extend method allow to extend the routes of the current router with router from other router and a optional new prefix the include function allow to include a router without manually importing it. This is useful when importing multiple routers just for using the .extend method but nothing else --- rest_framework/routers.py | 52 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index e2afa573f..6a3b9b56b 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -15,6 +15,7 @@ For example, you might have a `urls.py` that looks something like this: """ import itertools from collections import OrderedDict, namedtuple +from importlib import import_module from django.core.exceptions import ImproperlyConfigured from django.urls import NoReverseMatch, re_path @@ -49,6 +50,20 @@ class BaseRouter: def __init__(self): self.registry = [] + @classmethod + def _include(cls, module, router_name="router"): + """ + Allow to import router object from another app + """ + router_module = import_module(module) + + router = getattr(router_module, router_name) + + if not isinstance(router, cls): + raise ValueError("The router should be an instance (direct or indirect) of BaseRouter") + + return router + def register(self, prefix, viewset, basename=None): if basename is None: basename = self.get_default_basename(viewset) @@ -58,6 +73,40 @@ class BaseRouter: if hasattr(self, '_urls'): del self._urls + def extend(self, prefix, router): + """ + Extend the routes with url routes from the router of the module passed. + + Example: + from django.urls import path, include + from rest_framework import routers + + router = routers.DefaultRouter() + + router.extend('products', routers.include('project.products.urls')) # using include + + from .users.urls import router as users_router + router.extend('users', users_router) # manually importing the router + + + urlpatterns = [ + path("api/", include(router.urls)) + ] + + You can avoid naming collisions with `django.urls` include function using named imports + or importing the whole routers module: + >>> from rest_framework.routers import include as router_include + >>> from rest_framework import routers + """ + if not prefix.endswith("/"): + # TODO: warn or not the user to put an ending forward slash + prefix += "/" + + for old_prefix, viewset, basename in router.registry: + new_prefix = prefix + old_prefix + + self.register(new_prefix, viewset, basename) + def get_default_basename(self, viewset): """ If `basename` is not specified, attempt to automatically determine @@ -78,6 +127,9 @@ class BaseRouter: return self._urls +include = BaseRouter._include + + class SimpleRouter(BaseRouter): routes = [