This commit is contained in:
Viktor Ershov 2017-06-06 20:17:39 +00:00 committed by GitHub
commit e186265d53
4 changed files with 94 additions and 0 deletions

View File

@ -0,0 +1,12 @@
from django.db import router
class BaseDbRouter(object):
def get_db_alias(self, request, model):
raise NotImplementedError(".get_db_alias() must be overridden.")
class DjangoDbRouter(BaseDbRouter):
def get_db_alias(self, request, model):
if request.method.lower() != "get":
return router.db_for_write(model)

View File

@ -47,6 +47,9 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
# The database router classes to find the correct db alias for queryset
db_router_classes = api_settings.DEFAULT_DB_ROUTER_CLASSES
def get_queryset(self):
"""
Get the list of items for this view.
@ -72,8 +75,19 @@ class GenericAPIView(views.APIView):
if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return self.apply_db_routing(queryset)
def apply_db_routing(self, queryset):
for router in self.get_db_routers():
alias = router.get_db_alias(self.request, queryset.model)
if alias:
return queryset.using(alias)
return queryset
def get_db_routers(self):
return [db_router() for db_router in self.db_router_classes]
def get_object(self):
"""
Returns the object the view is displaying.

View File

@ -53,6 +53,7 @@ DEFAULTS = {
# Generic view behavior
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'DEFAULT_FILTER_BACKENDS': (),
'DEFAULT_DB_ROUTER_CLASSES': [],
# Throttling
'DEFAULT_THROTTLE_RATES': {
@ -138,6 +139,7 @@ IMPORT_STRINGS = (
'DEFAULT_VERSIONING_CLASS',
'DEFAULT_PAGINATION_CLASS',
'DEFAULT_FILTER_BACKENDS',
'DEFAULT_DB_ROUTER_CLASSES',
'EXCEPTION_HANDLER',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',

66
tests/test_db_routers.py Normal file
View File

@ -0,0 +1,66 @@
from django.test import TestCase
from rest_framework import db_routers, generics
from rest_framework.test import APIRequestFactory
from .models import BasicModel
class NoRoutersView(generics.GenericAPIView):
db_router_classes = []
queryset = BasicModel.objects.all()
class DjangoRouterView(generics.GenericAPIView):
db_router_classes = [db_routers.DjangoDbRouter]
queryset = BasicModel.objects.all()
class Router(object):
def db_for_read(self, model, **hints):
return 'db_for_read'
def db_for_write(self, model, **hints):
return 'db_for_write'
class BaseRoutersTest(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
def assertViewUsesAliasForQuerySet(self, view_class, http_method, db_alias):
with self.settings(DATABASE_ROUTERS=['tests.test_db_routers.Router']):
request = getattr(self.factory, http_method)('/')
view = view_class()
view.request = request
queryset = view.get_queryset()
self.assertEqual(queryset._db, db_alias)
class TestNoDbRouters(BaseRoutersTest):
def test_get(self):
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='get', db_alias=None)
def test_post(self):
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='post', db_alias=None)
def test_put(self):
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='put', db_alias=None)
def test_patch(self):
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='patch', db_alias=None)
class TestDjangoDbRouter(BaseRoutersTest):
def test_get(self):
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='get', db_alias=None)
def test_post(self):
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='post', db_alias='db_for_write')
def test_put(self):
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='put', db_alias='db_for_write')
def test_patch(self):
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='patch', db_alias='db_for_write')