From f9e33a6845275fddb0bafef754351cefbe8fea3e Mon Sep 17 00:00:00 2001 From: M1ha Date: Fri, 7 Feb 2020 10:49:49 +0500 Subject: [PATCH] Fixed issue https://github.com/carrotquest/django-clickhouse/issues/9 --- src/django_clickhouse/compatibility.py | 2 +- src/django_clickhouse/models.py | 21 +++++++++-- tests/models.py | 13 ++++++- tests/test_models.py | 52 ++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/django_clickhouse/compatibility.py b/src/django_clickhouse/compatibility.py index ef856bf..2354d7a 100644 --- a/src/django_clickhouse/compatibility.py +++ b/src/django_clickhouse/compatibility.py @@ -43,7 +43,7 @@ def update_returning_pk(qs: QuerySet, updates: dict) -> Set[Any]: :return: A set of primary keys """ qs._for_write = True - if django_pg_returning_available(qs.db): + if django_pg_returning_available(qs.db) and hasattr(qs, 'update_returning'): pk_name = qs.model._meta.pk.name qs = qs.only(pk_name).update_returning(**updates) pks = set(qs.values_list(pk_name, flat=True)) diff --git a/src/django_clickhouse/models.py b/src/django_clickhouse/models.py index f500bcc..46e8b47 100644 --- a/src/django_clickhouse/models.py +++ b/src/django_clickhouse/models.py @@ -71,17 +71,17 @@ class ClickHouseSyncBulkUpdateQuerySetMixin(ClickHouseSyncRegisterMixin, BulkUpd return returning - def bulk_update(self, *args, **kwargs): + def pg_bulk_update(self, *args, **kwargs): original_returning = kwargs.pop('returning', None) kwargs['returning'] = self._update_returning_param(original_returning) - result = super().bulk_update(*args, **kwargs) + result = super().pg_bulk_update(*args, **kwargs) self._register_ops('update', result) return result.count() if original_returning is None else result - def bulk_update_or_create(self, *args, **kwargs): + def pg_bulk_update_or_create(self, *args, **kwargs): original_returning = kwargs.pop('returning', None) kwargs['returning'] = self._update_returning_param(original_returning) - result = super().bulk_update_or_create(*args, **kwargs) + result = super().pg_bulk_update_or_create(*args, **kwargs) self._register_ops('update', result) return result.count() if original_returning is None else result @@ -97,6 +97,19 @@ class ClickHouseSyncQuerySetMixin(ClickHouseSyncRegisterMixin): self._register_ops('insert', objs) return objs + def bulk_update(self, objs, *args, **kwargs): + objs = list(objs) + + # No need to register anything, if there are no objects. + # If objects are not models, django-pg-bulk-update method is called and pg_bulk_update will register items + if len(objs) == 0 or not isinstance(objs[0], DjangoModel): + return super().bulk_update(objs, *args, **kwargs) + + # native django bulk_update requires each object to have a primary key + res = super().bulk_update(objs, *args, **kwargs) + self._register_ops('update', objs) + return res + # I add library dependant mixins to base classes only if libraries are installed qs_bases = [ClickHouseSyncQuerySetMixin] diff --git a/tests/models.py b/tests/models.py index a0de1ec..e915252 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2,10 +2,15 @@ This file contains sample models to use in tests """ from django.db import models +from django.db.models import QuerySet from django.db.models.manager import BaseManager from django_pg_returning import UpdateReturningModel -from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet +from django_clickhouse.models import ClickHouseSyncModel, ClickHouseSyncQuerySet, ClickHouseSyncQuerySetMixin + + +class NativeQuerySet(ClickHouseSyncQuerySetMixin, QuerySet): + pass class TestQuerySet(ClickHouseSyncQuerySet): @@ -16,8 +21,13 @@ class TestManager(BaseManager.from_queryset(TestQuerySet)): pass +class NativeManager(BaseManager.from_queryset(NativeQuerySet)): + pass + + class TestModel(UpdateReturningModel, ClickHouseSyncModel): objects = TestManager() + native_objects = NativeManager() value = models.IntegerField() created_date = models.DateField() @@ -26,6 +36,7 @@ class TestModel(UpdateReturningModel, ClickHouseSyncModel): class SecondaryTestModel(UpdateReturningModel, ClickHouseSyncModel): objects = TestManager() + native_objects = NativeManager() value = models.IntegerField() created_date = models.DateField() diff --git a/tests/test_models.py b/tests/test_models.py index 978cdc2..8a44930 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ import datetime from django.test import TransactionTestCase +from django.utils.timezone import now from tests.clickhouse_models import ClickHouseTestModel, ClickHouseSecondTestModel, ClickHouseCollapseTestModel, \ ClickHouseMultiTestModel @@ -60,6 +61,57 @@ class TestOperations(TransactionTestCase): self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + def test_native_bulk_update(self): + items = list(self.django_model.objects.filter(pk__in={1, 2})) + for instance in items: + instance.value = instance.pk * 10 + + self.django_model.native_objects.bulk_update(items, ['value']) + + items = list(self.django_model.objects.filter(pk__in={1, 2})) + self.assertEqual(2, len(items)) + for instance in items: + self.assertEqual(instance.value, instance.pk * 10) + + self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, + set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + + def test_pg_bulk_update(self): + items = list(self.django_model.objects.filter(pk__in={1, 2})) + + self.django_model.objects.pg_bulk_update([ + {'id': instance.pk, 'value': instance.pk * 10} + for instance in items + ]) + + items = list(self.django_model.objects.filter(pk__in={1, 2})) + self.assertEqual(2, len(items)) + for instance in items: + self.assertEqual(instance.value, instance.pk * 10) + + self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, + set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + + def test_pg_bulk_update_or_create(self): + items = list(self.django_model.objects.filter(pk__in={1, 2})) + + data = [{ + 'id': instance.pk, + 'value': instance.pk * 10, + 'created_date': instance.created_date, + 'created': instance.created + } for instance in items] + [{'id': 11, 'value': 110, 'created_date': datetime.date.today(), 'created': now()}] + + self.django_model.objects.pg_bulk_update_or_create(data) + + items = list(self.django_model.objects.filter(pk__in={1, 2, 11})) + self.assertEqual(3, len(items)) + for instance in items: + self.assertEqual(instance.value, instance.pk * 10) + + self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, + set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + def test_get_or_create(self): instance, created = self.django_model.objects. \ get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(),