mirror of
https://github.com/carrotquest/django-clickhouse.git
synced 2025-02-11 16:30:33 +03:00
This commit is contained in:
parent
818704989b
commit
f9e33a6845
|
@ -43,7 +43,7 @@ def update_returning_pk(qs: QuerySet, updates: dict) -> Set[Any]:
|
|||
:return: A set of primary keys
|
||||
"""
|
||||
qs._for_write = True
|
||||
if django_pg_returning_available(qs.db):
|
||||
if django_pg_returning_available(qs.db) and hasattr(qs, 'update_returning'):
|
||||
pk_name = qs.model._meta.pk.name
|
||||
qs = qs.only(pk_name).update_returning(**updates)
|
||||
pks = set(qs.values_list(pk_name, flat=True))
|
||||
|
|
|
@ -71,17 +71,17 @@ class ClickHouseSyncBulkUpdateQuerySetMixin(ClickHouseSyncRegisterMixin, BulkUpd
|
|||
|
||||
return returning
|
||||
|
||||
def bulk_update(self, *args, **kwargs):
|
||||
def pg_bulk_update(self, *args, **kwargs):
|
||||
original_returning = kwargs.pop('returning', None)
|
||||
kwargs['returning'] = self._update_returning_param(original_returning)
|
||||
result = super().bulk_update(*args, **kwargs)
|
||||
result = super().pg_bulk_update(*args, **kwargs)
|
||||
self._register_ops('update', result)
|
||||
return result.count() if original_returning is None else result
|
||||
|
||||
def bulk_update_or_create(self, *args, **kwargs):
|
||||
def pg_bulk_update_or_create(self, *args, **kwargs):
|
||||
original_returning = kwargs.pop('returning', None)
|
||||
kwargs['returning'] = self._update_returning_param(original_returning)
|
||||
result = super().bulk_update_or_create(*args, **kwargs)
|
||||
result = super().pg_bulk_update_or_create(*args, **kwargs)
|
||||
self._register_ops('update', result)
|
||||
return result.count() if original_returning is None else result
|
||||
|
||||
|
@ -97,6 +97,19 @@ class ClickHouseSyncQuerySetMixin(ClickHouseSyncRegisterMixin):
|
|||
self._register_ops('insert', objs)
|
||||
return objs
|
||||
|
||||
def bulk_update(self, objs, *args, **kwargs):
|
||||
objs = list(objs)
|
||||
|
||||
# No need to register anything, if there are no objects.
|
||||
# If objects are not models, django-pg-bulk-update method is called and pg_bulk_update will register items
|
||||
if len(objs) == 0 or not isinstance(objs[0], DjangoModel):
|
||||
return super().bulk_update(objs, *args, **kwargs)
|
||||
|
||||
# native django bulk_update requires each object to have a primary key
|
||||
res = super().bulk_update(objs, *args, **kwargs)
|
||||
self._register_ops('update', objs)
|
||||
return res
|
||||
|
||||
|
||||
# I add library dependant mixins to base classes only if libraries are installed
|
||||
qs_bases = [ClickHouseSyncQuerySetMixin]
|
||||
|
|
|
@ -2,10 +2,15 @@
|
|||
This file contains sample models to use in tests
|
||||
"""
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.manager import BaseManager
|
||||
from django_pg_returning import UpdateReturningModel
|
||||
|
||||
from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet
|
||||
from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet, ClickHouseSyncQuerySetMixin
|
||||
|
||||
|
||||
class NativeQuerySet(ClickHouseSyncQuerySetMixin, QuerySet):
|
||||
pass
|
||||
|
||||
|
||||
class TestQuerySet(ClickHouseSyncQuerySet):
|
||||
|
@ -16,8 +21,13 @@ class TestManager(BaseManager.from_queryset(TestQuerySet)):
|
|||
pass
|
||||
|
||||
|
||||
class NativeManager(BaseManager.from_queryset(NativeQuerySet)):
|
||||
pass
|
||||
|
||||
|
||||
class TestModel(UpdateReturningModel, ClickHouseSyncModel):
|
||||
objects = TestManager()
|
||||
native_objects = NativeManager()
|
||||
|
||||
value = models.IntegerField()
|
||||
created_date = models.DateField()
|
||||
|
@ -26,6 +36,7 @@ class TestModel(UpdateReturningModel, ClickHouseSyncModel):
|
|||
|
||||
class SecondaryTestModel(UpdateReturningModel, ClickHouseSyncModel):
|
||||
objects = TestManager()
|
||||
native_objects = NativeManager()
|
||||
|
||||
value = models.IntegerField()
|
||||
created_date = models.DateField()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
from django.utils.timezone import now
|
||||
|
||||
from tests.clickhouse_models import ClickHouseTestModel, ClickHouseSecondTestModel, ClickHouseCollapseTestModel, \
|
||||
ClickHouseMultiTestModel
|
||||
|
@ -60,6 +61,57 @@ class TestOperations(TransactionTestCase):
|
|||
self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
|
||||
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
|
||||
|
||||
def test_native_bulk_update(self):
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2}))
|
||||
for instance in items:
|
||||
instance.value = instance.pk * 10
|
||||
|
||||
self.django_model.native_objects.bulk_update(items, ['value'])
|
||||
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2}))
|
||||
self.assertEqual(2, len(items))
|
||||
for instance in items:
|
||||
self.assertEqual(instance.value, instance.pk * 10)
|
||||
|
||||
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
|
||||
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
|
||||
|
||||
def test_pg_bulk_update(self):
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2}))
|
||||
|
||||
self.django_model.objects.pg_bulk_update([
|
||||
{'id': instance.pk, 'value': instance.pk * 10}
|
||||
for instance in items
|
||||
])
|
||||
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2}))
|
||||
self.assertEqual(2, len(items))
|
||||
for instance in items:
|
||||
self.assertEqual(instance.value, instance.pk * 10)
|
||||
|
||||
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
|
||||
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
|
||||
|
||||
def test_pg_bulk_update_or_create(self):
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2}))
|
||||
|
||||
data = [{
|
||||
'id': instance.pk,
|
||||
'value': instance.pk * 10,
|
||||
'created_date': instance.created_date,
|
||||
'created': instance.created
|
||||
} for instance in items] + [{'id': 11, 'value': 110, 'created_date': datetime.date.today(), 'created': now()}]
|
||||
|
||||
self.django_model.objects.pg_bulk_update_or_create(data)
|
||||
|
||||
items = list(self.django_model.objects.filter(pk__in={1, 2, 11}))
|
||||
self.assertEqual(3, len(items))
|
||||
for instance in items:
|
||||
self.assertEqual(instance.value, instance.pk * 10)
|
||||
|
||||
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
|
||||
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))
|
||||
|
||||
def test_get_or_create(self):
|
||||
instance, created = self.django_model.objects. \
|
||||
get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(),
|
||||
|
|
Loading…
Reference in New Issue
Block a user