From 54923e03bf2abdd98bd97f124e1e14862ae923d0 Mon Sep 17 00:00:00 2001 From: M1ha Date: Fri, 7 Feb 2020 11:07:52 +0500 Subject: [PATCH] Fixed issue https://github.com/carrotquest/django-clickhouse/issues/11 --- src/django_clickhouse/models.py | 23 +++++++++++++++-------- tests/test_models.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/django_clickhouse/models.py b/src/django_clickhouse/models.py index 46e8b47..18b2947 100644 --- a/src/django_clickhouse/models.py +++ b/src/django_clickhouse/models.py @@ -5,6 +5,7 @@ It saves all operations to storage in order to write them to ClickHouse later. from typing import Optional, Any, Type, Set +import functools import six from django.db import transaction from django.db.models import QuerySet as DjangoQuerySet, Model as DjangoModel, Manager as DjangoManager @@ -71,19 +72,25 @@ class ClickHouseSyncBulkUpdateQuerySetMixin(ClickHouseSyncRegisterMixin, BulkUpd return returning - def pg_bulk_update(self, *args, **kwargs): + def _decorate_method(self, name: str, operation: str, args, kwargs): + if not hasattr(super(), name): + raise AttributeError(name) + + func = getattr(super(), name) original_returning = kwargs.pop('returning', None) kwargs['returning'] = self._update_returning_param(original_returning) - result = super().pg_bulk_update(*args, **kwargs) - self._register_ops('update', result) + result = func(*args, **kwargs) + self._register_ops(operation, result) return result.count() if original_returning is None else result + def pg_bulk_update(self, *args, **kwargs): + return self._decorate_method('pg_bulk_update', 'update', args, kwargs) + def pg_bulk_update_or_create(self, *args, **kwargs): - original_returning = kwargs.pop('returning', None) - kwargs['returning'] = self._update_returning_param(original_returning) - result = super().pg_bulk_update_or_create(*args, **kwargs) - self._register_ops('update', result) - return result.count() if original_returning is None else result + return self._decorate_method('pg_bulk_update_or_create', 'update', args, kwargs) + + def pg_bulk_create(self, *args, **kwargs): + return self._decorate_method('pg_bulk_create', 'insert', args, kwargs) class ClickHouseSyncQuerySetMixin(ClickHouseSyncRegisterMixin): diff --git a/tests/test_models.py b/tests/test_models.py index 8a44930..97276ce 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -76,6 +76,24 @@ class TestOperations(TransactionTestCase): self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + def test_pg_bulk_create(self): + now_dt = now() + res = self.django_model.objects.pg_bulk_create([ + {'value': i, 'created': now_dt, 'created_date': now_dt.date()} + for i in range(5) + ]) + self.assertEqual(5, res) + + items = list(self.django_model.objects.filter(value__lt=100).order_by('value')) + self.assertEqual(5, len(items)) + for i, instance in enumerate(items): + self.assertEqual(instance.created, now_dt) + self.assertEqual(instance.created_date, now_dt.date()) + self.assertEqual(i, instance.value) + + self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items}, + set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10))) + def test_pg_bulk_update(self): items = list(self.django_model.objects.filter(pk__in={1, 2}))