diff --git a/src/django_clickhouse/engines.py b/src/django_clickhouse/engines.py index 476f56a..90beb84 100644 --- a/src/django_clickhouse/engines.py +++ b/src/django_clickhouse/engines.py @@ -1,6 +1,7 @@ """ This file contains wrappers for infi.clckhouse_orm engines to use in django-clickhouse """ +import datetime from typing import List, TypeVar, Type from django.db.models import Model as DjangoModel @@ -8,6 +9,7 @@ from infi.clickhouse_orm import engines as infi_engines from infi.clickhouse_orm.models import Model as InfiModel from statsd.defaults.django import statsd +from django_clickhouse.database import connections from .configuration import config from .utils import format_datetime @@ -46,11 +48,7 @@ class CollapsingMergeTree(InsertOnlyEngineMixin, infi_engines.CollapsingMergeTre self.version_col = kwargs.pop('version_col', None) super(CollapsingMergeTree, self).__init__(*args, **kwargs) - def _get_final_versions_by_version(self, model_cls, min_date, max_date, object_pks): - db = model_cls.get_database() - min_date = format_datetime(min_date, 0, db_alias=db.db_alias) - max_date = format_datetime(min_date, 0, day_end=True, db_alias=db.db_alias) - + def _get_final_versions_by_version(self, db_alias, model_cls, min_date, max_date, object_pks): query = """ SELECT * FROM $table WHERE (`{pk_column}`, `{version_col}`) IN ( SELECT `{pk_column}`, MAX(`{version_col}`) @@ -60,24 +58,20 @@ class CollapsingMergeTree(InsertOnlyEngineMixin, infi_engines.CollapsingMergeTre GROUP BY `{pk_column}` ) """.format(version_col=self.version_col, date_col=self.date_col, pk_column=self.pk_column, - min_date=min_date.isoformat(), max_date=max_date.isoformat(), object_pks=','.join(object_pks)) + min_date=min_date, max_date=max_date, object_pks=','.join(object_pks)) - qs = db.select(query, model_class=model_cls) + qs = connections[db_alias].select(query, model_class=model_cls) return list(qs) - def _get_final_versions_by_final(self, model_cls, min_date, max_date, object_pks): - db = model_cls.get_database() - min_date = format_datetime(min_date, 0, db_alias=db.db_alias) - max_date = format_datetime(min_date, 0, day_end=True, db_alias=db.db_alias) - + def _get_final_versions_by_final(self, db_alias, model_cls, min_date, max_date, object_pks): query = """ SELECT * FROM $table FINAL WHERE `{date_col}` >= '{min_date}' AND `{date_col}` <= '{max_date}' AND `{pk_column}` IN ({object_pks}) """ - query = query.format(date_col=self.date_col, pk_column=self.pk_column, min_date=min_date.isoformat(), - max_date=max_date.isoformat(), object_pks=','.join(object_pks)) - qs = db.select(query, model_class=model_cls) + query = query.format(date_col=self.date_col, pk_column=self.pk_column, min_date=min_date, + max_date=max_date, object_pks=','.join(object_pks)) + qs = connections[db_alias].select(query, model_class=model_cls) return list(qs) def get_final_versions(self, model_cls, objects): @@ -105,10 +99,22 @@ class CollapsingMergeTree(InsertOnlyEngineMixin, infi_engines.CollapsingMergeTre object_pks = [str(getattr(obj, self.pk_column)) for obj in objects] - if self.version_col: - return self._get_final_versions_by_version(model_cls, min_date, max_date, object_pks) + db_alias = model_cls.get_database_alias() + + if isinstance(min_date, datetime.date): + min_date = min_date.isoformat() else: - return self._get_final_versions_by_final(model_cls, min_date, max_date, object_pks) + min_date = format_datetime(min_date, 0, db_alias=db_alias) + + if isinstance(max_date, datetime.date): + max_date = max_date.isoformat() + else: + max_date = format_datetime(max_date, 0, day_end=True, db_alias=db_alias) + + if self.version_col: + return self._get_final_versions_by_version(db_alias, model_cls, min_date, max_date, object_pks) + else: + return self._get_final_versions_by_final(db_alias, model_cls, min_date, max_date, object_pks) def get_insert_batch(self, model_cls, objects): # type: (Type[T], List[DjangoModel]) -> List[T] diff --git a/tests/test_engines.py b/tests/test_engines.py new file mode 100644 index 0000000..555a08f --- /dev/null +++ b/tests/test_engines.py @@ -0,0 +1,79 @@ +import datetime + +from django.test import TestCase +from django_clickhouse.migrations import migrate_app + +from django_clickhouse.database import connections +from tests.clickhouse_models import ClickHouseCollapseTestModel +from tests.models import TestModel + + +class CollapsingMergeTreeTest(TestCase): + fixtures = ['test_model'] + maxDiff = None + + collapse_fixture = [{ + "id": 1, + "created_date": "2018-01-01", + "sign": 1, + "version": 1 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": -1, + "version": 1 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": 1, + "version": 2 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": -1, + "version": 2 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": 1, + "version": 3 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": -1, + "version": 3 + }, { + "id": 1, + "created_date": "2018-01-01", + "sign": 1, + "version": 4 + }] + + def setUp(self): + self.db = connections['default'] + self.db.drop_database() + self.db.create_database() + migrate_app('tests', 'default') + ClickHouseCollapseTestModel.get_storage().flush() + + ClickHouseCollapseTestModel.objects.bulk_create([ + ClickHouseCollapseTestModel(**item) for item in self.collapse_fixture + ]) + self.objects = TestModel.objects.filter(id=1) + + def test_get_final_versions_by_final(self): + final_versions = ClickHouseCollapseTestModel.engine.get_final_versions(ClickHouseCollapseTestModel, + self.objects) + + self.assertEqual(1, len(final_versions)) + self.assertDictEqual({'id': 1, 'sign': 1, 'version': 4, 'value': 0, 'created_date': datetime.date(2018, 1, 1)}, + final_versions[0].to_dict()) + + def test_get_final_versions_by_version(self): + ClickHouseCollapseTestModel.engine.version_col = 'version' + final_versions = ClickHouseCollapseTestModel.engine.get_final_versions(ClickHouseCollapseTestModel, + self.objects) + + self.assertEqual(1, len(final_versions)) + self.assertDictEqual({'id': 1, 'sign': 1, 'version': 4, 'value': 0, 'created_date': datetime.date(2018, 1, 1)}, + final_versions[0].to_dict()) diff --git a/tests/test_sync.py b/tests/test_sync.py index 86ba1a5..4112de6 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -22,7 +22,6 @@ class SyncTest(TransactionTestCase): def setUp(self): self.db = connections['default'] self.db.drop_database() - self.db.db_exists = False self.db.create_database() migrate_app('tests', 'default') ClickHouseTestModel.get_storage().flush()