diff --git a/src/django_clickhouse/query.py b/src/django_clickhouse/query.py index 16e5612..ee38045 100644 --- a/src/django_clickhouse/query.py +++ b/src/django_clickhouse/query.py @@ -1,4 +1,4 @@ -from typing import Optional, Iterable +from typing import Optional, Iterable, List from copy import copy from infi.clickhouse_orm.database import Database @@ -37,7 +37,7 @@ class QuerySet(InfiQuerySet): """ if not self._db: if self._db_alias: - self._db = connections[self._db_alias] + self._db = connections[self._db_alias] else: self._db = self._model_cls.get_database(for_write=for_write) @@ -70,9 +70,9 @@ class QuerySet(InfiQuerySet): self.get_database(for_write=True).insert([instance]) return instance - def bulk_create(self, model_instances, batch_size=1000): # type: (Iterable[InfiModel], int) + def bulk_create(self, model_instances, batch_size=1000): # type: (Iterable[InfiModel], int) -> List[InfiModel] self.get_database(for_write=True).insert(model_instances=model_instances, batch_size=batch_size) - return model_instances + return list(model_instances) class AggregateQuerySet(QuerySet, InfiAggregateQuerySet): diff --git a/tests/test_query.py b/tests/test_query.py index d935854..95f14cf 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -25,7 +25,7 @@ class TestQuerySet(TestCase): def test_all(self): self.db.insert([ClickHouseTestModel(id=i, created_date=datetime.date.today(), value=i) for i in range(1, 4)]) qs = ClickHouseTestModel.objects.all() - print(qs.get_database(for_write=True).db_name) + self.assertIsInstance(qs, QuerySet) self.assertEqual(3, qs.count()) diff --git a/tests/test_sync.py b/tests/test_sync.py index 712b637..1ceb130 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -28,7 +28,7 @@ class SyncTest(TransactionTestCase): obj = TestModel.objects.create(value=1, created_date=datetime.date.today()) ClickHouseTestModel.sync_batch_from_storage() - synced_data = list(ClickHouseTestModel.objects_in(connections['default'])) + synced_data = list(ClickHouseTestModel.objects.all()) self.assertEqual(1, len(synced_data)) self.assertEqual(obj.created_date, synced_data[0].created_date) self.assertEqual(obj.value, synced_data[0].value) @@ -41,7 +41,7 @@ class SyncTest(TransactionTestCase): ClickHouseCollapseTestModel.sync_batch_from_storage() # sync_batch_from_storage uses FINAL, so data would be collapsed by now - synced_data = list(ClickHouseCollapseTestModel.objects_in(connections['default'])) + synced_data = list(ClickHouseCollapseTestModel.objects.all()) self.assertEqual(1, len(synced_data)) self.assertEqual(obj.created_date, synced_data[0].created_date) self.assertEqual(obj.value, synced_data[0].value) @@ -65,7 +65,7 @@ class SyncTest(TransactionTestCase): ClickHouseCollapseTestModel.sync_batch_from_storage() # sync_batch_from_storage uses FINAL, so data would be collapsed by now - synced_data = list(ClickHouseCollapseTestModel.objects_in(connections['default'])) + synced_data = list(ClickHouseCollapseTestModel.objects.all()) self.assertEqual(0, len(synced_data)) def test_multi_model(self): @@ -74,14 +74,14 @@ class SyncTest(TransactionTestCase): obj.save() ClickHouseMultiTestModel.sync_batch_from_storage() - synced_data = list(ClickHouseTestModel.objects_in(connections['default'])) + synced_data = list(ClickHouseTestModel.objects.all()) self.assertEqual(1, len(synced_data)) self.assertEqual(obj.created_date, synced_data[0].created_date) self.assertEqual(obj.value, synced_data[0].value) self.assertEqual(obj.id, synced_data[0].id) # sync_batch_from_storage uses FINAL, so data would be collapsed by now - synced_data = list(ClickHouseCollapseTestModel.objects_in(connections['default'])) + synced_data = list(ClickHouseCollapseTestModel.objects.all()) self.assertEqual(1, len(synced_data)) self.assertEqual(obj.created_date, synced_data[0].created_date) self.assertEqual(obj.value, synced_data[0].value)