diff --git a/rest_framework/db_routers.py b/rest_framework/db_routers.py new file mode 100644 index 000000000..a48ebd804 --- /dev/null +++ b/rest_framework/db_routers.py @@ -0,0 +1,12 @@ +from django.db import router + + +class BaseDbRouter(object): + def get_db_alias(self, request, model): + raise NotImplementedError(".get_db_alias() must be overridden.") + + +class DjangoDbRouter(BaseDbRouter): + def get_db_alias(self, request, model): + if request.method.lower() != "get": + return router.db_for_write(model) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8d0bf284a..2b8567e72 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -47,6 +47,9 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # The database router classes to find the correct db alias for queryset + db_router_classes = api_settings.DEFAULT_DB_ROUTER_CLASSES + def get_queryset(self): """ Get the list of items for this view. @@ -72,8 +75,19 @@ class GenericAPIView(views.APIView): if isinstance(queryset, QuerySet): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() + + return self.apply_db_routing(queryset) + + def apply_db_routing(self, queryset): + for router in self.get_db_routers(): + alias = router.get_db_alias(self.request, queryset.model) + if alias: + return queryset.using(alias) return queryset + def get_db_routers(self): + return [db_router() for db_router in self.db_router_classes] + def get_object(self): """ Returns the object the view is displaying. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 3f3c9110a..e51aec282 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -53,6 +53,7 @@ DEFAULTS = { # Generic view behavior 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', 'DEFAULT_FILTER_BACKENDS': (), + 'DEFAULT_DB_ROUTER_CLASSES': [], # Throttling 'DEFAULT_THROTTLE_RATES': { @@ -138,6 +139,7 @@ IMPORT_STRINGS = ( 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', + 'DEFAULT_DB_ROUTER_CLASSES', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', diff --git a/tests/test_db_routers.py b/tests/test_db_routers.py new file mode 100644 index 000000000..c44a80d85 --- /dev/null +++ b/tests/test_db_routers.py @@ -0,0 +1,66 @@ +from django.test import TestCase + +from rest_framework import db_routers, generics +from rest_framework.test import APIRequestFactory + +from .models import BasicModel + + +class NoRoutersView(generics.GenericAPIView): + db_router_classes = [] + queryset = BasicModel.objects.all() + + +class DjangoRouterView(generics.GenericAPIView): + db_router_classes = [db_routers.DjangoDbRouter] + queryset = BasicModel.objects.all() + + +class Router(object): + def db_for_read(self, model, **hints): + return 'db_for_read' + + def db_for_write(self, model, **hints): + return 'db_for_write' + + +class BaseRoutersTest(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + + def assertViewUsesAliasForQuerySet(self, view_class, http_method, db_alias): + with self.settings(DATABASE_ROUTERS=['tests.test_db_routers.Router']): + request = getattr(self.factory, http_method)('/') + view = view_class() + view.request = request + + queryset = view.get_queryset() + self.assertEqual(queryset._db, db_alias) + + +class TestNoDbRouters(BaseRoutersTest): + def test_get(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='get', db_alias=None) + + def test_post(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='post', db_alias=None) + + def test_put(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='put', db_alias=None) + + def test_patch(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='patch', db_alias=None) + + +class TestDjangoDbRouter(BaseRoutersTest): + def test_get(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='get', db_alias=None) + + def test_post(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='post', db_alias='db_for_write') + + def test_put(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='put', db_alias='db_for_write') + + def test_patch(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='patch', db_alias='db_for_write')