Add support for enum fields

This commit is contained in:
Itai Shirav 2016-08-31 15:26:28 +03:00
parent c0af06875c
commit 8fc3a31d4b
9 changed files with 237 additions and 11 deletions

View File

@ -191,8 +191,34 @@ UInt32Field UInt32 int Range 0 to 4294967295
UInt64Field UInt64 int/long Range 0 to 18446744073709551615 UInt64Field UInt64 int/long Range 0 to 18446744073709551615
Float32Field Float32 float Float32Field Float32 float
Float64Field Float64 float Float64Field Float64 float
Enum8Field Enum8 Enum See below
Enum16Field Enum16 Enum See below
============= ======== ================= =================================================== ============= ======== ================= ===================================================
Working with enum fields
************************
``Enum8Field`` and ``Enum16Field`` provide support for working with ClickHouse enum columns. They accept
strings or integers as values, and convert them to the matching Pythonic Enum member.
Python 3.4 and higher supports Enums natively. When using previous Python versions you
need to install the `enum34` library.
Example of a model with an enum field::
Gender = Enum('Gender', 'male female unspecified')
class Person(models.Model):
first_name = fields.StringField()
last_name = fields.StringField()
birthday = fields.DateField()
gender = fields.Enum32Field(Gender)
engine = engines.MergeTree('birthday', ('first_name', 'last_name', 'birthday'))
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
Table Engines Table Engines
------------- -------------

View File

@ -45,6 +45,8 @@ recipe = infi.recipe.console_scripts
eggs = ${project:name} eggs = ${project:name}
ipython ipython
nose nose
coverage
enum34
infi.unittest infi.unittest
infi.traceback infi.traceback
zc.buildout zc.buildout

View File

@ -13,7 +13,7 @@ class Field(object):
def __init__(self, default=None): def __init__(self, default=None):
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
self.default = default or self.class_default self.default = self.class_default if default is None else default
def to_python(self, value): def to_python(self, value):
''' '''
@ -42,6 +42,17 @@ class Field(object):
''' '''
return value return value
def get_sql(self, with_default=True):
'''
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
'''
from .utils import escape
if with_default:
default = self.get_db_prep_value(self.default)
return '%s DEFAULT %s' % (self.db_type, escape(default))
else:
return self.db_type
class StringField(Field): class StringField(Field):
@ -187,3 +198,67 @@ class Float64Field(BaseFloatField):
db_type = 'Float64' db_type = 'Float64'
class BaseEnumField(Field):
def __init__(self, enum_cls, default=None):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default)
def to_python(self, value):
if isinstance(value, self.enum_cls):
return value
try:
if isinstance(value, text_type):
return self.enum_cls[value]
if isinstance(value, binary_type):
return self.enum_cls[value.decode('UTF-8')]
if isinstance(value, int):
return self.enum_cls(value)
except (KeyError, ValueError):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
def get_db_prep_value(self, value):
return value.name
def get_sql(self, with_default=True):
from .utils import escape
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
if with_default:
default = self.get_db_prep_value(self.default)
sql = '%s DEFAULT %s' % (sql, escape(default))
return sql
@classmethod
def create_ad_hoc_field(cls, db_type):
'''
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
'''
import re
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
members = {}
for match in re.finditer("'(\w+)' = (\d+)", db_type):
members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
return field_class(enum_cls)
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
class Enum16Field(BaseEnumField):
db_type = 'Enum16'

View File

@ -68,12 +68,11 @@ class AlterTable(Operation):
if name not in table_fields: if name not in table_fields:
logger.info(' Add column %s', name) logger.info(' Add column %s', name)
assert prev_name, 'Cannot add a column to the beginning of the table' assert prev_name, 'Cannot add a column to the beginning of the table'
default = field.get_db_prep_value(field.default) cmd = 'ADD COLUMN %s %s AFTER %s' % (name, field.get_sql(), prev_name)
cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name)
self._alter_table(database, cmd) self._alter_table(database, cmd)
prev_name = name prev_name = name
# Identify fields whose type was changed # Identify fields whose type was changed
model_fields = [(name, field.db_type) for name, field in self.model_class._fields] model_fields = [(name, field.get_sql(with_default=False)) for name, field in self.model_class._fields]
for model_field, table_field in zip(model_fields, self._get_table_fields(database)): for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement' assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
if model_field[1] != table_field[1]: if model_field[1] != table_field[1]:

View File

