Add $table and $db placeholders

This commit is contained in:
Itai Shirav 2016-07-11 16:17:49 +03:00
parent 993705e97b
commit f29b3ea696
3 changed files with 40 additions and 5 deletions

View File

@ -91,6 +91,25 @@ It is possible to select only a subset of the columns, and the rest will receive
for person in db.select("SELECT first_name FROM my_test_db.person WHERE last_name='Smith'", model_class=Person):
print person.first_name
SQL Placeholders
****************
There are a couple of special placeholders that you can use inside the SQL to make it easier to write:
``$db`` and ``$table``. The first one is replaced by the database name, and the second is replaced by
the database name plus table name (but is available only when the model is specified).
So instead of this::
db.select("SELECT * FROM my_test_db.person", model_class=Person)
you can use::
db.select("SELECT * FROM $db.person", model_class=Person)
or even::
db.select("SELECT * FROM $table", model_class=Person)
Ad-Hoc Models
*************

View File

@ -5,6 +5,7 @@ from utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime
import logging
from string import Template
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
@ -41,7 +42,7 @@ class Database(object):
return # model_instances is empty
model_class = first_instance.__class__
def gen():
yield 'INSERT INTO `%s`.`%s` FORMAT TabSeparated\n' % (self.db_name, model_class.table_name())
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class)
yield first_instance.to_tsv()
yield '\n'
for instance in i:
@ -50,14 +51,16 @@ class Database(object):
self._send(gen())
def count(self, model_class, conditions=None):
query = 'SELECT count() FROM `%s`.`%s`' % (self.db_name, model_class.table_name())
query = 'SELECT count() FROM $table'
if conditions:
query += ' WHERE ' + conditions
query = self._substitute(query, model_class)
r = self._send(query)
return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None):
query += ' FORMAT TabSeparatedWithNamesAndTypes'
query = self._substitute(query, model_class)
r = self._send(query, settings, True)
lines = r.iter_lines()
field_names = parse_tsv(next(lines))
@ -70,11 +73,12 @@ class Database(object):
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
offset = (page_num - 1) * page_size
query = 'SELECT * FROM `%s`.`%s`' % (self.db_name, model_class.table_name())
query = 'SELECT * FROM $table'
if conditions:
query += ' WHERE ' + conditions
query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size)
query = self._substitute(query, model_class)
return Page(
objects=list(self.select(query, model_class, settings)),
number_of_objects=count,
@ -100,7 +104,8 @@ class Database(object):
def _get_applied_migrations(self, migrations_package_name):
from migrations import MigrationHistory
self.create_table(MigrationHistory)
query = "SELECT module_name from `%s`.`%s` WHERE package_name = '%s'" % (self.db_name, MigrationHistory.table_name(), migrations_package_name)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False):
@ -117,3 +122,14 @@ class Database(object):
if self.password:
params['password'] = password
return params
def _substitute(self, query, model_class=None):
'''
Replaces $db and $table placeholders in the query.
'''
if '$' in query:
mapping = dict(db="`%s`" % self.db_name)
if model_class:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).substitute(mapping)
return query

View File

@ -22,7 +22,7 @@ class MigrationsTestCase(unittest.TestCase):
self.database.drop_table(MigrationHistory)
def tableExists(self, model_class):
query = "EXISTS TABLE `%s`.`%s`" % (self.database.db_name, model_class.table_name())
query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
return next(self.database.select(query)).result == 1
def getTableFields(self, model_class):