diff --git a/README.rst b/README.rst index e70e585..7ca6b06 100644 --- a/README.rst +++ b/README.rst @@ -91,6 +91,25 @@ It is possible to select only a subset of the columns, and the rest will receive for person in db.select("SELECT first_name FROM my_test_db.person WHERE last_name='Smith'", model_class=Person): print person.first_name +SQL Placeholders +**************** + +There are a couple of special placeholders that you can use inside the SQL to make it easier to write: +``$db`` and ``$table``. The first one is replaced by the database name, and the second is replaced by +the database name plus table name (but is available only when the model is specified). + +So instead of this:: + + db.select("SELECT * FROM my_test_db.person", model_class=Person) + +you can use:: + + db.select("SELECT * FROM $db.person", model_class=Person) + +or even:: + + db.select("SELECT * FROM $table", model_class=Person) + Ad-Hoc Models ************* diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index 342c870..1187647 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -5,6 +5,7 @@ from utils import escape, parse_tsv, import_submodules from math import ceil import datetime import logging +from string import Template Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') @@ -41,7 +42,7 @@ class Database(object): return # model_instances is empty model_class = first_instance.__class__ def gen(): - yield 'INSERT INTO `%s`.`%s` FORMAT TabSeparated\n' % (self.db_name, model_class.table_name()) + yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class) yield first_instance.to_tsv() yield '\n' for instance in i: @@ -50,14 +51,16 @@ class Database(object): self._send(gen()) def count(self, model_class, conditions=None): - query = 'SELECT count() FROM `%s`.`%s`' % (self.db_name, model_class.table_name()) + query = 'SELECT count() FROM $table' if conditions: query += ' WHERE ' + conditions + query = self._substitute(query, model_class) r = self._send(query) return int(r.text) if r.text else 0 def select(self, query, model_class=None, settings=None): query += ' FORMAT TabSeparatedWithNamesAndTypes' + query = self._substitute(query, model_class) r = self._send(query, settings, True) lines = r.iter_lines() field_names = parse_tsv(next(lines)) @@ -70,11 +73,12 @@ class Database(object): count = self.count(model_class, conditions) pages_total = int(ceil(count / float(page_size))) offset = (page_num - 1) * page_size - query = 'SELECT * FROM `%s`.`%s`' % (self.db_name, model_class.table_name()) + query = 'SELECT * FROM $table' if conditions: query += ' WHERE ' + conditions query += ' ORDER BY %s' % order_by query += ' LIMIT %d, %d' % (offset, page_size) + query = self._substitute(query, model_class) return Page( objects=list(self.select(query, model_class, settings)), number_of_objects=count, @@ -100,7 +104,8 @@ class Database(object): def _get_applied_migrations(self, migrations_package_name): from migrations import MigrationHistory self.create_table(MigrationHistory) - query = "SELECT module_name from `%s`.`%s` WHERE package_name = '%s'" % (self.db_name, MigrationHistory.table_name(), migrations_package_name) + query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name + query = self._substitute(query, MigrationHistory) return set(obj.module_name for obj in self.select(query)) def _send(self, data, settings=None, stream=False): @@ -117,3 +122,14 @@ class Database(object): if self.password: params['password'] = password return params + + def _substitute(self, query, model_class=None): + ''' + Replaces $db and $table placeholders in the query. + ''' + if '$' in query: + mapping = dict(db="`%s`" % self.db_name) + if model_class: + mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) + query = Template(query).substitute(mapping) + return query diff --git a/tests/test_migrations.py b/tests/test_migrations.py index 627e3f6..1a03034 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -22,7 +22,7 @@ class MigrationsTestCase(unittest.TestCase): self.database.drop_table(MigrationHistory) def tableExists(self, model_class): - query = "EXISTS TABLE `%s`.`%s`" % (self.database.db_name, model_class.table_name()) + query = "EXISTS TABLE $db.`%s`" % model_class.table_name() return next(self.database.select(query)).result == 1 def getTableFields(self, model_class):