diff --git a/README.rst b/README.rst index 7ca6b06..27f2d07 100644 --- a/README.rst +++ b/README.rst @@ -191,8 +191,50 @@ UInt32Field UInt32 int Range 0 to 4294967295 UInt64Field UInt64 int/long Range 0 to 18446744073709551615 Float32Field Float32 float Float64Field Float64 float +Enum8Field Enum8 Enum See below +Enum16Field Enum16 Enum See below +ArrayField Array list See below ============= ======== ================= =================================================== +Working with enum fields +************************ + +``Enum8Field`` and ``Enum16Field`` provide support for working with ClickHouse enum columns. They accept +strings or integers as values, and convert them to the matching Pythonic Enum member. + +Python 3.4 and higher supports Enums natively. When using previous Python versions you +need to install the `enum34` library. + +Example of a model with an enum field:: + + Gender = Enum('Gender', 'male female unspecified') + + class Person(models.Model): + + first_name = fields.StringField() + last_name = fields.StringField() + birthday = fields.DateField() + gender = fields.Enum32Field(Gender) + + engine = engines.MergeTree('birthday', ('first_name', 'last_name', 'birthday')) + + suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female) + +Working with array fields +************************* + +You can create array fields containing any data type, for example:: + + class SensorData(models.Model): + + date = fields.DateField() + temperatures = fields.ArrayField(fields.Float32Field) + humidity_levels = fields.ArrayField(fields.UInt8Field) + + engine = engines.MergeTree('date', ('date',)) + + data = SensorData(date=date.today(), temperatures=[25.5, 31.2, 28.7], humidity_levels=[41, 39, 66]) + Table Engines ------------- diff --git a/buildout.cfg b/buildout.cfg index ddc5939..ad10a3e 100644 --- a/buildout.cfg +++ b/buildout.cfg @@ -45,6 +45,8 @@ recipe = infi.recipe.console_scripts eggs = ${project:name} ipython nose + coverage + enum34 infi.unittest infi.traceback zc.buildout diff --git a/src/infi/clickhouse_orm/fields.py b/src/infi/clickhouse_orm/fields.py index 390c303..0406b01 100644 --- a/src/infi/clickhouse_orm/fields.py +++ b/src/infi/clickhouse_orm/fields.py @@ -3,6 +3,8 @@ import datetime import pytz import time +from .utils import escape, parse_array + class Field(object): @@ -13,7 +15,7 @@ class Field(object): def __init__(self, default=None): self.creation_counter = Field.creation_counter Field.creation_counter += 1 - self.default = default or self.class_default + self.default = self.class_default if default is None else default def to_python(self, value): ''' @@ -36,11 +38,22 @@ class Field(object): if value < min_value or value > max_value: raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) - def get_db_prep_value(self, value): + def to_db_string(self, value, quote=True): ''' - Returns the field's value prepared for interacting with the database. + Returns the field's value prepared for writing to the database. + When quote is true, strings are surrounded by single quotes. ''' - return value + return escape(value, quote) + + def get_sql(self, with_default=True): + ''' + Returns an SQL expression describing the field (e.g. for CREATE TABLE). + ''' + if with_default: + default = self.to_db_string(self.default) + return '%s DEFAULT %s' % (self.db_type, default) + else: + return self.db_type class StringField(Field): @@ -77,8 +90,8 @@ class DateField(Field): def validate(self, value): self._range_check(value, DateField.min_value, DateField.max_value) - def get_db_prep_value(self, value): - return value.isoformat() + def to_db_string(self, value, quote=True): + return escape(value.isoformat(), quote) class DateTimeField(Field): @@ -97,8 +110,8 @@ class DateTimeField(Field): return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) - def get_db_prep_value(self, value): - return int(time.mktime(value.timetuple())) + def to_db_string(self, value, quote=True): + return escape(int(time.mktime(value.timetuple())), quote) class BaseIntField(Field): @@ -187,3 +200,94 @@ class Float64Field(BaseFloatField): db_type = 'Float64' + +class BaseEnumField(Field): + + def __init__(self, enum_cls, default=None): + self.enum_cls = enum_cls + if default is None: + default = list(enum_cls)[0] + super(BaseEnumField, self).__init__(default) + + def to_python(self, value): + if isinstance(value, self.enum_cls): + return value + try: + if isinstance(value, text_type): + return self.enum_cls[value] + if isinstance(value, binary_type): + return self.enum_cls[value.decode('UTF-8')] + if isinstance(value, int): + return self.enum_cls(value) + except (KeyError, ValueError): + pass + raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) + + def to_db_string(self, value, quote=True): + return escape(value.name, quote) + + def get_sql(self, with_default=True): + values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] + sql = '%s(%s)' % (self.db_type, ' ,'.join(values)) + if with_default: + default = self.to_db_string(self.default) + sql = '%s DEFAULT %s' % (sql, default) + return sql + + @classmethod + def create_ad_hoc_field(cls, db_type): + ''' + Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" + this method returns a matching enum field. + ''' + import re + try: + Enum # exists in Python 3.4+ + except NameError: + from enum import Enum # use the enum34 library instead + members = {} + for match in re.finditer("'(\w+)' = (\d+)", db_type): + members[match.group(1)] = int(match.group(2)) + enum_cls = Enum('AdHocEnum', members) + field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field + return field_class(enum_cls) + + +class Enum8Field(BaseEnumField): + + db_type = 'Enum8' + + +class Enum16Field(BaseEnumField): + + db_type = 'Enum16' + + +class ArrayField(Field): + + class_default = [] + + def __init__(self, inner_field, default=None): + self.inner_field = inner_field + super(ArrayField, self).__init__(default) + + def to_python(self, value): + if isinstance(value, text_type): + value = parse_array(value) + elif isinstance(value, binary_type): + value = parse_array(value.decode('UTF-8')) + elif not isinstance(value, (list, tuple)): + raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) + return [self.inner_field.to_python(v) for v in value] + + def validate(self, value): + for v in value: + self.inner_field.validate(v) + + def to_db_string(self, value, quote=True): + array = [self.inner_field.to_db_string(v, quote=True) for v in value] + return '[' + ', '.join(array) + ']' + + def get_sql(self, with_default=True): + from .utils import escape + return 'Array(%s)' % self.inner_field.get_sql(with_default=False) diff --git a/src/infi/clickhouse_orm/migrations.py b/src/infi/clickhouse_orm/migrations.py index 203943f..8167a20 100644 --- a/src/infi/clickhouse_orm/migrations.py +++ b/src/infi/clickhouse_orm/migrations.py @@ -68,12 +68,11 @@ class AlterTable(Operation): if name not in table_fields: logger.info(' Add column %s', name) assert prev_name, 'Cannot add a column to the beginning of the table' - default = field.get_db_prep_value(field.default) - cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name) + cmd = 'ADD COLUMN %s %s AFTER %s' % (name, field.get_sql(), prev_name) self._alter_table(database, cmd) prev_name = name # Identify fields whose type was changed - model_fields = [(name, field.db_type) for name, field in self.model_class._fields] + model_fields = [(name, field.get_sql(with_default=False)) for name, field in self.model_class._fields] for model_field, table_field in zip(model_fields, self._get_table_fields(database)): assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement' if model_field[1] != table_field[1]: diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index c6a0ff8..a077366 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -4,6 +4,9 @@ from .fields import Field from six import with_metaclass +from logging import getLogger +logger = getLogger('clickhouse_orm') + class ModelBase(type): ''' @@ -28,7 +31,6 @@ class ModelBase(type): @classmethod def create_ad_hoc_model(cls, fields): # fields is a list of tuples (name, db_type) - import infi.clickhouse_orm.fields as orm_fields # Check if model exists in cache fields = list(fields) cache_key = str(fields) @@ -37,15 +39,28 @@ class ModelBase(type): # Create an ad hoc model class attrs = {} for name, db_type in fields: - field_class = db_type + 'Field' - if not hasattr(orm_fields, field_class): - raise NotImplementedError('No field class for %s' % db_type) - attrs[name] = getattr(orm_fields, field_class)() + attrs[name] = cls.create_ad_hoc_field(db_type) model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs) # Add the model class to the cache cls.ad_hoc_model_cache[cache_key] = model_class return model_class + @classmethod + def create_ad_hoc_field(cls, db_type): + import infi.clickhouse_orm.fields as orm_fields + # Enums + if db_type.startswith('Enum'): + return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) + # Arrays + if db_type.startswith('Array'): + inner_field = cls.create_ad_hoc_field(db_type[6 : -1]) + return orm_fields.ArrayField(inner_field) + # Simple fields + name = db_type + 'Field' + if not hasattr(orm_fields, name): + raise NotImplementedError('No field class for %s' % db_type) + return getattr(orm_fields, name)() + class Model(with_metaclass(ModelBase)): ''' @@ -107,8 +122,7 @@ class Model(with_metaclass(ModelBase)): parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())] cols = [] for name, field in cls._fields: - default = field.get_db_prep_value(field.default) - cols.append(' %s %s DEFAULT %s' % (name, field.db_type, escape(default))) + cols.append(' %s %s' % (name, field.get_sql())) parts.append(',\n'.join(cols)) parts.append(')') parts.append('ENGINE = ' + cls.engine.create_table_sql()) @@ -142,6 +156,8 @@ class Model(with_metaclass(ModelBase)): ''' parts = [] for name, field in self._fields: - value = field.get_db_prep_value(field.to_python(getattr(self, name))) - parts.append(escape(value, quote=False)) - return '\t'.join(parts) + value = field.to_db_string(getattr(self, name), quote=False) + parts.append(value) + tsv = '\t'.join(parts) + logger.debug(tsv) + return tsv diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index f52a7ae..bcbaf90 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -1,5 +1,6 @@ from six import string_types, binary_type, text_type, PY3 import codecs +import re SPECIAL_CHARS = { @@ -15,6 +16,11 @@ SPECIAL_CHARS = { def escape(value, quote=True): + ''' + If the value is a string, escapes any special characters and optionally + surrounds it with single quotes. If the value is not a string (e.g. a number), + converts it to one. + ''' if isinstance(value, string_types): chars = (SPECIAL_CHARS.get(c, c) for c in value) value = "'" + "".join(chars) + "'" if quote else "".join(chars) @@ -33,6 +39,40 @@ def parse_tsv(line): return [unescape(value) for value in line.split('\t')] +def parse_array(array_string): + """ + Parse an array string as returned by clickhouse. For example: + "['hello', 'world']" ==> ["hello", "world"] + "[1,2,3]" ==> [1, 2, 3] + """ + # Sanity check + if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']': + raise ValueError('Invalid array string: "%s"' % array_string) + # Drop opening brace + array_string = array_string[1:] + # Go over the string, lopping off each value at the beginning until nothing is left + values = [] + while True: + if array_string == ']': + # End of array + return values + elif array_string[0] in ', ': + # In between values + array_string = array_string[1:] + elif array_string[0] == "'": + # Start of quoted value, find its end + match = re.search(r"[^\\]'", array_string) + if match is None: + raise ValueError('Missing closing quote: "%s"' % array_string) + values.append(array_string[1 : match.start() + 1]) + array_string = array_string[match.end():] + else: + # Start of non-quoted value, find its end + match = re.search(r",|\]", array_string) + values.append(array_string[1 : match.start() + 1]) + array_string = array_string[match.end():] + + def import_submodules(package_name): """ Import all submodules of a module. diff --git a/tests/sample_migrations/0006.py b/tests/sample_migrations/0006.py new file mode 100644 index 0000000..fefb325 --- /dev/null +++ b/tests/sample_migrations/0006.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.CreateTable(EnumModel1) +] \ No newline at end of file diff --git a/tests/sample_migrations/0007.py b/tests/sample_migrations/0007.py new file mode 100644 index 0000000..da23040 --- /dev/null +++ b/tests/sample_migrations/0007.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.AlterTable(EnumModel2) +] \ No newline at end of file diff --git a/tests/test_array_fields.py b/tests/test_array_fields.py new file mode 100644 index 0000000..ef3c3c1 --- /dev/null +++ b/tests/test_array_fields.py @@ -0,0 +1,58 @@ +import unittest +from datetime import date + +from infi.clickhouse_orm.database import Database +from infi.clickhouse_orm.models import Model +from infi.clickhouse_orm.fields import * +from infi.clickhouse_orm.engines import * + + +class ArrayFieldsTest(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db') + self.database.create_table(ModelWithArrays) + + def tearDown(self): + self.database.drop_database() + + def test_insert_and_select(self): + instance = ModelWithArrays( + date_field='2016-08-30', + arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'], + arr_date=['2010-01-01'] + ) + self.database.insert([instance]) + query = 'SELECT * from $db.modelwitharrays ORDER BY date_field' + for model_cls in (ModelWithArrays, None): + results = list(self.database.select(query, model_cls)) + self.assertEquals(len(results), 1) + self.assertEquals(results[0].arr_str, instance.arr_str) + self.assertEquals(results[0].arr_int, instance.arr_int) + self.assertEquals(results[0].arr_date, instance.arr_date) + + def test_conversion(self): + instance = ModelWithArrays( + arr_int=('1', '2', '3'), + arr_date=['2010-01-01'] + ) + self.assertEquals(instance.arr_str, []) + self.assertEquals(instance.arr_int, [1, 2, 3]) + self.assertEquals(instance.arr_date, [date(2010, 1, 1)]) + + def test_assignment_error(self): + instance = ModelWithArrays() + for value in (7, 'x', [date.today()], ['aaa'], [None]): + with self.assertRaises(ValueError): + instance.arr_int = value + + +class ModelWithArrays(Model): + + date_field = DateField() + arr_str = ArrayField(StringField()) + arr_int = ArrayField(Int32Field()) + arr_date = ArrayField(DateField()) + + engine = MergeTree('date_field', ('date_field',)) + diff --git a/tests/test_enum_fields.py b/tests/test_enum_fields.py new file mode 100644 index 0000000..78df6d3 --- /dev/null +++ b/tests/test_enum_fields.py @@ -0,0 +1,87 @@ +import unittest + +from infi.clickhouse_orm.database import Database +from infi.clickhouse_orm.models import Model +from infi.clickhouse_orm.fields import * +from infi.clickhouse_orm.engines import * + +try: + Enum # exists in Python 3.4+ +except NameError: + from enum import Enum # use the enum34 library instead + + +class EnumFieldsTest(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db') + self.database.create_table(ModelWithEnum) + self.database.create_table(ModelWithEnumArray) + + def tearDown(self): + self.database.drop_database() + + def test_insert_and_select(self): + self.database.insert([ + ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), + ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange) + ]) + query = 'SELECT * from $table ORDER BY date_field' + results = list(self.database.select(query, ModelWithEnum)) + self.assertEquals(len(results), 2) + self.assertEquals(results[0].enum_field, Fruit.apple) + self.assertEquals(results[1].enum_field, Fruit.orange) + + def test_ad_hoc_model(self): + self.database.insert([ + ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), + ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange) + ]) + query = 'SELECT * from $db.modelwithenum ORDER BY date_field' + results = list(self.database.select(query)) + self.assertEquals(len(results), 2) + self.assertEquals(results[0].enum_field.name, Fruit.apple.name) + self.assertEquals(results[0].enum_field.value, Fruit.apple.value) + self.assertEquals(results[1].enum_field.name, Fruit.orange.name) + self.assertEquals(results[1].enum_field.value, Fruit.orange.value) + + def test_conversion(self): + self.assertEquals(ModelWithEnum(enum_field=3).enum_field, Fruit.orange) + self.assertEquals(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple) + self.assertEquals(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana) + + def test_assignment_error(self): + for value in (0, 17, 'pear', '', None, 99.9): + with self.assertRaises(ValueError): + ModelWithEnum(enum_field=value) + + def test_default_value(self): + instance = ModelWithEnum() + self.assertEquals(instance.enum_field, Fruit.apple) + + def test_enum_array(self): + instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange]) + self.database.insert([instance]) + query = 'SELECT * from $table ORDER BY date_field' + results = list(self.database.select(query, ModelWithEnumArray)) + self.assertEquals(len(results), 1) + self.assertEquals(results[0].enum_array, instance.enum_array) + + +Fruit = Enum('Fruit', u'apple banana orange') + + +class ModelWithEnum(Model): + + date_field = DateField() + enum_field = Enum8Field(Fruit) + + engine = MergeTree('date_field', ('date_field',)) + + +class ModelWithEnumArray(Model): + + date_field = DateField() + enum_array = ArrayField(Enum16Field(Fruit)) + + engine = MergeTree('date_field', ('date_field',)) diff --git a/tests/test_migrations.py b/tests/test_migrations.py index 1a03034..39bcb55 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -10,6 +10,11 @@ from infi.clickhouse_orm.migrations import MigrationHistory import sys, os sys.path.append(os.path.dirname(__file__)) +try: + Enum # exists in Python 3.4+ +except NameError: + from enum import Enum # use the enum34 library instead + import logging logging.basicConfig(level=logging.DEBUG, format='%(message)s') logging.getLogger("requests").setLevel(logging.WARNING) @@ -21,6 +26,9 @@ class MigrationsTestCase(unittest.TestCase): self.database = Database('test-db') self.database.drop_table(MigrationHistory) + def tearDown(self): + self.database.drop_database() + def tableExists(self, model_class): query = "EXISTS TABLE $db.`%s`" % model_class.table_name() return next(self.database.select(query)).result == 1 @@ -30,18 +38,28 @@ class MigrationsTestCase(unittest.TestCase): return [(row.name, row.type) for row in self.database.select(query)] def test_migrations(self): + # Creation and deletion of table self.database.migrate('tests.sample_migrations', 1) self.assertTrue(self.tableExists(Model1)) self.database.migrate('tests.sample_migrations', 2) self.assertFalse(self.tableExists(Model1)) self.database.migrate('tests.sample_migrations', 3) self.assertTrue(self.tableExists(Model1)) + # Adding, removing and altering simple fields self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) self.database.migrate('tests.sample_migrations', 4) self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')]) self.database.migrate('tests.sample_migrations', 5) self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')]) - + # Altering enum fields + self.database.migrate('tests.sample_migrations', 6) + self.assertTrue(self.tableExists(EnumModel1)) + self.assertEquals(self.getTableFields(EnumModel1), + [('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")]) + self.database.migrate('tests.sample_migrations', 7) + self.assertTrue(self.tableExists(EnumModel1)) + self.assertEquals(self.getTableFields(EnumModel2), + [('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")]) # Several different models with the same table name, to simulate a table that changes over time @@ -86,3 +104,26 @@ class Model3(Model): def table_name(cls): return 'mig' + +class EnumModel1(Model): + + date = DateField() + f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow')) + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'enum_mig' + + +class EnumModel2(Model): + + date = DateField() + f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'enum_mig'