1) Added get_database method to Model

2) Added some assertions in tests for adding _database attribute in selects and inserts
3) database.insert() method sets _database
This commit is contained in:
M1ha 2017-02-09 17:10:48 +05:00 committed by Itai Shirav
parent adff766246
commit 6f975a801c
3 changed files with 17 additions and 0 deletions

View File

@ -62,10 +62,12 @@ class Database(object):
def gen():
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
first_instance.set_database(self)
yield (first_instance.to_tsv(include_readonly=False) + '\n').encode('utf-8')
# Collect lines in batches of batch_size
batch = []
for instance in i:
instance.set_database(self)
batch.append(instance.to_tsv(include_readonly=False))
if len(batch) >= batch_size:
# Return the current batch of lines

View File

@ -116,6 +116,13 @@ class Model(with_metaclass(ModelBase)):
assert isinstance(db, Database), "database must be database.Database instance"
self._database = db
def get_database(self):
"""
Gets _database attribute for current model instance
:return: database.Database instance, model was inserted or selected from or None
"""
return self._database
def get_field(self, name):
'''
Get a Field instance given its name, or None if not found.

View File

@ -24,6 +24,8 @@ class DatabaseTestCase(unittest.TestCase):
def _insert_and_check(self, data, count):
self.database.insert(data)
self.assertEquals(count, self.database.count(Person))
for instance in data:
self.assertEquals(self.database, instance.get_database())
def test_insert__generator(self):
self._insert_and_check(self._sample_data(), len(data))
@ -53,6 +55,8 @@ class DatabaseTestCase(unittest.TestCase):
self.assertEquals(results[0].height, 1.72)
self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
def test_select_partial_fields(self):
self._insert_and_check(self._sample_data(), len(data))
@ -63,6 +67,8 @@ class DatabaseTestCase(unittest.TestCase):
self.assertEquals(results[0].height, 0) # default value
self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 0) # default value
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
def test_select_ad_hoc_model(self):
self._insert_and_check(self._sample_data(), len(data))
@ -74,6 +80,8 @@ class DatabaseTestCase(unittest.TestCase):
self.assertEquals(results[0].height, 1.72)
self.assertEquals(results[1].last_name, 'Scott')
self.assertEquals(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
def test_pagination(self):
self._insert_and_check(self._sample_data(), len(data))