diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8ad9e9e..4579012 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Change Log ========== +v0.8.1 +------ +- Add support for ReplacingMergeTree (leenr) +- Fix problem with SELECT WITH TOTALS (pilosus) +- Update serialization format of DateTimeField to 10 digits, zero padded (nikepan) +- Greatly improve performance when inserting large strings (credit to M1hacka for identifying the problem) +- Reduce memory footprint of Database.insert() + v0.8.0 ------ - Always keep datetime fields in UTC internally, and convert server timezone to UTC when parsing query results diff --git a/README.rst b/README.rst index 9db0ab1..6c2c6d9 100644 --- a/README.rst +++ b/README.rst @@ -380,6 +380,10 @@ For a ``SummingMergeTree`` you can optionally specify the summing columns:: engine = engines.SummingMergeTree('EventDate', ('OrderID', 'EventDate', 'BannerID'), summing_cols=('Shows', 'Clicks', 'Cost')) +For a ``ReplacingMergeTree`` you can optionally specify the version column:: + + engine = engines.ReplacingMergeTree('EventDate', ('OrderID', 'EventDate', 'BannerID'), ver_col='Version') + A ``Buffer`` engine is available for BufferModels. (See below how to use BufferModel). You can specify following parameters:: engine = engines.Buffer(Person) # you need to initialize engine with main Model. Other default parameters will be used @@ -425,7 +429,6 @@ After cloning the project, run the following commands:: To run the tests, ensure that the ClickHouse server is running on http://localhost:8123/ (this is the default), and run:: bin/nosetests -======= To see test coverage information run:: diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index 2be8534..7226f82 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -50,6 +50,7 @@ class Database(object): def insert(self, model_instances, batch_size=1000): from six import next + from cStringIO import StringIO i = iter(model_instances) try: first_instance = next(i) @@ -61,22 +62,27 @@ class Database(object): raise DatabaseException("You can't insert into read only table") def gen(): - yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8') + buf = StringIO() + buf.write(self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')) first_instance.set_database(self) - yield (first_instance.to_tsv(include_readonly=False) + '\n').encode('utf-8') + buf.write(first_instance.to_tsv(include_readonly=False).encode('utf-8')) + buf.write('\n') # Collect lines in batches of batch_size - batch = [] + lines = 2 for instance in i: instance.set_database(self) - batch.append(instance.to_tsv(include_readonly=False)) - if len(batch) >= batch_size: + buf.write(instance.to_tsv(include_readonly=False).encode('utf-8')) + buf.write('\n') + lines += 1 + if lines >= batch_size: # Return the current batch of lines - yield ('\n'.join(batch) + '\n').encode('utf-8') + yield buf.getvalue() # Start a new batch - batch = [] + buf = StringIO() + lines = 0 # Return any remaining lines in partial batch - if batch: - yield ('\n'.join(batch) + '\n').encode('utf-8') + if lines: + yield buf.getvalue() self._send(gen()) def count(self, model_class, conditions=None): @@ -96,7 +102,9 @@ class Database(object): field_types = parse_tsv(next(lines)) model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types)) for line in lines: - yield model_class.from_tsv(line, field_names, self.server_timezone, self) + # skip blank line left by WITH TOTALS modifier + if line: + yield model_class.from_tsv(line, field_names, self.server_timezone, self) def raw(self, query, settings=None, stream=False): """ diff --git a/src/infi/clickhouse_orm/engines.py b/src/infi/clickhouse_orm/engines.py index 7c9a94c..c26b451 100644 --- a/src/infi/clickhouse_orm/engines.py +++ b/src/infi/clickhouse_orm/engines.py @@ -9,6 +9,7 @@ class MergeTree(Engine): def __init__(self, date_col, key_cols, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None): + assert type(key_cols) in (list, tuple), 'key_cols must be a list or tuple' self.date_col = date_col self.key_cols = key_cols self.sampling_expr = sampling_expr @@ -54,6 +55,7 @@ class SummingMergeTree(MergeTree): def __init__(self, date_col, key_cols, summing_cols=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None): super(SummingMergeTree, self).__init__(date_col, key_cols, sampling_expr, index_granularity, replica_table_path, replica_name) + assert type is None or type(summing_cols) in (list, tuple), 'summing_cols must be a list or tuple' self.summing_cols = summing_cols def _build_sql_params(self): @@ -63,6 +65,20 @@ class SummingMergeTree(MergeTree): return params +class ReplacingMergeTree(MergeTree): + + def __init__(self, date_col, key_cols, ver_col=None, sampling_expr=None, + index_granularity=8192, replica_table_path=None, replica_name=None): + super(ReplacingMergeTree, self).__init__(date_col, key_cols, sampling_expr, index_granularity, replica_table_path, replica_name) + self.ver_col = ver_col + + def _build_sql_params(self): + params = super(ReplacingMergeTree, self)._build_sql_params() + if self.ver_col: + params.append(self.ver_col) + return params + + class Buffer(Engine): """Here we define Buffer engine Read more here https://clickhouse.yandex/reference_en.html#Buffer diff --git a/src/infi/clickhouse_orm/fields.py b/src/infi/clickhouse_orm/fields.py index 3fa49ec..f7d3992 100644 --- a/src/infi/clickhouse_orm/fields.py +++ b/src/infi/clickhouse_orm/fields.py @@ -132,12 +132,18 @@ class DateTimeField(Field): if isinstance(value, string_types): if value == '0000-00-00 00:00:00': return self.class_default + if len(value) == 10: + try: + value = int(value) + return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) + except ValueError: + pass dt = datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') return timezone_in_use.localize(dt).astimezone(pytz.utc) raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True): - return escape(timegm(value.utctimetuple()), quote) + return escape('%010d' % timegm(value.utctimetuple()), quote) class BaseIntField(Field): @@ -148,6 +154,11 @@ class BaseIntField(Field): except: raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + def to_db_string(self, value, quote=True): + # There's no need to call escape since numbers do not contain + # special characters, and never need quoting + return text_type(value) + def validate(self, value): self._range_check(value, self.min_value, self.max_value) @@ -216,6 +227,11 @@ class BaseFloatField(Field): except: raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + def to_db_string(self, value, quote=True): + # There's no need to call escape since numbers do not contain + # special characters, and never need quoting + return text_type(value) + class Float32Field(BaseFloatField): diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index a6163f7..44437a9 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -27,6 +27,7 @@ class ModelBase(type): fields = base_fields + [item for item in attrs.items() if isinstance(item[1], Field)] fields.sort(key=lambda item: item[1].creation_counter) setattr(new_cls, '_fields', fields) + setattr(new_cls, '_writable_fields', [f for f in fields if not f[1].readonly]) return new_cls @classmethod @@ -186,7 +187,7 @@ class Model(with_metaclass(ModelBase)): :param bool include_readonly: If False, returns only fields, that can be inserted into database ''' data = self.__dict__ - fields = self._fields if include_readonly else [f for f in self._fields if not f[1].readonly] + fields = self._fields if include_readonly else self._writable_fields return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields) def to_dict(self, include_readonly=True, field_names=None): @@ -195,7 +196,7 @@ class Model(with_metaclass(ModelBase)): :param bool include_readonly: If False, returns only fields, that can be inserted into database :param field_names: An iterable of field names to return ''' - fields = self._fields if include_readonly else [f for f in self._fields if not f[1].readonly] + fields = self._fields if include_readonly else self._writable_fields if field_names is not None: fields = [f for f in fields if f[0] in field_names] diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index c24a93d..83d11e0 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -17,15 +17,18 @@ SPECIAL_CHARS = { SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") + def escape(value, quote=True): ''' If the value is a string, escapes any special characters and optionally surrounds it with single quotes. If the value is not a string (e.g. a number), converts it to one. ''' + def escape_one(match): + return SPECIAL_CHARS[match.group(0)] + if isinstance(value, string_types): - if SPECIAL_CHARS_REGEX.search(value): - value = "".join(SPECIAL_CHARS.get(c, c) for c in value) + value = SPECIAL_CHARS_REGEX.sub(escape_one, value) if quote: value = "'" + value + "'" return text_type(value) @@ -38,7 +41,7 @@ def unescape(value): def parse_tsv(line): if PY3 and isinstance(line, binary_type): line = line.decode() - if line[-1] == '\n': + if line and line[-1] == '\n': line = line[:-1] return [unescape(value) for value in line.split('\t')] diff --git a/tests/test_database.py b/tests/test_database.py index 62b340b..1897d8f 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -89,6 +89,15 @@ class DatabaseTestCase(unittest.TestCase): self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database) + def test_select_with_totals(self): + self._insert_and_check(self._sample_data(), len(data)) + query = "SELECT last_name, sum(height) as height FROM `test-db`.person GROUP BY last_name WITH TOTALS" + results = list(self.database.select(query)) + total = sum(r.height for r in results[:-1]) + # Last line has an empty last name, and total of all heights + self.assertFalse(results[-1].last_name) + self.assertEquals(total, results[-1].height) + def test_pagination(self): self._insert_and_check(self._sample_data(), len(data)) # Try different page sizes diff --git a/tests/test_engines.py b/tests/test_engines.py new file mode 100644 index 0000000..d3d8865 --- /dev/null +++ b/tests/test_engines.py @@ -0,0 +1,64 @@ +import unittest + +from infi.clickhouse_orm.database import Database, DatabaseException +from infi.clickhouse_orm.models import Model +from infi.clickhouse_orm.fields import * +from infi.clickhouse_orm.engines import * + +import logging +logging.getLogger("requests").setLevel(logging.WARNING) + + +class EnginesTestCase(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db') + + def tearDown(self): + self.database.drop_database() + + def _create_and_insert(self, model_class): + self.database.create_table(model_class) + self.database.insert([ + model_class(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1) + ]) + + def test_merge_tree(self): + class TestModel(SampleModel): + engine = MergeTree('date', ('date', 'event_id', 'event_group')) + self._create_and_insert(TestModel) + + def test_merge_tree_with_sampling(self): + class TestModel(SampleModel): + engine = MergeTree('date', ('date', 'event_id', 'event_group'), sampling_expr='intHash32(event_id)') + self._create_and_insert(TestModel) + + def test_merge_tree_with_granularity(self): + class TestModel(SampleModel): + engine = MergeTree('date', ('date', 'event_id', 'event_group'), index_granularity=4096) + self._create_and_insert(TestModel) + + def test_collapsing_merge_tree(self): + class TestModel(SampleModel): + engine = CollapsingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_version') + self._create_and_insert(TestModel) + + def test_summing_merge_tree(self): + class TestModel(SampleModel): + engine = SummingMergeTree('date', ('date', 'event_group'), ('event_count',)) + self._create_and_insert(TestModel) + + def test_replacing_merge_tree(self): + class TestModel(SampleModel): + engine = ReplacingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_uversion') + self._create_and_insert(TestModel) + + +class SampleModel(Model): + + date = DateField() + event_id = UInt32Field() + event_group = UInt32Field() + event_count = UInt16Field() + event_version = Int8Field() + event_uversion = UInt8Field(materialized='abs(event_version)') diff --git a/tests/test_simple_fields.py b/tests/test_simple_fields.py index 2d1f2ab..645d9ed 100644 --- a/tests/test_simple_fields.py +++ b/tests/test_simple_fields.py @@ -6,32 +6,17 @@ import pytz class SimpleFieldsTest(unittest.TestCase): - def test_date_field(self): - f = DateField() - # Valid values - for value in (date(1970, 1, 1), datetime(1970, 1, 1), '1970-01-01', '0000-00-00', 0): - self.assertEquals(f.to_python(value, pytz.utc), date(1970, 1, 1)) - # Invalid values - for value in ('nope', '21/7/1999', 0.5): - with self.assertRaises(ValueError): - f.to_python(value, pytz.utc) - # Range check - for value in (date(1900, 1, 1), date(2900, 1, 1)): - with self.assertRaises(ValueError): - f.validate(value) - def test_datetime_field(self): f = DateTimeField() epoch = datetime(1970, 1, 1, tzinfo=pytz.utc) # Valid values for value in (date(1970, 1, 1), datetime(1970, 1, 1), epoch, epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')), - '1970-01-01 00:00:00', '0000-00-00 00:00:00', 0): + '1970-01-01 00:00:00', '1970-01-17 00:00:17', '0000-00-00 00:00:00', 0): dt = f.to_python(value, pytz.utc) self.assertEquals(dt.tzinfo, pytz.utc) - self.assertEquals(dt, epoch) # Verify that conversion to and from db string does not change value - dt2 = f.to_python(int(f.to_db_string(dt)), pytz.utc) + dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) self.assertEquals(dt, dt2) # Invalid values for value in ('nope', '21/7/1999', 0.5): @@ -52,6 +37,10 @@ class SimpleFieldsTest(unittest.TestCase): for value in ('nope', '21/7/1999', 0.5): with self.assertRaises(ValueError): f.to_python(value, pytz.utc) + # Range check + for value in (date(1900, 1, 1), date(2900, 1, 1)): + with self.assertRaises(ValueError): + f.validate(value) def test_date_field_timezone(self): # Verify that conversion of timezone-aware datetime is correct