Add $table and $db placeholders

This commit is contained in:
Itai Shirav 2016-07-11 16:17:49 +03:00
parent 993705e97b
commit f29b3ea696
3 changed files with 40 additions and 5 deletions

View File

@ -91,6 +91,25 @@ It is possible to select only a subset of the columns, and the rest will receive
for person in db.select("SELECT first_name FROM my_test_db.person WHERE last_name='Smith'", model_class=Person): for person in db.select("SELECT first_name FROM my_test_db.person WHERE last_name='Smith'", model_class=Person):
print person.first_name print person.first_name
SQL Placeholders
****************
There are a couple of special placeholders that you can use inside the SQL to make it easier to write:
``$db`` and ``$table``. The first one is replaced by the database name, and the second is replaced by
the database name plus table name (but is available only when the model is specified).
So instead of this::
db.select("SELECT * FROM my_test_db.person", model_class=Person)
you can use::
db.select("SELECT * FROM $db.person", model_class=Person)
or even::
db.select("SELECT * FROM $table", model_class=Person)
Ad-Hoc Models Ad-Hoc Models
************* *************

View File

@ -5,6 +5,7 @@ from utils import escape, parse_tsv, import_submodules
from math import ceil from math import ceil
import datetime import datetime
import logging import logging
from string import Template
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
@ -41,7 +42,7 @@ class Database(object):
return # model_instances is empty return # model_instances is empty
model_class = first_instance.__class__ model_class = first_instance.__class__
def gen(): def gen():
yield 'INSERT INTO `%s`.`%s` FORMAT TabSeparated\n' % (self.db_name, model_class.table_name()) yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class)
yield first_instance.to_tsv() yield first_instance.to_tsv()
yield '\n' yield '\n'
for instance in i: for instance in i:
@ -50,14 +51,16 @@ class Database(object):
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):
query = 'SELECT count() FROM `%s`.`%s`' % (self.db_name, model_class.table_name()) query = 'SELECT count() FROM $table'
if conditions: if conditions:
query += ' WHERE ' + conditions query += ' WHERE ' + conditions
query = self._substitute(query, model_class)
r = self._send(query) r = self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None): def select(self, query, model_class=None, settings=None):
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += ' FORMAT TabSeparatedWithNamesAndTypes'
query = self._substitute(query, model_class)
r = self._send(query, settings, True) r = self._send(query, settings, True)
lines = r.iter_lines() lines = r.iter_lines()
field_names = parse_tsv(next(lines)) field_names = parse_tsv(next(lines))
@ -70,11 +73,12 @@ class Database(object):
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
query = 'SELECT * FROM `%s`.`%s`' % (self.db_name, model_class.table_name()) query = 'SELECT * FROM $table'
if conditions: if conditions:
query += ' WHERE ' + conditions query += ' WHERE ' + conditions
query += ' ORDER BY %s' % order_by query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size) query += ' LIMIT %d, %d' % (offset, page_size)
query = self._substitute(query, model_class)
return Page( return Page(
objects=list(self.select(query, model_class, settings)), objects=list(self.select(query, model_class, settings)),
number_of_objects=count, number_of_objects=count,
@ -100,7 +104,8 @@ class Database(object):
def _get_applied_migrations(self, migrations_package_name): def _get_applied_migrations(self, migrations_package_name):
from migrations import MigrationHistory from migrations import MigrationHistory
self.create_table(MigrationHistory) self.create_table(MigrationHistory)
query = "SELECT module_name from `%s`.`%s` WHERE package_name = '%s'" % (self.db_name, MigrationHistory.table_name(), migrations_package_name) query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query)) return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False): def _send(self, data, settings=None, stream=False):
@ -117,3 +122,14 @@ class Database(object):
if self.password: if self.password:
params['password'] = password params['password'] = password
return params return params
def _substitute(self, query, model_class=None):
'''
Replaces $db and $table placeholders in the query.
'''
if '$' in query:
mapping = dict(db="`%s`" % self.db_name)
if model_class:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).substitute(mapping)
return query

View File

@ -22,7 +22,7 @@ class MigrationsTestCase(unittest.TestCase):
self.database.drop_table(MigrationHistory) self.database.drop_table(MigrationHistory)
def tableExists(self, model_class): def tableExists(self, model_class):
query = "EXISTS TABLE `%s`.`%s`" % (self.database.db_name, model_class.table_name()) query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
return next(self.database.select(query)).result == 1 return next(self.database.select(query)).result == 1
def getTableFields(self, model_class): def getTableFields(self, model_class):