Add support for array fields

This commit is contained in:
Itai Shirav 2016-09-01 15:25:48 +03:00
parent 8fc3a31d4b
commit 13bd956fc6
6 changed files with 202 additions and 27 deletions

View File

@ -193,6 +193,7 @@ Float32Field Float32 float
Float64Field Float64 float Float64Field Float64 float
Enum8Field Enum8 Enum See below Enum8Field Enum8 Enum See below
Enum16Field Enum16 Enum See below Enum16Field Enum16 Enum See below
ArrayField Array list See below
============= ======== ================= =================================================== ============= ======== ================= ===================================================
Working with enum fields Working with enum fields
@ -219,6 +220,21 @@ Example of a model with an enum field::
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female) suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
Working with array fields
*************************
You can create array fields containing any data type, for example::
class SensorData(models.Model):
date = fields.DateField()
temperatures = fields.ArrayField(fields.Float32Field)
humidity_levels = fields.ArrayField(fields.UInt8Field)
engine = engines.MergeTree('date', ('date',))
data = SensorData(date=date.today(), temperatures=[25.5, 31.2, 28.7], humidity_levels=[41, 39, 66])
Table Engines Table Engines
------------- -------------

View File

@ -3,6 +3,8 @@ import datetime
import pytz import pytz
import time import time
from .utils import escape, parse_array
class Field(object): class Field(object):
@ -36,20 +38,20 @@ class Field(object):
if value < min_value or value > max_value: if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
def get_db_prep_value(self, value): def to_db_string(self, value, quote=True):
''' '''
Returns the field's value prepared for interacting with the database. Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes.
''' '''
return value return escape(value, quote)
def get_sql(self, with_default=True): def get_sql(self, with_default=True):
''' '''
Returns an SQL expression describing the field (e.g. for CREATE TABLE). Returns an SQL expression describing the field (e.g. for CREATE TABLE).
''' '''
from .utils import escape
if with_default: if with_default:
default = self.get_db_prep_value(self.default) default = self.to_db_string(self.default)
return '%s DEFAULT %s' % (self.db_type, escape(default)) return '%s DEFAULT %s' % (self.db_type, default)
else: else:
return self.db_type return self.db_type
@ -88,8 +90,8 @@ class DateField(Field):
def validate(self, value): def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value) self._range_check(value, DateField.min_value, DateField.max_value)
def get_db_prep_value(self, value): def to_db_string(self, value, quote=True):
return value.isoformat() return escape(value.isoformat(), quote)
class DateTimeField(Field): class DateTimeField(Field):
@ -108,8 +110,8 @@ class DateTimeField(Field):
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def get_db_prep_value(self, value): def to_db_string(self, value, quote=True):
return int(time.mktime(value.timetuple())) return escape(int(time.mktime(value.timetuple())), quote)
class BaseIntField(Field): class BaseIntField(Field):
@ -221,16 +223,15 @@ class BaseEnumField(Field):
pass pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
def get_db_prep_value(self, value): def to_db_string(self, value, quote=True):
return value.name return escape(value.name, quote)
def get_sql(self, with_default=True): def get_sql(self, with_default=True):
from .utils import escape
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
sql = '%s(%s)' % (self.db_type, ' ,'.join(values)) sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
if with_default: if with_default:
default = self.get_db_prep_value(self.default) default = self.to_db_string(self.default)
sql = '%s DEFAULT %s' % (sql, escape(default)) sql = '%s DEFAULT %s' % (sql, default)
return sql return sql
@classmethod @classmethod
@ -262,3 +263,31 @@ class Enum16Field(BaseEnumField):
db_type = 'Enum16' db_type = 'Enum16'
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None):
self.inner_field = inner_field
super(ArrayField, self).__init__(default)
def to_python(self, value):
if isinstance(value, text_type):
value = parse_array(value)
elif isinstance(value, binary_type):
value = parse_array(value.decode('UTF-8'))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
return [self.inner_field.to_python(v) for v in value]
def validate(self, value):
for v in value:
self.inner_field.validate(v)
def to_db_string(self, value, quote=True):
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + ', '.join(array) + ']'
def get_sql(self, with_default=True):
from .utils import escape
return 'Array(%s)' % self.inner_field.get_sql(with_default=False)

View File

@ -4,6 +4,9 @@ from .fields import Field
from six import with_metaclass from six import with_metaclass
from logging import getLogger
logger = getLogger('clickhouse_orm')
class ModelBase(type): class ModelBase(type):
''' '''
@ -28,7 +31,6 @@ class ModelBase(type):
@classmethod @classmethod
def create_ad_hoc_model(cls, fields): def create_ad_hoc_model(cls, fields):
# fields is a list of tuples (name, db_type) # fields is a list of tuples (name, db_type)
import infi.clickhouse_orm.fields as orm_fields
# Check if model exists in cache # Check if model exists in cache
fields = list(fields) fields = list(fields)
cache_key = str(fields) cache_key = str(fields)
@ -37,18 +39,28 @@ class ModelBase(type):
# Create an ad hoc model class # Create an ad hoc model class
attrs = {} attrs = {}
for name, db_type in fields: for name, db_type in fields:
if db_type.startswith('Enum'): attrs[name] = cls.create_ad_hoc_field(db_type)
attrs[name] = orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
else:
field_class = db_type + 'Field'
if not hasattr(orm_fields, field_class):
raise NotImplementedError('No field class for %s' % db_type)
attrs[name] = getattr(orm_fields, field_class)()
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs) model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
# Add the model class to the cache # Add the model class to the cache
cls.ad_hoc_model_cache[cache_key] = model_class cls.ad_hoc_model_cache[cache_key] = model_class
return model_class return model_class
@classmethod
def create_ad_hoc_field(cls, db_type):
import infi.clickhouse_orm.fields as orm_fields
# Enums
if db_type.startswith('Enum'):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# Arrays
if db_type.startswith('Array'):
inner_field = cls.create_ad_hoc_field(db_type[6 : -1])
return orm_fields.ArrayField(inner_field)
# Simple fields
name = db_type + 'Field'
if not hasattr(orm_fields, name):
raise NotImplementedError('No field class for %s' % db_type)
return getattr(orm_fields, name)()
class Model(with_metaclass(ModelBase)): class Model(with_metaclass(ModelBase)):
''' '''
@ -144,6 +156,8 @@ class Model(with_metaclass(ModelBase)):
''' '''
parts = [] parts = []
for name, field in self._fields: for name, field in self._fields:
value = field.get_db_prep_value(field.to_python(getattr(self, name))) value = field.to_db_string(getattr(self, name), quote=False)
parts.append(escape(value, quote=False)) parts.append(value)
return '\t'.join(parts) tsv = '\t'.join(parts)
logger.debug(tsv)
return tsv

