Merge pull request #1 from Infinidat/master

Master
This commit is contained in:
emakarov 2017-02-07 18:37:10 +03:00 committed by GitHub
commit e37a4cebb1
13 changed files with 504 additions and 37 deletions

View File

@ -34,6 +34,20 @@ It is possible to provide a default value for a field, instead of its "natural"
See below for the supported field types and table engines.
Table Names
***********
The table name used for the model is its class name, converted to lowercase. To override the default name,
implement the ``table_name`` method::
class Person(models.Model):
...
@classmethod
def table_name(cls):
return 'people'
Using Models
------------
@ -151,7 +165,7 @@ The ``paginate`` method returns a ``namedtuple`` containing the following fields
- ``objects`` - the list of objects in this page
- ``number_of_objects`` - total number of objects in all pages
- ``pages_total`` - total number of pages
- ``number`` - the page number
- ``number`` - the page number, starting from 1; the special value -1 may be used to retrieve the last page
- ``page_size`` - the number of objects per page
You can optionally pass conditions to the query::
@ -191,8 +205,50 @@ UInt32Field UInt32 int Range 0 to 4294967295
UInt64Field UInt64 int/long Range 0 to 18446744073709551615
Float32Field Float32 float
Float64Field Float64 float
Enum8Field Enum8 Enum See below
Enum16Field Enum16 Enum See below
ArrayField Array list See below
============= ======== ================= ===================================================
Working with enum fields
************************
``Enum8Field`` and ``Enum16Field`` provide support for working with ClickHouse enum columns. They accept
strings or integers as values, and convert them to the matching Pythonic Enum member.
Python 3.4 and higher supports Enums natively. When using previous Python versions you
need to install the `enum34` library.
Example of a model with an enum field::
Gender = Enum('Gender', 'male female unspecified')
class Person(models.Model):
first_name = fields.StringField()
last_name = fields.StringField()
birthday = fields.DateField()
gender = fields.Enum32Field(Gender)
engine = engines.MergeTree('birthday', ('first_name', 'last_name', 'birthday'))
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
Working with array fields
*************************
You can create array fields containing any data type, for example::
class SensorData(models.Model):
date = fields.DateField()
temperatures = fields.ArrayField(fields.Float32Field())
humidity_levels = fields.ArrayField(fields.UInt8Field())
engine = engines.MergeTree('date', ('date',))
data = SensorData(date=date.today(), temperatures=[25.5, 31.2, 28.7], humidity_levels=[41, 39, 66])
Table Engines
-------------

View File

@ -45,6 +45,8 @@ recipe = infi.recipe.console_scripts
eggs = ${project:name}
ipython
nose
coverage
enum34
infi.unittest
infi.traceback
zc.buildout

View File

@ -18,12 +18,20 @@ class DatabaseException(Exception):
class Database(object):
def __init__(self, db_name, db_url='http://localhost:8123/', username=None, password=None):
def __init__(self, db_name, db_url='http://localhost:8123/', username=None, password=None, readonly=False):
self.db_name = db_name
self.db_url = db_url
self.username = username
self.password = password
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % db_name)
self.readonly = readonly
if not self.readonly:
self.create_database()
def create_database(self):
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
def drop_database(self):
self._send('DROP DATABASE `%s`' % self.db_name)
def create_table(self, model_class):
# TODO check that model has an engine
@ -32,10 +40,7 @@ class Database(object):
def drop_table(self, model_class):
self._send(model_class.drop_table_sql(self.db_name))
def drop_database(self):
self._send('DROP DATABASE `%s`' % self.db_name)
def insert(self, model_instances):
def insert(self, model_instances, batch_size=1000):
from six import next
i = iter(model_instances)
try:
@ -45,11 +50,19 @@ class Database(object):
model_class = first_instance.__class__
def gen():
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
yield first_instance.to_tsv().encode('utf-8')
yield '\n'.encode('utf-8')
yield (first_instance.to_tsv() + '\n').encode('utf-8')
# Collect lines in batches of batch_size
batch = []
for instance in i:
yield instance.to_tsv().encode('utf-8')
yield '\n'.encode('utf-8')
batch.append(instance.to_tsv())
if len(batch) >= batch_size:
# Return the current batch of lines
yield ('\n'.join(batch) + '\n').encode('utf-8')
# Start a new batch
batch = []
# Return any remaining lines in partial batch
if batch:
yield ('\n'.join(batch) + '\n').encode('utf-8')
self._send(gen())
def count(self, model_class, conditions=None):
@ -74,6 +87,10 @@ class Database(object):
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = pages_total
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table'
if conditions:

