M1ha 2020-02-07 10:49:49 +05:00
parent 818704989b
commit f9e33a6845
4 changed files with 82 additions and 6 deletions

View File

@ -43,7 +43,7 @@ def update_returning_pk(qs: QuerySet, updates: dict) -> Set[Any]:
:return: A set of primary keys
"""
qs._for_write = True
if django_pg_returning_available(qs.db):
if django_pg_returning_available(qs.db) and hasattr(qs, 'update_returning'):
pk_name = qs.model._meta.pk.name
qs = qs.only(pk_name).update_returning(**updates)
pks = set(qs.values_list(pk_name, flat=True))

View File

@ -71,17 +71,17 @@ class ClickHouseSyncBulkUpdateQuerySetMixin(ClickHouseSyncRegisterMixin, BulkUpd
return returning
def bulk_update(self, *args, **kwargs):
def pg_bulk_update(self, *args, **kwargs):
original_returning = kwargs.pop('returning', None)
kwargs['returning'] = self._update_returning_param(original_returning)
result = super().bulk_update(*args, **kwargs)
result = super().pg_bulk_update(*args, **kwargs)
self._register_ops('update', result)
return result.count() if original_returning is None else result
def bulk_update_or_create(self, *args, **kwargs):
def pg_bulk_update_or_create(self, *args, **kwargs):
original_returning = kwargs.pop('returning', None)
kwargs['returning'] = self._update_returning_param(original_returning)
result = super().bulk_update_or_create(*args, **kwargs)
result = super().pg_bulk_update_or_create(*args, **kwargs)
self._register_ops('update', result)
return result.count() if original_returning is None else result
@ -97,6 +97,19 @@ class ClickHouseSyncQuerySetMixin(ClickHouseSyncRegisterMixin):
self._register_ops('insert', objs)
return objs
def bulk_update(self, objs, *args, **kwargs):
objs = list(objs)
# No need to register anything, if there are no objects.
# If objects are not models, django-pg-bulk-update method is called and pg_bulk_update will register items
if len(objs) == 0 or not isinstance(objs[0], DjangoModel):
return super().bulk_update(objs, *args, **kwargs)
# native django bulk_update requires each object to have a primary key
res = super().bulk_update(objs, *args, **kwargs)
self._register_ops('update', objs)
return res
# I add library dependant mixins to base classes only if libraries are installed
qs_bases = [ClickHouseSyncQuerySetMixin]

View File

@ -2,10 +2,15 @@
This file contains sample models to use in tests
"""
from django.db import models
from django.db.models import QuerySet
from django.db.models.manager import BaseManager
from django_pg_returning import UpdateReturningModel
from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet
from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet, ClickHouseSyncQuerySetMixin
class NativeQuerySet(ClickHouseSyncQuerySetMixin, QuerySet):
pass
class TestQuerySet(ClickHouseSyncQuerySet):
@ -16,8 +21,13 @@ class TestManager(BaseManager.from_queryset(TestQuerySet)):
pass
class NativeManager(BaseManager.from_queryset(NativeQuerySet)):
pass
class TestModel(UpdateReturningModel, ClickHouseSyncModel):
objects = TestManager()
native_objects = NativeManager()
value = models.IntegerField()
created_date = models.DateField()
@ -26,6 +36,7 @@ class TestModel(UpdateReturningModel, ClickHouseSyncModel):
class SecondaryTestModel(UpdateReturningModel, ClickHouseSyncModel):
objects = TestManager()
native_objects = NativeManager()
value = models.IntegerField()
created_date = models.DateField()

View File

@ -1,6 +1,7 @@
import datetime
from django.test import TransactionTestCase
from django.utils.timezone import now
from tests.clickhouse_models import ClickHouseTestModel, ClickHouseSecondTestModel, ClickHouseCollapseTestModel, \
ClickHouseMultiTestModel
@ -60,6 +61,57 @@ class TestOperations(TransactionTestCase):
self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_native_bulk_update(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
for instance in items:
instance.value = instance.pk * 10
self.django_model.native_objects.bulk_update(items, ['value'])
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.assertEqual(2, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_pg_bulk_update(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.django_model.objects.pg_bulk_update([
{'id': instance.pk, 'value': instance.pk * 10}
for instance in items
])
items = list(self.django_model.objects.filter(pk__in={1, 2}))
self.assertEqual(2, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_pg_bulk_update_or_create(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))
data = [{
'id': instance.pk,
'value': instance.pk * 10,
'created_date': instance.created_date,
'created': instance.created
} for instance in items] + [{'id': 11, 'value': 110, 'created_date': datetime.date.today(), 'created': now()}]
self.django_model.objects.pg_bulk_update_or_create(data)
items = list(self.django_model.objects.filter(pk__in={1, 2, 11}))
self.assertEqual(3, len(items))
for instance in items:
self.assertEqual(instance.value, instance.pk * 10)
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
def test_get_or_create(self):
instance, created = self.django_model.objects. \
get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(),