View File

@ -1,5 +1,6 @@
from six import string_types, binary_type, text_type, PY3 from six import string_types, binary_type, text_type, PY3
import codecs import codecs
import re
SPECIAL_CHARS = { SPECIAL_CHARS = {
@ -15,6 +16,11 @@ SPECIAL_CHARS = {
def escape(value, quote=True): def escape(value, quote=True):
'''
If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one.
'''
if isinstance(value, string_types): if isinstance(value, string_types):
chars = (SPECIAL_CHARS.get(c, c) for c in value) chars = (SPECIAL_CHARS.get(c, c) for c in value)
value = "'" + "".join(chars) + "'" if quote else "".join(chars) value = "'" + "".join(chars) + "'" if quote else "".join(chars)
@ -33,6 +39,40 @@ def parse_tsv(line):
return [unescape(value) for value in line.split('\t')] return [unescape(value) for value in line.split('\t')]
def parse_array(array_string):
"""
Parse an array string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"]
"[1,2,3]" ==> [1, 2, 3]
"""
# Sanity check
if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']':
raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace
array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left
values = []
while True:
if array_string == ']':
# End of array
return values
elif array_string[0] in ', ':
# In between values
array_string = array_string[1:]
elif array_string[0] == "'":
# Start of quoted value, find its end
match = re.search(r"[^\\]'", array_string)
if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]", array_string)
values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():]
def import_submodules(package_name): def import_submodules(package_name):
""" """
Import all submodules of a module. Import all submodules of a module.

View File

@ -0,0 +1,58 @@
import unittest
from datetime import date
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class ArrayFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db')
self.database.create_table(ModelWithArrays)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithArrays(
date_field='2016-08-30',
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
arr_date=['2010-01-01']
)
self.database.insert([instance])
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
for model_cls in (ModelWithArrays, None):
results = list(self.database.select(query, model_cls))
self.assertEquals(len(results), 1)
self.assertEquals(results[0].arr_str, instance.arr_str)
self.assertEquals(results[0].arr_int, instance.arr_int)
self.assertEquals(results[0].arr_date, instance.arr_date)
def test_conversion(self):
instance = ModelWithArrays(
arr_int=('1', '2', '3'),
arr_date=['2010-01-01']
)
self.assertEquals(instance.arr_str, [])
self.assertEquals(instance.arr_int, [1, 2, 3])
self.assertEquals(instance.arr_date, [date(2010, 1, 1)])
def test_assignment_error(self):
instance = ModelWithArrays()
for value in (7, 'x', [date.today()], ['aaa'], [None]):
with self.assertRaises(ValueError):
instance.arr_int = value
class ModelWithArrays(Model):
date_field = DateField()
arr_str = ArrayField(StringField())
arr_int = ArrayField(Int32Field())
arr_date = ArrayField(DateField())
engine = MergeTree('date_field', ('date_field',))

View File

@ -16,6 +16,7 @@ class EnumFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db') self.database = Database('test-db')
self.database.create_table(ModelWithEnum) self.database.create_table(ModelWithEnum)
self.database.create_table(ModelWithEnumArray)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -39,7 +40,9 @@ class EnumFieldsTest(unittest.TestCase):
query = 'SELECT * from $db.modelwithenum ORDER BY date_field' query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEquals(len(results), 2) self.assertEquals(len(results), 2)
self.assertEquals(results[0].enum_field.name, Fruit.apple.name)
self.assertEquals(results[0].enum_field.value, Fruit.apple.value) self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
self.assertEquals(results[1].enum_field.name, Fruit.orange.name)
self.assertEquals(results[1].enum_field.value, Fruit.orange.value) self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
def test_conversion(self): def test_conversion(self):
@ -56,6 +59,14 @@ class EnumFieldsTest(unittest.TestCase):
instance = ModelWithEnum() instance = ModelWithEnum()
self.assertEquals(instance.enum_field, Fruit.apple) self.assertEquals(instance.enum_field, Fruit.apple)
def test_enum_array(self):
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
self.database.insert([instance])
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, ModelWithEnumArray))
self.assertEquals(len(results), 1)
self.assertEquals(results[0].enum_array, instance.enum_array)
Fruit = Enum('Fruit', u'apple banana orange') Fruit = Enum('Fruit', u'apple banana orange')
@ -67,3 +78,10 @@ class ModelWithEnum(Model):
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree('date_field', ('date_field',))
class ModelWithEnumArray(Model):
date_field = DateField()
enum_array = ArrayField(Enum16Field(Fruit))
engine = MergeTree('date_field', ('date_field',))