mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-25 10:13:45 +03:00
1) Added get_database method to Model
2) Added some assertions in tests for adding _database attribute in selects and inserts 3) database.insert() method sets _database
This commit is contained in:
parent
adff766246
commit
6f975a801c
|
@ -62,10 +62,12 @@ class Database(object):
|
|||
|
||||
def gen():
|
||||
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
|
||||
first_instance.set_database(self)
|
||||
yield (first_instance.to_tsv(include_readonly=False) + '\n').encode('utf-8')
|
||||
# Collect lines in batches of batch_size
|
||||
batch = []
|
||||
for instance in i:
|
||||
instance.set_database(self)
|
||||
batch.append(instance.to_tsv(include_readonly=False))
|
||||
if len(batch) >= batch_size:
|
||||
# Return the current batch of lines
|
||||
|
|
|
@ -116,6 +116,13 @@ class Model(with_metaclass(ModelBase)):
|
|||
assert isinstance(db, Database), "database must be database.Database instance"
|
||||
self._database = db
|
||||
|
||||
def get_database(self):
|
||||
"""
|
||||
Gets _database attribute for current model instance
|
||||
:return: database.Database instance, model was inserted or selected from or None
|
||||
"""
|
||||
return self._database
|
||||
|
||||
def get_field(self, name):
|
||||
'''
|
||||
Get a Field instance given its name, or None if not found.
|
||||
|
|
|
@ -24,6 +24,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||
def _insert_and_check(self, data, count):
|
||||
self.database.insert(data)
|
||||
self.assertEquals(count, self.database.count(Person))
|
||||
for instance in data:
|
||||
self.assertEquals(self.database, instance.get_database())
|
||||
|
||||
def test_insert__generator(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
|
@ -53,6 +55,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||
self.assertEquals(results[0].height, 1.72)
|
||||
self.assertEquals(results[1].last_name, 'Scott')
|
||||
self.assertEquals(results[1].height, 1.70)
|
||||
self.assertEqual(results[0].get_database(), self.database)
|
||||
self.assertEqual(results[1].get_database(), self.database)
|
||||
|
||||
def test_select_partial_fields(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
|
@ -63,6 +67,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||
self.assertEquals(results[0].height, 0) # default value
|
||||
self.assertEquals(results[1].last_name, 'Scott')
|
||||
self.assertEquals(results[1].height, 0) # default value
|
||||
self.assertEqual(results[0].get_database(), self.database)
|
||||
self.assertEqual(results[1].get_database(), self.database)
|
||||
|
||||
def test_select_ad_hoc_model(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
|
@ -74,6 +80,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||
self.assertEquals(results[0].height, 1.72)
|
||||
self.assertEquals(results[1].last_name, 'Scott')
|
||||
self.assertEquals(results[1].height, 1.70)
|
||||
self.assertEqual(results[0].get_database(), self.database)
|
||||
self.assertEqual(results[1].get_database(), self.database)
|
||||
|
||||
def test_pagination(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
|
|
Loading…
Reference in New Issue
Block a user