mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-22 17:16:34 +03:00
Add support for array fields
This commit is contained in:
parent
8fc3a31d4b
commit
13bd956fc6
16
README.rst
16
README.rst
|
@ -193,6 +193,7 @@ Float32Field Float32 float
|
||||||
Float64Field Float64 float
|
Float64Field Float64 float
|
||||||
Enum8Field Enum8 Enum See below
|
Enum8Field Enum8 Enum See below
|
||||||
Enum16Field Enum16 Enum See below
|
Enum16Field Enum16 Enum See below
|
||||||
|
ArrayField Array list See below
|
||||||
============= ======== ================= ===================================================
|
============= ======== ================= ===================================================
|
||||||
|
|
||||||
Working with enum fields
|
Working with enum fields
|
||||||
|
@ -219,6 +220,21 @@ Example of a model with an enum field::
|
||||||
|
|
||||||
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
|
suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
|
||||||
|
|
||||||
|
Working with array fields
|
||||||
|
*************************
|
||||||
|
|
||||||
|
You can create array fields containing any data type, for example::
|
||||||
|
|
||||||
|
class SensorData(models.Model):
|
||||||
|
|
||||||
|
date = fields.DateField()
|
||||||
|
temperatures = fields.ArrayField(fields.Float32Field)
|
||||||
|
humidity_levels = fields.ArrayField(fields.UInt8Field)
|
||||||
|
|
||||||
|
engine = engines.MergeTree('date', ('date',))
|
||||||
|
|
||||||
|
data = SensorData(date=date.today(), temperatures=[25.5, 31.2, 28.7], humidity_levels=[41, 39, 66])
|
||||||
|
|
||||||
Table Engines
|
Table Engines
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,8 @@ import datetime
|
||||||
import pytz
|
import pytz
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from .utils import escape, parse_array
|
||||||
|
|
||||||
|
|
||||||
class Field(object):
|
class Field(object):
|
||||||
|
|
||||||
|
@ -36,20 +38,20 @@ class Field(object):
|
||||||
if value < min_value or value > max_value:
|
if value < min_value or value > max_value:
|
||||||
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
|
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def to_db_string(self, value, quote=True):
|
||||||
'''
|
'''
|
||||||
Returns the field's value prepared for interacting with the database.
|
Returns the field's value prepared for writing to the database.
|
||||||
|
When quote is true, strings are surrounded by single quotes.
|
||||||
'''
|
'''
|
||||||
return value
|
return escape(value, quote)
|
||||||
|
|
||||||
def get_sql(self, with_default=True):
|
def get_sql(self, with_default=True):
|
||||||
'''
|
'''
|
||||||
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
|
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
|
||||||
'''
|
'''
|
||||||
from .utils import escape
|
|
||||||
if with_default:
|
if with_default:
|
||||||
default = self.get_db_prep_value(self.default)
|
default = self.to_db_string(self.default)
|
||||||
return '%s DEFAULT %s' % (self.db_type, escape(default))
|
return '%s DEFAULT %s' % (self.db_type, default)
|
||||||
else:
|
else:
|
||||||
return self.db_type
|
return self.db_type
|
||||||
|
|
||||||
|
@ -88,8 +90,8 @@ class DateField(Field):
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
self._range_check(value, DateField.min_value, DateField.max_value)
|
self._range_check(value, DateField.min_value, DateField.max_value)
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def to_db_string(self, value, quote=True):
|
||||||
return value.isoformat()
|
return escape(value.isoformat(), quote)
|
||||||
|
|
||||||
|
|
||||||
class DateTimeField(Field):
|
class DateTimeField(Field):
|
||||||
|
@ -108,8 +110,8 @@ class DateTimeField(Field):
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def to_db_string(self, value, quote=True):
|
||||||
return int(time.mktime(value.timetuple()))
|
return escape(int(time.mktime(value.timetuple())), quote)
|
||||||
|
|
||||||
|
|
||||||
class BaseIntField(Field):
|
class BaseIntField(Field):
|
||||||
|
@ -221,16 +223,15 @@ class BaseEnumField(Field):
|
||||||
pass
|
pass
|
||||||
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
|
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def to_db_string(self, value, quote=True):
|
||||||
return value.name
|
return escape(value.name, quote)
|
||||||
|
|
||||||
def get_sql(self, with_default=True):
|
def get_sql(self, with_default=True):
|
||||||
from .utils import escape
|
|
||||||
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
|
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
|
||||||
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
|
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
|
||||||
if with_default:
|
if with_default:
|
||||||
default = self.get_db_prep_value(self.default)
|
default = self.to_db_string(self.default)
|
||||||
sql = '%s DEFAULT %s' % (sql, escape(default))
|
sql = '%s DEFAULT %s' % (sql, default)
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -262,3 +263,31 @@ class Enum16Field(BaseEnumField):
|
||||||
db_type = 'Enum16'
|
db_type = 'Enum16'
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayField(Field):
|
||||||
|
|
||||||
|
class_default = []
|
||||||
|
|
||||||
|
def __init__(self, inner_field, default=None):
|
||||||
|
self.inner_field = inner_field
|
||||||
|
super(ArrayField, self).__init__(default)
|
||||||
|
|
||||||
|
def to_python(self, value):
|
||||||
|
if isinstance(value, text_type):
|
||||||
|
value = parse_array(value)
|
||||||
|
elif isinstance(value, binary_type):
|
||||||
|
value = parse_array(value.decode('UTF-8'))
|
||||||
|
elif not isinstance(value, (list, tuple)):
|
||||||
|
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
|
||||||
|
return [self.inner_field.to_python(v) for v in value]
|
||||||
|
|
||||||
|
def validate(self, value):
|
||||||
|
for v in value:
|
||||||
|
self.inner_field.validate(v)
|
||||||
|
|
||||||
|
def to_db_string(self, value, quote=True):
|
||||||
|
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
|
||||||
|
return '[' + ', '.join(array) + ']'
|
||||||
|
|
||||||
|
def get_sql(self, with_default=True):
|
||||||
|
from .utils import escape
|
||||||
|
return 'Array(%s)' % self.inner_field.get_sql(with_default=False)
|
||||||
|
|
|
@ -4,6 +4,9 @@ from .fields import Field
|
||||||
|
|
||||||
from six import with_metaclass
|
from six import with_metaclass
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger('clickhouse_orm')
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(type):
|
class ModelBase(type):
|
||||||
'''
|
'''
|
||||||
|
@ -28,7 +31,6 @@ class ModelBase(type):
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ad_hoc_model(cls, fields):
|
def create_ad_hoc_model(cls, fields):
|
||||||
# fields is a list of tuples (name, db_type)
|
# fields is a list of tuples (name, db_type)
|
||||||
import infi.clickhouse_orm.fields as orm_fields
|
|
||||||
# Check if model exists in cache
|
# Check if model exists in cache
|
||||||
fields = list(fields)
|
fields = list(fields)
|
||||||
cache_key = str(fields)
|
cache_key = str(fields)
|
||||||
|
@ -37,18 +39,28 @@ class ModelBase(type):
|
||||||
# Create an ad hoc model class
|
# Create an ad hoc model class
|
||||||
attrs = {}
|
attrs = {}
|
||||||
for name, db_type in fields:
|
for name, db_type in fields:
|
||||||
if db_type.startswith('Enum'):
|
attrs[name] = cls.create_ad_hoc_field(db_type)
|
||||||
attrs[name] = orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
|
|
||||||
else:
|
|
||||||
field_class = db_type + 'Field'
|
|
||||||
if not hasattr(orm_fields, field_class):
|
|
||||||
raise NotImplementedError('No field class for %s' % db_type)
|
|
||||||
attrs[name] = getattr(orm_fields, field_class)()
|
|
||||||
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
|
model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
|
||||||
# Add the model class to the cache
|
# Add the model class to the cache
|
||||||
cls.ad_hoc_model_cache[cache_key] = model_class
|
cls.ad_hoc_model_cache[cache_key] = model_class
|
||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_ad_hoc_field(cls, db_type):
|
||||||
|
import infi.clickhouse_orm.fields as orm_fields
|
||||||
|
# Enums
|
||||||
|
if db_type.startswith('Enum'):
|
||||||
|
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
|
||||||
|
# Arrays
|
||||||
|
if db_type.startswith('Array'):
|
||||||
|
inner_field = cls.create_ad_hoc_field(db_type[6 : -1])
|
||||||
|
return orm_fields.ArrayField(inner_field)
|
||||||
|
# Simple fields
|
||||||
|
name = db_type + 'Field'
|
||||||
|
if not hasattr(orm_fields, name):
|
||||||
|
raise NotImplementedError('No field class for %s' % db_type)
|
||||||
|
return getattr(orm_fields, name)()
|
||||||
|
|
||||||
|
|
||||||
class Model(with_metaclass(ModelBase)):
|
class Model(with_metaclass(ModelBase)):
|
||||||
'''
|
'''
|
||||||
|
@ -144,6 +156,8 @@ class Model(with_metaclass(ModelBase)):
|
||||||
'''
|
'''
|
||||||
parts = []
|
parts = []
|
||||||
for name, field in self._fields:
|
for name, field in self._fields:
|
||||||
value = field.get_db_prep_value(field.to_python(getattr(self, name)))
|
value = field.to_db_string(getattr(self, name), quote=False)
|
||||||
parts.append(escape(value, quote=False))
|
parts.append(value)
|
||||||
return '\t'.join(parts)
|
tsv = '\t'.join(parts)
|
||||||
|
logger.debug(tsv)
|
||||||
|
return tsv
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from six import string_types, binary_type, text_type, PY3
|
from six import string_types, binary_type, text_type, PY3
|
||||||
import codecs
|
import codecs
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
SPECIAL_CHARS = {
|
SPECIAL_CHARS = {
|
||||||
|
@ -15,6 +16,11 @@ SPECIAL_CHARS = {
|
||||||
|
|
||||||
|
|
||||||
def escape(value, quote=True):
|
def escape(value, quote=True):
|
||||||
|
'''
|
||||||
|
If the value is a string, escapes any special characters and optionally
|
||||||
|
surrounds it with single quotes. If the value is not a string (e.g. a number),
|
||||||
|
converts it to one.
|
||||||
|
'''
|
||||||
if isinstance(value, string_types):
|
if isinstance(value, string_types):
|
||||||
chars = (SPECIAL_CHARS.get(c, c) for c in value)
|
chars = (SPECIAL_CHARS.get(c, c) for c in value)
|
||||||
value = "'" + "".join(chars) + "'" if quote else "".join(chars)
|
value = "'" + "".join(chars) + "'" if quote else "".join(chars)
|
||||||
|
@ -33,6 +39,40 @@ def parse_tsv(line):
|
||||||
return [unescape(value) for value in line.split('\t')]
|
return [unescape(value) for value in line.split('\t')]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_array(array_string):
|
||||||
|
"""
|
||||||
|
Parse an array string as returned by clickhouse. For example:
|
||||||
|
"['hello', 'world']" ==> ["hello", "world"]
|
||||||
|
"[1,2,3]" ==> [1, 2, 3]
|
||||||
|
"""
|
||||||
|
# Sanity check
|
||||||
|
if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']':
|
||||||
|
raise ValueError('Invalid array string: "%s"' % array_string)
|
||||||
|
# Drop opening brace
|
||||||
|
array_string = array_string[1:]
|
||||||
|
# Go over the string, lopping off each value at the beginning until nothing is left
|
||||||
|
values = []
|
||||||
|
while True:
|
||||||
|
if array_string == ']':
|
||||||
|
# End of array
|
||||||
|
return values
|
||||||
|
elif array_string[0] in ', ':
|
||||||
|
# In between values
|
||||||
|
array_string = array_string[1:]
|
||||||
|
elif array_string[0] == "'":
|
||||||
|
# Start of quoted value, find its end
|
||||||
|
match = re.search(r"[^\\]'", array_string)
|
||||||
|
if match is None:
|
||||||
|
raise ValueError('Missing closing quote: "%s"' % array_string)
|
||||||
|
values.append(array_string[1 : match.start() + 1])
|
||||||
|
array_string = array_string[match.end():]
|
||||||
|
else:
|
||||||
|
# Start of non-quoted value, find its end
|
||||||
|
match = re.search(r",|\]", array_string)
|
||||||
|
values.append(array_string[1 : match.start() + 1])
|
||||||
|
array_string = array_string[match.end():]
|
||||||
|
|
||||||
|
|
||||||
def import_submodules(package_name):
|
def import_submodules(package_name):
|
||||||
"""
|
"""
|
||||||
Import all submodules of a module.
|
Import all submodules of a module.
|
||||||
|
|
58
tests/test_array_fields.py
Normal file
58
tests/test_array_fields.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import unittest
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from infi.clickhouse_orm.database import Database
|
||||||
|
from infi.clickhouse_orm.models import Model
|
||||||
|
from infi.clickhouse_orm.fields import *
|
||||||
|
from infi.clickhouse_orm.engines import *
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFieldsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.database = Database('test-db')
|
||||||
|
self.database.create_table(ModelWithArrays)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.database.drop_database()
|
||||||
|
|
||||||
|
def test_insert_and_select(self):
|
||||||
|
instance = ModelWithArrays(
|
||||||
|
date_field='2016-08-30',
|
||||||
|
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
|
||||||
|
arr_date=['2010-01-01']
|
||||||
|
)
|
||||||
|
self.database.insert([instance])
|
||||||
|
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
|
||||||
|
for model_cls in (ModelWithArrays, None):
|
||||||
|
results = list(self.database.select(query, model_cls))
|
||||||
|
self.assertEquals(len(results), 1)
|
||||||
|
self.assertEquals(results[0].arr_str, instance.arr_str)
|
||||||
|
self.assertEquals(results[0].arr_int, instance.arr_int)
|
||||||
|
self.assertEquals(results[0].arr_date, instance.arr_date)
|
||||||
|
|
||||||
|
def test_conversion(self):
|
||||||
|
instance = ModelWithArrays(
|
||||||
|
arr_int=('1', '2', '3'),
|
||||||
|
arr_date=['2010-01-01']
|
||||||
|
)
|
||||||
|
self.assertEquals(instance.arr_str, [])
|
||||||
|
self.assertEquals(instance.arr_int, [1, 2, 3])
|
||||||
|
self.assertEquals(instance.arr_date, [date(2010, 1, 1)])
|
||||||
|
|
||||||
|
def test_assignment_error(self):
|
||||||
|
instance = ModelWithArrays()
|
||||||
|
for value in (7, 'x', [date.today()], ['aaa'], [None]):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
instance.arr_int = value
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithArrays(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
arr_str = ArrayField(StringField())
|
||||||
|
arr_int = ArrayField(Int32Field())
|
||||||
|
arr_date = ArrayField(DateField())
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('date_field',))
|
||||||
|
|
|
@ -16,6 +16,7 @@ class EnumFieldsTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.database = Database('test-db')
|
self.database = Database('test-db')
|
||||||
self.database.create_table(ModelWithEnum)
|
self.database.create_table(ModelWithEnum)
|
||||||
|
self.database.create_table(ModelWithEnumArray)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.database.drop_database()
|
self.database.drop_database()
|
||||||
|
@ -39,7 +40,9 @@ class EnumFieldsTest(unittest.TestCase):
|
||||||
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
|
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
|
||||||
results = list(self.database.select(query))
|
results = list(self.database.select(query))
|
||||||
self.assertEquals(len(results), 2)
|
self.assertEquals(len(results), 2)
|
||||||
|
self.assertEquals(results[0].enum_field.name, Fruit.apple.name)
|
||||||
self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
|
self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
|
||||||
|
self.assertEquals(results[1].enum_field.name, Fruit.orange.name)
|
||||||
self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
|
self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
|
||||||
|
|
||||||
def test_conversion(self):
|
def test_conversion(self):
|
||||||
|
@ -56,6 +59,14 @@ class EnumFieldsTest(unittest.TestCase):
|
||||||
instance = ModelWithEnum()
|
instance = ModelWithEnum()
|
||||||
self.assertEquals(instance.enum_field, Fruit.apple)
|
self.assertEquals(instance.enum_field, Fruit.apple)
|
||||||
|
|
||||||
|
def test_enum_array(self):
|
||||||
|
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
|
||||||
|
self.database.insert([instance])
|
||||||
|
query = 'SELECT * from $table ORDER BY date_field'
|
||||||
|
results = list(self.database.select(query, ModelWithEnumArray))
|
||||||
|
self.assertEquals(len(results), 1)
|
||||||
|
self.assertEquals(results[0].enum_array, instance.enum_array)
|
||||||
|
|
||||||
|
|
||||||
Fruit = Enum('Fruit', u'apple banana orange')
|
Fruit = Enum('Fruit', u'apple banana orange')
|
||||||
|
|
||||||
|
@ -67,3 +78,10 @@ class ModelWithEnum(Model):
|
||||||
|
|
||||||
engine = MergeTree('date_field', ('date_field',))
|
engine = MergeTree('date_field', ('date_field',))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithEnumArray(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
enum_array = ArrayField(Enum16Field(Fruit))
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('date_field',))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user