Support dashes and other special characters in database names and table names

This commit is contained in:
Itai Shirav 2016-06-30 11:36:54 +03:00
parent c1053f4c75
commit c6c9f13e51
3 changed files with 10 additions and 10 deletions

View File

@ -14,7 +14,7 @@ class Database(object):
self.db_url = db_url self.db_url = db_url
self.username = username self.username = username
self.password = password self.password = password
self._send('CREATE DATABASE IF NOT EXISTS ' + db_name) self._send('CREATE DATABASE IF NOT EXISTS `%s`' % db_name)
def create_table(self, model_class): def create_table(self, model_class):
# TODO check that model has an engine # TODO check that model has an engine
@ -24,7 +24,7 @@ class Database(object):
self._send(model_class.drop_table_sql(self.db_name)) self._send(model_class.drop_table_sql(self.db_name))
def drop_database(self): def drop_database(self):
self._send('DROP DATABASE ' + self.db_name) self._send('DROP DATABASE `%s`' % self.db_name)
def insert(self, model_instances): def insert(self, model_instances):
i = iter(model_instances) i = iter(model_instances)
@ -34,7 +34,7 @@ class Database(object):
return # model_instances is empty return # model_instances is empty
model_class = first_instance.__class__ model_class = first_instance.__class__
def gen(): def gen():
yield 'INSERT INTO %s.%s FORMAT TabSeparated\n' % (self.db_name, model_class.table_name()) yield 'INSERT INTO `%s`.`%s` FORMAT TabSeparated\n' % (self.db_name, model_class.table_name())
yield first_instance.to_tsv() yield first_instance.to_tsv()
yield '\n' yield '\n'
for instance in i: for instance in i:
@ -43,7 +43,7 @@ class Database(object):
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):
query = 'SELECT count() FROM %s.%s' % (self.db_name, model_class.table_name()) query = 'SELECT count() FROM `%s`.`%s`' % (self.db_name, model_class.table_name())
if conditions: if conditions:
query += ' WHERE ' + conditions query += ' WHERE ' + conditions
r = self._send(query) r = self._send(query)

View File

@ -92,7 +92,7 @@ class Model(object):
''' '''
Returns the SQL command for creating a table for this model. Returns the SQL command for creating a table for this model.
''' '''
parts = ['CREATE TABLE IF NOT EXISTS %s.%s (' % (db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls._fields: for name, field in cls._fields:
default = field.get_db_prep_value(field.default) default = field.get_db_prep_value(field.default)
@ -107,7 +107,7 @@ class Model(object):
''' '''
Returns the SQL command for deleting this model's table. Returns the SQL command for deleting this model's table.
''' '''
return 'DROP TABLE IF EXISTS %s.%s' % (db_name, cls.table_name()) return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names=None): def from_tsv(cls, line, field_names=None):

View File

@ -9,7 +9,7 @@ from infi.clickhouse_orm.engines import *
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test_db') self.database = Database('test-db')
self.database.create_table(Person) self.database.create_table(Person)
def tearDown(self): def tearDown(self):
@ -41,7 +41,7 @@ class DatabaseTestCase(unittest.TestCase):
def test_select(self): def test_select(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
query = "SELECT * FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEquals(len(results), 2) self.assertEquals(len(results), 2)
self.assertEquals(results[0].last_name, 'Durham') self.assertEquals(results[0].last_name, 'Durham')
@ -51,7 +51,7 @@ class DatabaseTestCase(unittest.TestCase):
def test_select_partial_fields(self): def test_select_partial_fields(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
query = "SELECT first_name, last_name FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEquals(len(results), 2) self.assertEquals(len(results), 2)
self.assertEquals(results[0].last_name, 'Durham') self.assertEquals(results[0].last_name, 'Durham')
@ -61,7 +61,7 @@ class DatabaseTestCase(unittest.TestCase):
def test_select_ad_hoc_model(self): def test_select_ad_hoc_model(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
query = "SELECT * FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEquals(len(results), 2) self.assertEquals(len(results), 2)
self.assertEquals(results[0].__class__.__name__, 'AdHocModel') self.assertEquals(results[0].__class__.__name__, 'AdHocModel')