Performance improvements when inserting

escape - check first if there are any special chars before replacing
send lines in batches
Use list comprehension in to_tsv
This commit is contained in:
Itai Shirav 2016-09-15 11:32:07 +03:00
parent 6221fb4143
commit 66f8e8a4ae
3 changed files with 21 additions and 14 deletions

View File

@ -35,7 +35,7 @@ class Database(object):
def drop_database(self): def drop_database(self):
self._send('DROP DATABASE `%s`' % self.db_name) self._send('DROP DATABASE `%s`' % self.db_name)
def insert(self, model_instances): def insert(self, model_instances, batch_size=1000):
from six import next from six import next
i = iter(model_instances) i = iter(model_instances)
try: try:
@ -45,11 +45,19 @@ class Database(object):
model_class = first_instance.__class__ model_class = first_instance.__class__
def gen(): def gen():
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8') yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
yield first_instance.to_tsv().encode('utf-8') yield (first_instance.to_tsv() + '\n').encode('utf-8')
yield '\n'.encode('utf-8') # Collect lines in batches of batch_size
batch = []
for instance in i: for instance in i:
yield instance.to_tsv().encode('utf-8') batch.append(instance.to_tsv())
yield '\n'.encode('utf-8') if len(batch) >= batch_size:
# Return the current batch of lines
yield ('\n'.join(batch) + '\n').encode('utf-8')
# Start a new batch
batch = []
# Return any remaining lines in partial batch
if batch:
yield ('\n'.join(batch) + '\n').encode('utf-8')
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):

View File

@ -154,10 +154,5 @@ class Model(with_metaclass(ModelBase)):
''' '''
Returns the instance's column values as a tab-separated line. A newline is not included. Returns the instance's column values as a tab-separated line. A newline is not included.
''' '''
parts = [] data = self.__dict__
for name, field in self._fields: return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in self._fields)
value = field.to_db_string(getattr(self, name), quote=False)
parts.append(value)
tsv = '\t'.join(parts)
logger.debug(tsv)
return tsv

View File

@ -14,6 +14,8 @@ SPECIAL_CHARS = {
"'" : "\\'" "'" : "\\'"
} }
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True): def escape(value, quote=True):
''' '''
@ -22,8 +24,10 @@ def escape(value, quote=True):
converts it to one. converts it to one.
''' '''
if isinstance(value, string_types): if isinstance(value, string_types):
chars = (SPECIAL_CHARS.get(c, c) for c in value) if SPECIAL_CHARS_REGEX.search(value):
value = "'" + "".join(chars) + "'" if quote else "".join(chars) value = "".join(SPECIAL_CHARS.get(c, c) for c in value)
if quote:
value = "'" + value + "'"
return text_type(value) return text_type(value)