@ -37,6 +37,9 @@ class ModelBase(type):
# Create an ad hoc model class # Create an ad hoc model class
attrs = {} attrs = {}
for name, db_type in fields: for name, db_type in fields:
if db_type.startswith('Enum'):
attrs[name] = orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
else:
field_class = db_type + 'Field' field_class = db_type + 'Field'
if not hasattr(orm_fields, field_class): if not hasattr(orm_fields, field_class):
raise NotImplementedError('No field class for %s' % db_type) raise NotImplementedError('No field class for %s' % db_type)
@ -107,8 +110,7 @@ class Model(with_metaclass(ModelBase)):
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls._fields: for name, field in cls._fields:
default = field.get_db_prep_value(field.default) cols.append(' %s %s' % (name, field.get_sql()))
cols.append(' %s %s DEFAULT %s' % (name, field.db_type, escape(default)))
parts.append(',\n'.join(cols)) parts.append(',\n'.join(cols))
parts.append(')') parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql()) parts.append('ENGINE = ' + cls.engine.create_table_sql())

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(EnumModel1)
]

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(EnumModel2)
]

69
tests/test_enum_fields.py Normal file
View File

@ -0,0 +1,69 @@
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
class EnumFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db')
self.database.create_table(ModelWithEnum)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
])
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, ModelWithEnum))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].enum_field, Fruit.apple)
self.assertEquals(results[1].enum_field, Fruit.orange)
def test_ad_hoc_model(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
])
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
results = list(self.database.select(query))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
def test_conversion(self):
self.assertEquals(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
self.assertEquals(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
self.assertEquals(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
def test_assignment_error(self):
for value in (0, 17, 'pear', '', None, 99.9):
with self.assertRaises(ValueError):
ModelWithEnum(enum_field=value)
def test_default_value(self):
instance = ModelWithEnum()
self.assertEquals(instance.enum_field, Fruit.apple)
Fruit = Enum('Fruit', u'apple banana orange')
class ModelWithEnum(Model):
date_field = DateField()
enum_field = Enum8Field(Fruit)
engine = MergeTree('date_field', ('date_field',))

View File

@ -10,6 +10,11 @@ from infi.clickhouse_orm.migrations import MigrationHistory
import sys, os import sys, os
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
import logging import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s') logging.basicConfig(level=logging.DEBUG, format='%(message)s')
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
@ -21,6 +26,9 @@ class MigrationsTestCase(unittest.TestCase):
self.database = Database('test-db') self.database = Database('test-db')
self.database.drop_table(MigrationHistory) self.database.drop_table(MigrationHistory)
def tearDown(self):
self.database.drop_database()
def tableExists(self, model_class): def tableExists(self, model_class):
query = "EXISTS TABLE $db.`%s`" % model_class.table_name() query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
return next(self.database.select(query)).result == 1 return next(self.database.select(query)).result == 1
@ -30,18 +38,28 @@ class MigrationsTestCase(unittest.TestCase):
return [(row.name, row.type) for row in self.database.select(query)] return [(row.name, row.type) for row in self.database.select(query)]
def test_migrations(self): def test_migrations(self):
# Creation and deletion of table
self.database.migrate('tests.sample_migrations', 1) self.database.migrate('tests.sample_migrations', 1)
self.assertTrue(self.tableExists(Model1)) self.assertTrue(self.tableExists(Model1))
self.database.migrate('tests.sample_migrations', 2) self.database.migrate('tests.sample_migrations', 2)
self.assertFalse(self.tableExists(Model1)) self.assertFalse(self.tableExists(Model1))
self.database.migrate('tests.sample_migrations', 3) self.database.migrate('tests.sample_migrations', 3)
self.assertTrue(self.tableExists(Model1)) self.assertTrue(self.tableExists(Model1))
# Adding, removing and altering simple fields
self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
self.database.migrate('tests.sample_migrations', 4) self.database.migrate('tests.sample_migrations', 4)
self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')]) self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
self.database.migrate('tests.sample_migrations', 5) self.database.migrate('tests.sample_migrations', 5)
self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')]) self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
# Altering enum fields
self.database.migrate('tests.sample_migrations', 6)
self.assertTrue(self.tableExists(EnumModel1))
self.assertEquals(self.getTableFields(EnumModel1),
[('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
self.database.migrate('tests.sample_migrations', 7)
self.assertTrue(self.tableExists(EnumModel1))
self.assertEquals(self.getTableFields(EnumModel2),
[('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
# Several different models with the same table name, to simulate a table that changes over time # Several different models with the same table name, to simulate a table that changes over time
@ -86,3 +104,26 @@ class Model3(Model):
def table_name(cls): def table_name(cls):
return 'mig' return 'mig'
class EnumModel1(Model):
date = DateField()
f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow'))
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'enum_mig'
class EnumModel2(Model):
date = DateField()
f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'enum_mig'