mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 20:40:14 +03:00
Implement database routers
This commit is contained in:
parent
f2c65512c6
commit
d903b07ccf
12
rest_framework/db_routers.py
Normal file
12
rest_framework/db_routers.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from django.db import router
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDbRouter(object):
|
||||||
|
def get_db_alias(self, request, model):
|
||||||
|
raise NotImplementedError(".get_db_alias() must be overridden.")
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDbRouter(BaseDbRouter):
|
||||||
|
def get_db_alias(self, request, model):
|
||||||
|
if request.method.lower() != "get":
|
||||||
|
return router.db_for_write(model)
|
|
@ -46,6 +46,9 @@ class GenericAPIView(views.APIView):
|
||||||
# The style to use for queryset pagination.
|
# The style to use for queryset pagination.
|
||||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||||
|
|
||||||
|
# The database router classes to use as queryset aliases
|
||||||
|
db_router_classes = api_settings.DEFAULT_DB_ROUTER_CLASSES
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""
|
"""
|
||||||
Get the list of items for this view.
|
Get the list of items for this view.
|
||||||
|
@ -71,8 +74,19 @@ class GenericAPIView(views.APIView):
|
||||||
if isinstance(queryset, QuerySet):
|
if isinstance(queryset, QuerySet):
|
||||||
# Ensure queryset is re-evaluated on each request.
|
# Ensure queryset is re-evaluated on each request.
|
||||||
queryset = queryset.all()
|
queryset = queryset.all()
|
||||||
|
|
||||||
|
return self.apply_db_routing(queryset)
|
||||||
|
|
||||||
|
def apply_db_routing(self, queryset):
|
||||||
|
for router in self.get_db_routers():
|
||||||
|
alias = router.get_db_alias(self.request, queryset.model)
|
||||||
|
if alias:
|
||||||
|
return queryset.using(alias)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
def get_db_routers(self):
|
||||||
|
return [db_router() for db_router in self.db_router_classes]
|
||||||
|
|
||||||
def get_object(self):
|
def get_object(self):
|
||||||
"""
|
"""
|
||||||
Returns the object the view is displaying.
|
Returns the object the view is displaying.
|
||||||
|
|
|
@ -54,6 +54,7 @@ DEFAULTS = {
|
||||||
# Generic view behavior
|
# Generic view behavior
|
||||||
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
||||||
'DEFAULT_FILTER_BACKENDS': (),
|
'DEFAULT_FILTER_BACKENDS': (),
|
||||||
|
'DEFAULT_DB_ROUTER_CLASSES': [],
|
||||||
|
|
||||||
# Throttling
|
# Throttling
|
||||||
'DEFAULT_THROTTLE_RATES': {
|
'DEFAULT_THROTTLE_RATES': {
|
||||||
|
@ -137,7 +138,8 @@ IMPORT_STRINGS = (
|
||||||
'DEFAULT_METADATA_CLASS',
|
'DEFAULT_METADATA_CLASS',
|
||||||
'DEFAULT_VERSIONING_CLASS',
|
'DEFAULT_VERSIONING_CLASS',
|
||||||
'DEFAULT_PAGINATION_CLASS',
|
'DEFAULT_PAGINATION_CLASS',
|
||||||
'DEFAULT_FILTER_BACKENDS',
|
'DEFAULT_FILTER_BACKENDS'
|
||||||
|
'DEFAULT_DB_ROUTER_CLASSES',
|
||||||
'EXCEPTION_HANDLER',
|
'EXCEPTION_HANDLER',
|
||||||
'TEST_REQUEST_RENDERER_CLASSES',
|
'TEST_REQUEST_RENDERER_CLASSES',
|
||||||
'UNAUTHENTICATED_USER',
|
'UNAUTHENTICATED_USER',
|
||||||
|
|
66
tests/test_db_routers.py
Normal file
66
tests/test_db_routers.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from rest_framework import db_routers, generics
|
||||||
|
from rest_framework.test import APIRequestFactory
|
||||||
|
|
||||||
|
from .models import BasicModel
|
||||||
|
|
||||||
|
|
||||||
|
class NoRoutersView(generics.GenericAPIView):
|
||||||
|
db_router_classes = []
|
||||||
|
queryset = BasicModel.objects.all()
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoRouterView(generics.GenericAPIView):
|
||||||
|
db_router_classes = [db_routers.DjangoDbRouter]
|
||||||
|
queryset = BasicModel.objects.all()
|
||||||
|
|
||||||
|
|
||||||
|
class Router(object):
|
||||||
|
def db_for_read(self, model, **hints):
|
||||||
|
return 'db_for_read'
|
||||||
|
|
||||||
|
def db_for_write(self, model, **hints):
|
||||||
|
return 'db_for_write'
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRoutersTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = APIRequestFactory()
|
||||||
|
|
||||||
|
def assertViewUsesAliasForQuerySet(self, view_class, http_method, db_alias):
|
||||||
|
with self.settings(DATABASE_ROUTERS=['tests.test_db_routers.Router']):
|
||||||
|
request = getattr(self.factory, http_method)('/')
|
||||||
|
view = view_class()
|
||||||
|
view.request = request
|
||||||
|
|
||||||
|
queryset = view.get_queryset()
|
||||||
|
self.assertEqual(queryset._db, db_alias)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoDbRouters(BaseRoutersTest):
|
||||||
|
def test_get(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='get', db_alias=None)
|
||||||
|
|
||||||
|
def test_post(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='post', db_alias=None)
|
||||||
|
|
||||||
|
def test_put(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='put', db_alias=None)
|
||||||
|
|
||||||
|
def test_patch(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(NoRoutersView, http_method='patch', db_alias=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDjangoDbRouter(BaseRoutersTest):
|
||||||
|
def test_get(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='get', db_alias=None)
|
||||||
|
|
||||||
|
def test_post(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='post', db_alias='db_for_write')
|
||||||
|
|
||||||
|
def test_put(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='put', db_alias='db_for_write')
|
||||||
|
|
||||||
|
def test_patch(self):
|
||||||
|
self.assertViewUsesAliasForQuerySet(DjangoRouterView, http_method='patch', db_alias='db_for_write')
|
Loading…
Reference in New Issue
Block a user