View File

@ -3,6 +3,8 @@ import datetime
import pytz
import time
from .utils import escape, parse_array
class Field(object):
@ -13,7 +15,7 @@ class Field(object):
def __init__(self, default=None):
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.default = default or self.class_default
self.default = self.class_default if default is None else default
def to_python(self, value):
'''
@ -36,11 +38,22 @@ class Field(object):
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
def get_db_prep_value(self, value):
def to_db_string(self, value, quote=True):
'''
Returns the field's value prepared for interacting with the database.
Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes.
'''
return value
return escape(value, quote)
def get_sql(self, with_default=True):
'''
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
'''
if with_default:
default = self.to_db_string(self.default)
return '%s DEFAULT %s' % (self.db_type, default)
else:
return self.db_type
class StringField(Field):
@ -66,6 +79,8 @@ class DateField(Field):
def to_python(self, value):
if isinstance(value, datetime.date):
return value
if isinstance(value, datetime.datetime):
return value.date()
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, string_types):
@ -77,8 +92,8 @@ class DateField(Field):
def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value)
def get_db_prep_value(self, value):
return value.isoformat()
def to_db_string(self, value, quote=True):
return escape(value.isoformat(), quote)
class DateTimeField(Field):
@ -94,11 +109,13 @@ class DateTimeField(Field):
if isinstance(value, int):
return datetime.datetime.fromtimestamp(value, pytz.utc)
if isinstance(value, string_types):
if value == '0000-00-00 00:00:00':
return self.class_default
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def get_db_prep_value(self, value):
return int(time.mktime(value.timetuple()))
def to_db_string(self, value, quote=True):
return escape(int(time.mktime(value.timetuple())), quote)
class BaseIntField(Field):
@ -187,3 +204,94 @@ class Float64Field(BaseFloatField):
db_type = 'Float64'
class BaseEnumField(Field):
def __init__(self, enum_cls, default=None):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default)
def to_python(self, value):
if isinstance(value, self.enum_cls):
return value
try:
if isinstance(value, text_type):
return self.enum_cls[value]
if isinstance(value, binary_type):
return self.enum_cls[value.decode('UTF-8')]
if isinstance(value, int):
return self.enum_cls(value)
except (KeyError, ValueError):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True):
return escape(value.name, quote)
def get_sql(self, with_default=True):
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
if with_default:
default = self.to_db_string(self.default)
sql = '%s DEFAULT %s' % (sql, default)
return sql
@classmethod
def create_ad_hoc_field(cls, db_type):
'''
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
'''
import re
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
members = {}
for match in re.finditer("'(\w+)' = (\d+)", db_type):
members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
return field_class(enum_cls)
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
class Enum16Field(BaseEnumField):
db_type = 'Enum16'
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None):
self.inner_field = inner_field
super(ArrayField, self).__init__(default)
def to_python(self, value):
if isinstance(value, text_type):
value = parse_array(value)
elif isinstance(value, binary_type):
value = parse_array(value.decode('UTF-8'))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
return [self.inner_field.to_python(v) for v in value]
def validate(self, value):
for v in value:
self.inner_field.validate(v)
def to_db_string(self, value, quote=True):
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + ', '.join(array) + ']'
def get_sql(self, with_default=True):
from .utils import escape
return 'Array(%s)' % self.inner_field.get_sql(with_default=False)

View File

@ -68,12 +68,11 @@ class AlterTable(Operation):
if name not in table_fields:
logger.info(' Add column %s', name)
assert prev_name, 'Cannot add a column to the beginning of the table'
default = field.get_db_prep_value(field.default)
cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name)
cmd = 'ADD COLUMN %s %s AFTER %s' % (name, field.get_sql(), prev_name)
self._alter_table(database, cmd)
prev_name = name
# Identify fields whose type was changed
model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
model_fields = [(name, field.get_sql(with_default=False)) for name, field in self.model_class._fields]
for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
if model_field[1] != table_field[1]:

View File

@ -4,6 +4,9 @@ from .fields import Field
from six import with_metaclass
from logging import getLogger
logger = getLogger('clickhouse_orm')
class ModelBase(type):
'''
@ -28,7 +31,6 @@ class ModelBase(type):
@classmethod
def create_ad_hoc_model(cls, fields):
# fields is a list of tuples (name, db_type)
import infi.clickhouse_orm.fields as orm_fields
# Check if model exists in cache
fields = list(fields)
cache_key = str(fields)
@ -37,15 +39,28 @@ class ModelBase(type):
# Create an ad hoc model class
attrs = {}
for name, db_type in fields:
field_class = db_type + 'Field'
if not hasattr(orm_fields, field_class):
raise NotImplementedError('No field class for %s' % db_type)
attrs[name] = getattr(orm_fields, field_class)()
attrs[name] = cls.create_ad_hoc_field(db_type)
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
# Add the model class to the cache
cls.ad_hoc_model_cache[cache_key] = model_class
return model_class
@classmethod
def create_ad_hoc_field(cls, db_type):
import infi.clickhouse_orm.fields as orm_fields
# Enums
if db_type.startswith('Enum'):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# Arrays
if db_type.startswith('Array'):
inner_field = cls.create_ad_hoc_field(db_type[6 : -1])
return orm_fields.ArrayField(inner_field)
# Simple fields
name = db_type + 'Field'
if not hasattr(orm_fields, name):
raise NotImplementedError('No field class for %s' % db_type)
return getattr(orm_fields, name)()
class Model(with_metaclass(ModelBase)):
'''
@ -107,8 +122,7 @@ class Model(with_metaclass(ModelBase)):
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
cols = []
for name, field in cls._fields:
default = field.get_db_prep_value(field.default)
cols.append(' %s %s DEFAULT %s' % (name, field.db_type, escape(default)))
cols.append(' %s %s' % (name, field.get_sql()))
parts.append(',\n'.join(cols))
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql())
@ -140,8 +154,5 @@ class Model(with_metaclass(ModelBase)):
'''
Returns the instance's column values as a tab-separated line. A newline is not included.
'''
parts = []
for name, field in self._fields:
value = field.get_db_prep_value(field.to_python(getattr(self, name)))
parts.append(escape(value, quote=False))
return '\t'.join(parts)
data = self.__dict__
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in self._fields)

View File

@ -1,5 +1,6 @@
from six import string_types, binary_type, text_type, PY3
import codecs
import re
SPECIAL_CHARS = {
@ -13,11 +14,20 @@ SPECIAL_CHARS = {
"'" : "\\'"
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True):
'''
If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one.
'''
if isinstance(value, string_types):
chars = (SPECIAL_CHARS.get(c, c) for c in value)
value = "'" + "".join(chars) + "'" if quote else "".join(chars)
if SPECIAL_CHARS_REGEX.search(value):
value = "".join(SPECIAL_CHARS.get(c, c) for c in value)
if quote:
value = "'" + value + "'"
return text_type(value)
@ -33,6 +43,40 @@ def parse_tsv(line):
return [unescape(value) for value in line.split('\t')]
def parse_array(array_string):
"""
Parse an array string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"]
"[1,2,3]" ==> [1, 2, 3]
"""
# Sanity check
if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']':
raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace
array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left
values = []
while True:
if array_string == ']':
# End of array
return values
elif array_string[0] in ', ':
# In between values
array_string = array_string[1:]
elif array_string[0] == "'":
# Start of quoted value, find its end
match = re.search(r"[^\\]'", array_string)
if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]", array_string)
values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1:]
def import_submodules(package_name):
"""
Import all submodules of a module.

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(EnumModel1)
]

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(EnumModel2)
]

View File

@ -0,0 +1,73 @@
import unittest
from datetime import date
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class ArrayFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db')
self.database.create_table(ModelWithArrays)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithArrays(
date_field='2016-08-30',
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
arr_date=['2010-01-01']
)
self.database.insert([instance])
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
for model_cls in (ModelWithArrays, None):
results = list(self.database.select(query, model_cls))
self.assertEquals(len(results), 1)
self.assertEquals(results[0].arr_str, instance.arr_str)
self.assertEquals(results[0].arr_int, instance.arr_int)
self.assertEquals(results[0].arr_date, instance.arr_date)
def test_conversion(self):
instance = ModelWithArrays(
arr_int=('1', '2', '3'),
arr_date=['2010-01-01']
)
self.assertEquals(instance.arr_str, [])
self.assertEquals(instance.arr_int, [1, 2, 3])
self.assertEquals(instance.arr_date, [date(2010, 1, 1)])
def test_assignment_error(self):
instance = ModelWithArrays()
for value in (7, 'x', [date.today()], ['aaa'], [None]):
with self.assertRaises(ValueError):
instance.arr_int = value
def test_parse_array(self):
from infi.clickhouse_orm.utils import parse_array, unescape
self.assertEquals(parse_array("[]"), [])
self.assertEquals(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
self.assertEquals(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
self.assertEquals(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
for s in ("",
"[",
"]",
"[1, 2",
"3, 4]",
"['aaa', 'aaa]"):
with self.assertRaises(ValueError):
parse_array(s)
class ModelWithArrays(Model):
date_field = DateField()
arr_str = ArrayField(StringField())
arr_int = ArrayField(Int32Field())
arr_date = ArrayField(DateField())
engine = MergeTree('date_field', ('date_field',))

View File

@ -93,6 +93,23 @@ class DatabaseTestCase(unittest.TestCase):
# Verify that all instances were returned
self.assertEquals(len(instances), len(data))
def test_pagination_last_page(self):
self._insert_and_check(self._sample_data(), len(data))
# Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality
page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size)
page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size)
self.assertEquals(page_a[1:], page_b[1:])
self.assertEquals([obj.to_tsv() for obj in page_a.objects],
[obj.to_tsv() for obj in page_b.objects])
def test_pagination_invalid_page(self):
self._insert_and_check(self._sample_data(), len(data))
for page_num in (0, -2, -100):
with self.assertRaises(ValueError):
self.database.paginate(Person, 'first_name, last_name', page_num, 100)
def test_special_chars(self):
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
p = Person(first_name=s)

87
tests/test_enum_fields.py Normal file
View File

@ -0,0 +1,87 @@
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
class EnumFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db')
self.database.create_table(ModelWithEnum)
self.database.create_table(ModelWithEnumArray)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
])
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, ModelWithEnum))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].enum_field, Fruit.apple)
self.assertEquals(results[1].enum_field, Fruit.orange)
def test_ad_hoc_model(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
])
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
results = list(self.database.select(query))
self.assertEquals(len(results), 2)
self.assertEquals(results[0].enum_field.name, Fruit.apple.name)
self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
self.assertEquals(results[1].enum_field.name, Fruit.orange.name)
self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
def test_conversion(self):
self.assertEquals(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
self.assertEquals(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
self.assertEquals(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
def test_assignment_error(self):
for value in (0, 17, 'pear', '', None, 99.9):
with self.assertRaises(ValueError):
ModelWithEnum(enum_field=value)
def test_default_value(self):
instance = ModelWithEnum()
self.assertEquals(instance.enum_field, Fruit.apple)
def test_enum_array(self):
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
self.database.insert([instance])
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, ModelWithEnumArray))
self.assertEquals(len(results), 1)
self.assertEquals(results[0].enum_array, instance.enum_array)
Fruit = Enum('Fruit', u'apple banana orange')
class ModelWithEnum(Model):
date_field = DateField()
enum_field = Enum8Field(Fruit)
engine = MergeTree('date_field', ('date_field',))
class ModelWithEnumArray(Model):
date_field = DateField()
enum_array = ArrayField(Enum16Field(Fruit))
engine = MergeTree('date_field', ('date_field',))

View File

@ -10,6 +10,11 @@ from infi.clickhouse_orm.migrations import MigrationHistory
import sys, os
sys.path.append(os.path.dirname(__file__))
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
logging.getLogger("requests").setLevel(logging.WARNING)
@ -21,6 +26,9 @@ class MigrationsTestCase(unittest.TestCase):
self.database = Database('test-db')
self.database.drop_table(MigrationHistory)
def tearDown(self):
self.database.drop_database()
def tableExists(self, model_class):
query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
return next(self.database.select(query)).result == 1
@ -30,18 +38,28 @@ class MigrationsTestCase(unittest.TestCase):
return [(row.name, row.type) for row in self.database.select(query)]
def test_migrations(self):
# Creation and deletion of table
self.database.migrate('tests.sample_migrations', 1)
self.assertTrue(self.tableExists(Model1))
self.database.migrate('tests.sample_migrations', 2)
self.assertFalse(self.tableExists(Model1))
self.database.migrate('tests.sample_migrations', 3)
self.assertTrue(self.tableExists(Model1))
# Adding, removing and altering simple fields
self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
self.database.migrate('tests.sample_migrations', 4)
self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
self.database.migrate('tests.sample_migrations', 5)
self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
# Altering enum fields
self.database.migrate('tests.sample_migrations', 6)
self.assertTrue(self.tableExists(EnumModel1))
self.assertEquals(self.getTableFields(EnumModel1),
[('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
self.database.migrate('tests.sample_migrations', 7)
self.assertTrue(self.tableExists(EnumModel1))
self.assertEquals(self.getTableFields(EnumModel2),
[('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
# Several different models with the same table name, to simulate a table that changes over time
@ -86,3 +104,26 @@ class Model3(Model):
def table_name(cls):
return 'mig'
class EnumModel1(Model):
date = DateField()
f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow'))
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'enum_mig'
class EnumModel2(Model):
date = DateField()
f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'enum_mig'