diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index 5c804c0..2be8534 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -62,10 +62,12 @@ class Database(object): def gen(): yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8') + first_instance.set_database(self) yield (first_instance.to_tsv(include_readonly=False) + '\n').encode('utf-8') # Collect lines in batches of batch_size batch = [] for instance in i: + instance.set_database(self) batch.append(instance.to_tsv(include_readonly=False)) if len(batch) >= batch_size: # Return the current batch of lines diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index 4b7b1b2..6b82244 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -116,6 +116,13 @@ class Model(with_metaclass(ModelBase)): assert isinstance(db, Database), "database must be database.Database instance" self._database = db + def get_database(self): + """ + Gets _database attribute for current model instance + :return: database.Database instance, model was inserted or selected from or None + """ + return self._database + def get_field(self, name): ''' Get a Field instance given its name, or None if not found. diff --git a/tests/test_database.py b/tests/test_database.py index a3ad15c..aa7a4df 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -24,6 +24,8 @@ class DatabaseTestCase(unittest.TestCase): def _insert_and_check(self, data, count): self.database.insert(data) self.assertEquals(count, self.database.count(Person)) + for instance in data: + self.assertEquals(self.database, instance.get_database()) def test_insert__generator(self): self._insert_and_check(self._sample_data(), len(data)) @@ -53,6 +55,8 @@ class DatabaseTestCase(unittest.TestCase): self.assertEquals(results[0].height, 1.72) self.assertEquals(results[1].last_name, 'Scott') self.assertEquals(results[1].height, 1.70) + self.assertEqual(results[0].get_database(), self.database) + self.assertEqual(results[1].get_database(), self.database) def test_select_partial_fields(self): self._insert_and_check(self._sample_data(), len(data)) @@ -63,6 +67,8 @@ class DatabaseTestCase(unittest.TestCase): self.assertEquals(results[0].height, 0) # default value self.assertEquals(results[1].last_name, 'Scott') self.assertEquals(results[1].height, 0) # default value + self.assertEqual(results[0].get_database(), self.database) + self.assertEqual(results[1].get_database(), self.database) def test_select_ad_hoc_model(self): self._insert_and_check(self._sample_data(), len(data)) @@ -74,6 +80,8 @@ class DatabaseTestCase(unittest.TestCase): self.assertEquals(results[0].height, 1.72) self.assertEquals(results[1].last_name, 'Scott') self.assertEquals(results[1].height, 1.70) + self.assertEqual(results[0].get_database(), self.database) + self.assertEqual(results[1].get_database(), self.database) def test_pagination(self): self._insert_and_check(self._sample_data(), len(data))