M1ha 2020-02-07 10:49:49 +05:00
parent 818704989b
commit f9e33a6845
4 changed files with 82 additions and 6 deletions

View File

@ -43,7 +43,7 @@ def update_returning_pk(qs: QuerySet, updates: dict) -> Set[Any]:
:return: A set of primary keys :return: A set of primary keys
""" """
qs._for_write = True qs._for_write = True
if django_pg_returning_available(qs.db): if django_pg_returning_available(qs.db) and hasattr(qs, 'update_returning'):
pk_name = qs.model._meta.pk.name pk_name = qs.model._meta.pk.name
qs = qs.only(pk_name).update_returning(**updates) qs = qs.only(pk_name).update_returning(**updates)
pks = set(qs.values_list(pk_name, flat=True)) pks = set(qs.values_list(pk_name, flat=True))

View File

@ -71,17 +71,17 @@ class ClickHouseSyncBulkUpdateQuerySetMixin(ClickHouseSyncRegisterMixin, BulkUpd
return returning return returning
def bulk_update(self, *args, **kwargs): def pg_bulk_update(self, *args, **kwargs):
original_returning = kwargs.pop('returning', None) original_returning = kwargs.pop('returning', None)
kwargs['returning'] = self._update_returning_param(original_returning) kwargs['returning'] = self._update_returning_param(original_returning)
result = super().bulk_update(*args, **kwargs) result = super().pg_bulk_update(*args, **kwargs)
self._register_ops('update', result) self._register_ops('update', result)
return result.count() if original_returning is None else result return result.count() if original_returning is None else result
def bulk_update_or_create(self, *args, **kwargs): def pg_bulk_update_or_create(self, *args, **kwargs):
original_returning = kwargs.pop('returning', None) original_returning = kwargs.pop('returning', None)
kwargs['returning'] = self._update_returning_param(original_returning) kwargs['returning'] = self._update_returning_param(original_returning)
result = super().bulk_update_or_create(*args, **kwargs) result = super().pg_bulk_update_or_create(*args, **kwargs)
self._register_ops('update', result) self._register_ops('update', result)
return result.count() if original_returning is None else result return result.count() if original_returning is None else result
@ -97,6 +97,19 @@ class ClickHouseSyncQuerySetMixin(ClickHouseSyncRegisterMixin):
self._register_ops('insert', objs) self._register_ops('insert', objs)
return objs return objs
def bulk_update(self, objs, *args, **kwargs):
objs = list(objs)
# No need to register anything, if there are no objects.
# If objects are not models, django-pg-bulk-update method is called and pg_bulk_update will register items
if len(objs) == 0 or not isinstance(objs[0], DjangoModel):
return super().bulk_update(objs, *args, **kwargs)
# native django bulk_update requires each object to have a primary key
res = super().bulk_update(objs, *args, **kwargs)
self._register_ops('update', objs)
return res
# I add library dependant mixins to base classes only if libraries are installed # I add library dependant mixins to base classes only if libraries are installed
qs_bases = [ClickHouseSyncQuerySetMixin] qs_bases = [ClickHouseSyncQuerySetMixin]

View File

@ -2,10 +2,15 @@
This file contains sample models to use in tests This file contains sample models to use in tests
""" """
from django.db import models from django.db import models
from django.db.models import QuerySet
from django.db.models.manager import BaseManager from django.db.models.manager import BaseManager
from django_pg_returning import UpdateReturningModel from django_pg_returning import UpdateReturningModel
from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet, ClickHouseSyncQuerySetMixin
class NativeQuerySet(ClickHouseSyncQuerySetMixin, QuerySet):
pass
class TestQuerySet(ClickHouseSyncQuerySet): class TestQuerySet(ClickHouseSyncQuerySet):
@ -16,8 +21,13 @@ class TestManager(BaseManager.from_queryset(TestQuerySet)):
pass pass
class NativeManager(BaseManager.from_queryset(NativeQuerySet)):
pass
class TestModel(UpdateReturningModel, ClickHouseSyncModel): class TestModel(UpdateReturningModel, ClickHouseSyncModel):
objects = TestManager() objects = TestManager()
native_objects = NativeManager()
value = models.IntegerField() value = models.IntegerField()
created_date = models.DateField() created_date = models.DateField()
@ -26,6 +36,7 @@ class TestModel(UpdateReturningModel, ClickHouseSyncModel):
class SecondaryTestModel(UpdateReturningModel, ClickHouseSyncModel): class SecondaryTestModel(UpdateReturningModel, ClickHouseSyncModel):
objects = TestManager() objects = TestManager()
native_objects = NativeManager()
value = models.IntegerField() value = models.IntegerField()
created_date = models.DateField() created_date = models.DateField()

View File

@ -1,6 +1,7 @@
import datetime import datetime
from django.test import TransactionTestCase from django.test import TransactionTestCase
from django.utils.timezone import now
from tests.clickhouse_models import ClickHouseTestModel, ClickHouseSecondTestModel, ClickHouseCollapseTestModel, \ from tests.clickhouse_models import ClickHouseTestModel, ClickHouseSecondTestModel, ClickHouseCollapseTestModel, \
ClickHouseMultiTestModel ClickHouseMultiTestModel
@ -60,6 +61,57 @@ class TestOperations(TransactionTestCase):
self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_native_bulk_update(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
for instance in items:
instance.value = instance.pk * 10
self.django_model.native_objects.bulk_update(items, ['value'])
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.assertEqual(2, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_pg_bulk_update(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.django_model.objects.pg_bulk_update([
{'id': instance.pk, 'value': instance.pk * 10}
for instance in items
])
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.assertEqual(2, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_pg_bulk_update_or_create(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
data = [{
'id': instance.pk,
'value': instance.pk * 10,
'created_date': instance.created_date,
'created': instance.created
} for instance in items] + [{'id': 11, 'value': 110, 'created_date': datetime.date.today(), 'created': now()}]
self.django_model.objects.pg_bulk_update_or_create(data)
items = list(self.django_model.objects.filter(pk__in={1, 2, 11}))
self.assertEqual(3, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_get_or_create(self): def test_get_or_create(self):
instance, created = self.django_model.objects. \ instance, created = self.django_model.objects. \
get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(), get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(),