mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-25 10:13:45 +03:00
Add support for enum fields
This commit is contained in:
parent
c0af06875c
commit
8fc3a31d4b
26
README.rst
26
README.rst
|
@ -191,8 +191,34 @@ UInt32Field UInt32 int Range 0 to 4294967295
|
||||||
UInt64Field UInt64 int/long Range 0 to 18446744073709551615
|
UInt64Field UInt64 int/long Range 0 to 18446744073709551615
|
||||||
Float32Field Float32 float
|
Float32Field Float32 float
|
||||||
Float64Field Float64 float
|
Float64Field Float64 float
|
||||||
|
Enum8Field Enum8 Enum See below
|
||||||
|
Enum16Field Enum16 Enum See below
|
||||||
============= ======== ================= ===================================================
|
============= ======== ================= ===================================================
|
||||||
|
|
||||||
|
Working with enum fields
|
||||||
|
************************
|
||||||
|
|
||||||
|
``Enum8Field`` and ``Enum16Field`` provide support for working with ClickHouse enum columns. They accept
|
||||||
|
strings or integers as values, and convert them to the matching Pythonic Enum member.
|
||||||
|
|
||||||
|
Python 3.4 and higher supports Enums natively. When using previous Python versions you
|
||||||
|
need to install the `enum34` library.
|
||||||
|
|
||||||
|
Example of a model with an enum field::
|
||||||
|
|
||||||
|
Gender = Enum('Gender', 'male female unspecified')
|
||||||
|
|
||||||
|
class Person(models.Model):
|
||||||
|
|
||||||
|
first_name = fields.StringField()
|
||||||
|
last_name = fields.StringField()
|
||||||
|
birthday = fields.DateField()
|
||||||
|
gender = fields.Enum32Field(Gender)
|
||||||
|
|
||||||
|
engine = engines.MergeTree('birthday', ('first_name', 'last_name', 'birthday'))
|
||||||
|
|
||||||
|
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
|
||||||
|
|
||||||
Table Engines
|
Table Engines
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,8 @@ recipe = infi.recipe.console_scripts
|
||||||
eggs = ${project:name}
|
eggs = ${project:name}
|
||||||
ipython
|
ipython
|
||||||
nose
|
nose
|
||||||
|
coverage
|
||||||
|
enum34
|
||||||
infi.unittest
|
infi.unittest
|
||||||
infi.traceback
|
infi.traceback
|
||||||
zc.buildout
|
zc.buildout
|
||||||
|
|
|
@ -13,7 +13,7 @@ class Field(object):
|
||||||
def __init__(self, default=None):
|
def __init__(self, default=None):
|
||||||
self.creation_counter = Field.creation_counter
|
self.creation_counter = Field.creation_counter
|
||||||
Field.creation_counter += 1
|
Field.creation_counter += 1
|
||||||
self.default = default or self.class_default
|
self.default = self.class_default if default is None else default
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
'''
|
'''
|
||||||
|
@ -42,6 +42,17 @@ class Field(object):
|
||||||
'''
|
'''
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def get_sql(self, with_default=True):
|
||||||
|
'''
|
||||||
|
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
|
||||||
|
'''
|
||||||
|
from .utils import escape
|
||||||
|
if with_default:
|
||||||
|
default = self.get_db_prep_value(self.default)
|
||||||
|
return '%s DEFAULT %s' % (self.db_type, escape(default))
|
||||||
|
else:
|
||||||
|
return self.db_type
|
||||||
|
|
||||||
|
|
||||||
class StringField(Field):
|
class StringField(Field):
|
||||||
|
|
||||||
|
@ -187,3 +198,67 @@ class Float64Field(BaseFloatField):
|
||||||
|
|
||||||
db_type = 'Float64'
|
db_type = 'Float64'
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEnumField(Field):
|
||||||
|
|
||||||
|
def __init__(self, enum_cls, default=None):
|
||||||
|
self.enum_cls = enum_cls
|
||||||
|
if default is None:
|
||||||
|
default = list(enum_cls)[0]
|
||||||
|
super(BaseEnumField, self).__init__(default)
|
||||||
|
|
||||||
|
def to_python(self, value):
|
||||||
|
if isinstance(value, self.enum_cls):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
if isinstance(value, text_type):
|
||||||
|
return self.enum_cls[value]
|
||||||
|
if isinstance(value, binary_type):
|
||||||
|
return self.enum_cls[value.decode('UTF-8')]
|
||||||
|
if isinstance(value, int):
|
||||||
|
return self.enum_cls(value)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
pass
|
||||||
|
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
|
||||||
|
|
||||||
|
def get_db_prep_value(self, value):
|
||||||
|
return value.name
|
||||||
|
|
||||||
|
def get_sql(self, with_default=True):
|
||||||
|
from .utils import escape
|
||||||
|
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
|
||||||
|
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
|
||||||
|
if with_default:
|
||||||
|
default = self.get_db_prep_value(self.default)
|
||||||
|
sql = '%s DEFAULT %s' % (sql, escape(default))
|
||||||
|
return sql
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_ad_hoc_field(cls, db_type):
|
||||||
|
'''
|
||||||
|
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
|
||||||
|
this method returns a matching enum field.
|
||||||
|
'''
|
||||||
|
import re
|
||||||
|
try:
|
||||||
|
Enum # exists in Python 3.4+
|
||||||
|
except NameError:
|
||||||
|
from enum import Enum # use the enum34 library instead
|
||||||
|
members = {}
|
||||||
|
for match in re.finditer("'(\w+)' = (\d+)", db_type):
|
||||||
|
members[match.group(1)] = int(match.group(2))
|
||||||
|
enum_cls = Enum('AdHocEnum', members)
|
||||||
|
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
|
||||||
|
return field_class(enum_cls)
|
||||||
|
|
||||||
|
|
||||||
|
class Enum8Field(BaseEnumField):
|
||||||
|
|
||||||
|
db_type = 'Enum8'
|
||||||
|
|
||||||
|
|
||||||
|
class Enum16Field(BaseEnumField):
|
||||||
|
|
||||||
|
db_type = 'Enum16'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -68,12 +68,11 @@ class AlterTable(Operation):
|
||||||
if name not in table_fields:
|
if name not in table_fields:
|
||||||
logger.info(' Add column %s', name)
|
logger.info(' Add column %s', name)
|
||||||
assert prev_name, 'Cannot add a column to the beginning of the table'
|
assert prev_name, 'Cannot add a column to the beginning of the table'
|
||||||
default = field.get_db_prep_value(field.default)
|
cmd = 'ADD COLUMN %s %s AFTER %s' % (name, field.get_sql(), prev_name)
|
||||||
cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name)
|
|
||||||
self._alter_table(database, cmd)
|
self._alter_table(database, cmd)
|
||||||
prev_name = name
|
prev_name = name
|
||||||
# Identify fields whose type was changed
|
# Identify fields whose type was changed
|
||||||
model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
|
model_fields = [(name, field.get_sql(with_default=False)) for name, field in self.model_class._fields]
|
||||||
for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
|
for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
|
||||||
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
|
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
|
||||||
if model_field[1] != table_field[1]:
|
if model_field[1] != table_field[1]:
|
||||||
|
|
|
@ -37,10 +37,13 @@ class ModelBase(type):
|
||||||
# Create an ad hoc model class
|
# Create an ad hoc model class
|
||||||
attrs = {}
|
attrs = {}
|
||||||
for name, db_type in fields:
|
for name, db_type in fields:
|
||||||
field_class = db_type + 'Field'
|
if db_type.startswith('Enum'):
|
||||||
if not hasattr(orm_fields, field_class):
|
attrs[name] = orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
|
||||||
raise NotImplementedError('No field class for %s' % db_type)
|
else:
|
||||||
attrs[name] = getattr(orm_fields, field_class)()
|
field_class = db_type + 'Field'
|
||||||
|
if not hasattr(orm_fields, field_class):
|
||||||
|
raise NotImplementedError('No field class for %s' % db_type)
|
||||||
|
attrs[name] = getattr(orm_fields, field_class)()
|
||||||
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
|
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
|
||||||
# Add the model class to the cache
|
# Add the model class to the cache
|
||||||
cls.ad_hoc_model_cache[cache_key] = model_class
|
cls.ad_hoc_model_cache[cache_key] = model_class
|
||||||
|
@ -107,8 +110,7 @@ class Model(with_metaclass(ModelBase)):
|
||||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
|
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
|
||||||
cols = []
|
cols = []
|
||||||
for name, field in cls._fields:
|
for name, field in cls._fields:
|
||||||
default = field.get_db_prep_value(field.default)
|
cols.append(' %s %s' % (name, field.get_sql()))
|
||||||
cols.append(' %s %s DEFAULT %s' % (name, field.db_type, escape(default)))
|
|
||||||
parts.append(',\n'.join(cols))
|
parts.append(',\n'.join(cols))
|
||||||
parts.append(')')
|
parts.append(')')
|
||||||
parts.append('ENGINE = ' + cls.engine.create_table_sql())
|
parts.append('ENGINE = ' + cls.engine.create_table_sql())
|
||||||
|
|
6
tests/sample_migrations/0006.py
Normal file
6
tests/sample_migrations/0006.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from infi.clickhouse_orm import migrations
|
||||||
|
from ..test_migrations import *
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateTable(EnumModel1)
|
||||||
|
]
|
6
tests/sample_migrations/0007.py
Normal file
6
tests/sample_migrations/0007.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from infi.clickhouse_orm import migrations
|
||||||
|
from ..test_migrations import *
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterTable(EnumModel2)
|
||||||
|
]
|
69
tests/test_enum_fields.py
Normal file
69
tests/test_enum_fields.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from infi.clickhouse_orm.database import Database
|
||||||
|
from infi.clickhouse_orm.models import Model
|
||||||
|
from infi.clickhouse_orm.fields import *
|
||||||
|
from infi.clickhouse_orm.engines import *
|
||||||
|
|
||||||
|
try:
|
||||||
|
Enum # exists in Python 3.4+
|
||||||
|
except NameError:
|
||||||
|
from enum import Enum # use the enum34 library instead
|
||||||
|
|
||||||
|
|
||||||
|
class EnumFieldsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.database = Database('test-db')
|
||||||
|
self.database.create_table(ModelWithEnum)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.database.drop_database()
|
||||||
|
|
||||||
|
def test_insert_and_select(self):
|
||||||
|
self.database.insert([
|
||||||
|
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
|
||||||
|
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
|
||||||
|
])
|
||||||
|
query = 'SELECT * from $table ORDER BY date_field'
|
||||||
|
results = list(self.database.select(query, ModelWithEnum))
|
||||||
|
self.assertEquals(len(results), 2)
|
||||||
|
self.assertEquals(results[0].enum_field, Fruit.apple)
|
||||||
|
self.assertEquals(results[1].enum_field, Fruit.orange)
|
||||||
|
|
||||||
|
def test_ad_hoc_model(self):
|
||||||
|
self.database.insert([
|
||||||
|
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
|
||||||
|
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
|
||||||
|
])
|
||||||
|
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
|
||||||
|
results = list(self.database.select(query))
|
||||||
|
self.assertEquals(len(results), 2)
|
||||||
|
self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
|
||||||
|
self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
|
||||||
|
|
||||||
|
def test_conversion(self):
|
||||||
|
self.assertEquals(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
|
||||||
|
self.assertEquals(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
|
||||||
|
self.assertEquals(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
|
||||||
|
|
||||||
|
def test_assignment_error(self):
|
||||||
|
for value in (0, 17, 'pear', '', None, 99.9):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ModelWithEnum(enum_field=value)
|
||||||
|
|
||||||
|
def test_default_value(self):
|
||||||
|
instance = ModelWithEnum()
|
||||||
|
self.assertEquals(instance.enum_field, Fruit.apple)
|
||||||
|
|
||||||
|
|
||||||
|
Fruit = Enum('Fruit', u'apple banana orange')
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithEnum(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
enum_field = Enum8Field(Fruit)
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('date_field',))
|
||||||
|
|
|
@ -10,6 +10,11 @@ from infi.clickhouse_orm.migrations import MigrationHistory
|
||||||
import sys, os
|
import sys, os
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
try:
|
||||||
|
Enum # exists in Python 3.4+
|
||||||
|
except NameError:
|
||||||
|
from enum import Enum # use the enum34 library instead
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
|
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
|
||||||
logging.getLogger("requests").setLevel(logging.WARNING)
|
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||||
|
@ -21,6 +26,9 @@ class MigrationsTestCase(unittest.TestCase):
|
||||||
self.database = Database('test-db')
|
self.database = Database('test-db')
|
||||||
self.database.drop_table(MigrationHistory)
|
self.database.drop_table(MigrationHistory)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.database.drop_database()
|
||||||
|
|
||||||
def tableExists(self, model_class):
|
def tableExists(self, model_class):
|
||||||
query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
|
query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
|
||||||
return next(self.database.select(query)).result == 1
|
return next(self.database.select(query)).result == 1
|
||||||
|
@ -30,18 +38,28 @@ class MigrationsTestCase(unittest.TestCase):
|
||||||
return [(row.name, row.type) for row in self.database.select(query)]
|
return [(row.name, row.type) for row in self.database.select(query)]
|
||||||
|
|
||||||
def test_migrations(self):
|
def test_migrations(self):
|
||||||
|
# Creation and deletion of table
|
||||||
self.database.migrate('tests.sample_migrations', 1)
|
self.database.migrate('tests.sample_migrations', 1)
|
||||||
self.assertTrue(self.tableExists(Model1))
|
self.assertTrue(self.tableExists(Model1))
|
||||||
self.database.migrate('tests.sample_migrations', 2)
|
self.database.migrate('tests.sample_migrations', 2)
|
||||||
self.assertFalse(self.tableExists(Model1))
|
self.assertFalse(self.tableExists(Model1))
|
||||||
self.database.migrate('tests.sample_migrations', 3)
|
self.database.migrate('tests.sample_migrations', 3)
|
||||||
self.assertTrue(self.tableExists(Model1))
|
self.assertTrue(self.tableExists(Model1))
|
||||||
|
# Adding, removing and altering simple fields
|
||||||
self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
|
self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
|
||||||
self.database.migrate('tests.sample_migrations', 4)
|
self.database.migrate('tests.sample_migrations', 4)
|
||||||
self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
|
self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
|
||||||
self.database.migrate('tests.sample_migrations', 5)
|
self.database.migrate('tests.sample_migrations', 5)
|
||||||
self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
|
self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
|
||||||
|
# Altering enum fields
|
||||||
|
self.database.migrate('tests.sample_migrations', 6)
|
||||||
|
self.assertTrue(self.tableExists(EnumModel1))
|
||||||
|
self.assertEquals(self.getTableFields(EnumModel1),
|
||||||
|
[('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
|
||||||
|
self.database.migrate('tests.sample_migrations', 7)
|
||||||
|
self.assertTrue(self.tableExists(EnumModel1))
|
||||||
|
self.assertEquals(self.getTableFields(EnumModel2),
|
||||||
|
[('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
|
||||||
|
|
||||||
# Several different models with the same table name, to simulate a table that changes over time
|
# Several different models with the same table name, to simulate a table that changes over time
|
||||||
|
|
||||||
|
@ -86,3 +104,26 @@ class Model3(Model):
|
||||||
def table_name(cls):
|
def table_name(cls):
|
||||||
return 'mig'
|
return 'mig'
|
||||||
|
|
||||||
|
|
||||||
|
class EnumModel1(Model):
|
||||||
|
|
||||||
|
date = DateField()
|
||||||
|
f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow'))
|
||||||
|
|
||||||
|
engine = MergeTree('date', ('date',))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def table_name(cls):
|
||||||
|
return 'enum_mig'
|
||||||
|
|
||||||
|
|
||||||
|
class EnumModel2(Model):
|
||||||
|
|
||||||
|
date = DateField()
|
||||||
|
f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values
|
||||||
|
|
||||||
|
engine = MergeTree('date', ('date',))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def table_name(cls):
|
||||||
|
return 'enum_mig'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user