From d903b07ccfe8ec931d518be207857094ddb5ff46 Mon Sep 17 00:00:00 2001 From: Viktor Ershov Date: Mon, 21 Sep 2015 18:42:25 +0300 Subject: [PATCH 1/3] Implement database routers --- rest_framework/db_routers.py | 12 +++++++ rest_framework/generics.py | 14 ++++++++ rest_framework/settings.py | 4 ++- tests/test_db_routers.py | 66 ++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 rest_framework/db_routers.py create mode 100644 tests/test_db_routers.py diff --git a/rest_framework/db_routers.py b/rest_framework/db_routers.py new file mode 100644 index 000000000..a48ebd804 --- /dev/null +++ b/rest_framework/db_routers.py @@ -0,0 +1,12 @@ +from django.db import router + + +class BaseDbRouter(object): + def get_db_alias(self, request, model): + raise NotImplementedError(".get_db_alias() must be overridden.") + + +class DjangoDbRouter(BaseDbRouter): + def get_db_alias(self, request, model): + if request.method.lower() != "get": + return router.db_for_write(model) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 88438e8c4..12be4b41e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -46,6 +46,9 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # The database router classes to use as queryset aliases + db_router_classes = api_settings.DEFAULT_DB_ROUTER_CLASSES + def get_queryset(self): """ Get the list of items for this view. @@ -71,8 +74,19 @@ class GenericAPIView(views.APIView): if isinstance(queryset, QuerySet): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() + + return self.apply_db_routing(queryset) + + def apply_db_routing(self, queryset): + for router in self.get_db_routers(): + alias = router.get_db_alias(self.request, queryset.model) + if alias: + return queryset.using(alias) return queryset + def get_db_routers(self): + return [db_router() for db_router in self.db_router_classes] + def get_object(self): """ Returns the object the view is displaying. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index e20e51287..312312c9d 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -54,6 +54,7 @@ DEFAULTS = { # Generic view behavior 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', 'DEFAULT_FILTER_BACKENDS': (), + 'DEFAULT_DB_ROUTER_CLASSES': [], # Throttling 'DEFAULT_THROTTLE_RATES': { @@ -137,7 +138,8 @@ IMPORT_STRINGS = ( 'DEFAULT_METADATA_CLASS', 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_CLASS', - 'DEFAULT_FILTER_BACKENDS', + 'DEFAULT_FILTER_BACKENDS' + 'DEFAULT_DB_ROUTER_CLASSES', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', diff --git a/tests/test_db_routers.py b/tests/test_db_routers.py new file mode 100644 index 000000000..c44a80d85 --- /dev/null +++ b/tests/test_db_routers.py @@ -0,0 +1,66 @@ +from django.test import TestCase + +from rest_framework import db_routers, generics +from rest_framework.test import APIRequestFactory + +from .models import BasicModel + + +class NoRoutersView(generics.GenericAPIView): + db_router_classes = [] + queryset = BasicModel.objects.all() + + +class DjangoRouterView(generics.GenericAPIView): + db_router_classes = [db_routers.DjangoDbRouter] + queryset = BasicModel.objects.all() + + +class Router(object): + def db_for_read(self, model, **hints): + return 'db_for_read' + + def db_for_write(self, model, **hints): + return 'db_for_write' + + +class BaseRoutersTest(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + + def assertViewUsesAliasForQuerySet(self, view_class, http_method, db_alias): + with self.settings(DATABASE_ROUTERS=['tests.test_db_routers.Router']): + request = getattr(self.factory, http_method)('/') + view = view_class() + view.request = request + + queryset = view.get_queryset() + self.assertEqual(queryset._db, db_alias) + + +class TestNoDbRouters(BaseRoutersTest): + def test_get(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='get', db_alias=None) + + def test_post(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='post', db_alias=None) + + def test_put(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='put', db_alias=None) + + def test_patch(self): + self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='patch', db_alias=None) + + +class TestDjangoDbRouter(BaseRoutersTest): + def test_get(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='get', db_alias=None) + + def test_post(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='post', db_alias='db_for_write') + + def test_put(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='put', db_alias='db_for_write') + + def test_patch(self): + self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='patch', db_alias='db_for_write') From 0800931a81473b795352db7900315a96d141547a Mon Sep 17 00:00:00 2001 From: Viktor Ershov Date: Mon, 21 Sep 2015 18:49:53 +0300 Subject: [PATCH 2/3] fix comment --- rest_framework/generics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 12be4b41e..397aae4c9 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -46,7 +46,7 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS - # The database router classes to use as queryset aliases + # The database router classes to find the correct db alias for queryset db_router_classes = api_settings.DEFAULT_DB_ROUTER_CLASSES def get_queryset(self): From d8d8e2fef055a31f40e629a0302c11ed97150109 Mon Sep 17 00:00:00 2001 From: Viktor Ershov Date: Mon, 21 Sep 2015 18:50:52 +0300 Subject: [PATCH 3/3] fix settings --- rest_framework/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 312312c9d..d4dc02a24 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -138,7 +138,7 @@ IMPORT_STRINGS = ( 'DEFAULT_METADATA_CLASS', 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_CLASS', - 'DEFAULT_FILTER_BACKENDS' + 'DEFAULT_FILTER_BACKENDS', 'DEFAULT_DB_ROUTER_CLASSES', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES',