Chore: blacken tests

This commit is contained in:
olliemath 2021-07-27 23:19:15 +01:00
parent f2eb81371a
commit 9ccaacc640
49 changed files with 1326 additions and 1139 deletions

View File

@ -7,13 +7,13 @@ from clickhouse_orm.fields import *
from clickhouse_orm.engines import * from clickhouse_orm.engines import *
import logging import logging
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
class TestCaseWithData(unittest.TestCase): class TestCaseWithData(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(Person) self.database.create_table(Person)
def tearDown(self): def tearDown(self):
@ -35,7 +35,6 @@ class TestCaseWithData(unittest.TestCase):
yield Person(**entry) yield Person(**entry)
class Person(Model): class Person(Model):
first_name = StringField() first_name = StringField()
@ -44,16 +43,12 @@ class Person(Model):
height = Float32Field() height = Float32Field()
passport = NullableField(UInt32Field()) passport = NullableField(UInt32Field())
engine = MergeTree('birthday', ('first_name', 'last_name', 'birthday')) engine = MergeTree("birthday", ("first_name", "last_name", "birthday"))
data = [ data = [
{"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", {"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", "passport": 35052255},
"passport": 35052255}, {"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", "passport": 36052255},
{"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74",
"passport": 36052255},
{"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"}, {"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"},
{"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"}, {"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"},
{"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"}, {"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"},
@ -151,5 +146,5 @@ data = [
{"first_name": "Whitney", "last_name": "Durham", "birthday": "1977-09-15", "height": "1.72"}, {"first_name": "Whitney", "last_name": "Durham", "birthday": "1977-09-15", "height": "1.72"},
{"first_name": "Whitney", "last_name": "Scott", "birthday": "1971-07-04", "height": "1.70"}, {"first_name": "Whitney", "last_name": "Scott", "birthday": "1971-07-04", "height": "1.70"},
{"first_name": "Wynter", "last_name": "Garcia", "birthday": "1975-01-10", "height": "1.69"}, {"first_name": "Wynter", "last_name": "Garcia", "birthday": "1975-01-10", "height": "1.69"},
{"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"} {"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"},
] ]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model1)]
migrations.CreateTable(Model1)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.DropTable(Model1)]
migrations.DropTable(Model1)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model1)]
migrations.CreateTable(Model1)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model2)]
migrations.AlterTable(Model2)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model3)]
migrations.AlterTable(Model3)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(EnumModel1)]
migrations.CreateTable(EnumModel1)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(EnumModel2)]
migrations.AlterTable(EnumModel2)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(MaterializedModel)]
migrations.CreateTable(MaterializedModel)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(AliasModel)]
migrations.CreateTable(AliasModel)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model4Buffer)]
migrations.CreateTable(Model4Buffer)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTableWithBuffer(Model4Buffer_changed)]
migrations.AlterTableWithBuffer(Model4Buffer_changed)
]

View File

@ -2,8 +2,10 @@ from clickhouse_orm import migrations
operations = [ operations = [
migrations.RunSQL("INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-01', 1, 1, 'test') "), migrations.RunSQL("INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-01', 1, 1, 'test') "),
migrations.RunSQL([ migrations.RunSQL(
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-02', 2, 2, 'test2') ", [
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-03', 3, 3, 'test3') ", "INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-02', 2, 2, 'test2') ",
]) "INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-03', 3, 3, 'test3') ",
]
),
] ]

View File

@ -5,11 +5,7 @@ from test_migrations import Model3
def forward(database): def forward(database):
database.insert([ database.insert([Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4="test4")])
Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4='test4')
])
operations = [ operations = [migrations.RunPython(forward)]
migrations.RunPython(forward)
]

View File

@ -1,7 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(MaterializedModel1), migrations.AlterTable(AliasModel1)]
migrations.AlterTable(MaterializedModel1),
migrations.AlterTable(AliasModel1)
]

View File

@ -1,7 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model4_compressed), migrations.AlterTable(Model2LowCardinality)]
migrations.AlterTable(Model4_compressed),
migrations.AlterTable(Model2LowCardinality)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(ModelWithConstraints)]
migrations.CreateTable(ModelWithConstraints)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterConstraints(ModelWithConstraints2)]
migrations.AlterConstraints(ModelWithConstraints2)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(ModelWithIndex)]
migrations.CreateTable(ModelWithIndex)
]

View File

@ -1,6 +1,4 @@
from clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterIndexes(ModelWithIndex2, reindex=True)]
migrations.AlterIndexes(ModelWithIndex2, reindex=True)
]

View File

@ -9,24 +9,21 @@ from clickhouse_orm.funcs import F
class AliasFieldsTest(unittest.TestCase): class AliasFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithAliasFields) self.database.create_table(ModelWithAliasFields)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithAliasFields( instance = ModelWithAliasFields(date_field="2016-08-30", int_field=-10, str_field="TEST")
date_field='2016-08-30',
int_field=-10,
str_field='TEST'
)
self.database.insert([instance]) self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields # We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func' \ query = (
' FROM $db.%s ORDER BY alias_date' % ModelWithAliasFields.table_name() "SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func"
" FROM $db.%s ORDER BY alias_date" % ModelWithAliasFields.table_name()
)
for model_cls in (ModelWithAliasFields, None): for model_cls in (ModelWithAliasFields, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class AliasFieldsTest(unittest.TestCase):
def test_assignment_error(self): def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields. # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithAliasFields() instance = ModelWithAliasFields()
for value in ('x', [date.today()], ['aaa'], [None]): for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.alias_date = value instance.alias_date = value
@ -51,10 +48,10 @@ class AliasFieldsTest(unittest.TestCase):
def test_duplicate_default(self): def test_duplicate_default(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(alias='str_field', default='with default') StringField(alias="str_field", default="with default")
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(alias='str_field', materialized='str_field') StringField(alias="str_field", materialized="str_field")
def test_default_value(self): def test_default_value(self):
instance = ModelWithAliasFields() instance = ModelWithAliasFields()
@ -70,9 +67,9 @@ class ModelWithAliasFields(Model):
date_field = DateField() date_field = DateField()
str_field = StringField() str_field = StringField()
alias_str = StringField(alias=u'str_field') alias_str = StringField(alias=u"str_field")
alias_int = Int32Field(alias='int_field') alias_int = Int32Field(alias="int_field")
alias_date = DateField(alias='date_field') alias_date = DateField(alias="date_field")
alias_func = Int32Field(alias=F.toYYYYMM(date_field)) alias_func = Int32Field(alias=F.toYYYYMM(date_field))
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -8,9 +8,8 @@ from clickhouse_orm.engines import *
class ArrayFieldsTest(unittest.TestCase): class ArrayFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithArrays) self.database.create_table(ModelWithArrays)
def tearDown(self): def tearDown(self):
@ -18,12 +17,12 @@ class ArrayFieldsTest(unittest.TestCase):
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithArrays( instance = ModelWithArrays(
date_field='2016-08-30', date_field="2016-08-30",
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'], arr_str=["goodbye,", "cruel", "world", "special chars: ,\"\\'` \n\t\\[]"],
arr_date=['2010-01-01'], arr_date=["2010-01-01"],
) )
self.database.insert([instance]) self.database.insert([instance])
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field' query = "SELECT * from $db.modelwitharrays ORDER BY date_field"
for model_cls in (ModelWithArrays, None): for model_cls in (ModelWithArrays, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -32,32 +31,25 @@ class ArrayFieldsTest(unittest.TestCase):
self.assertEqual(results[0].arr_date, instance.arr_date) self.assertEqual(results[0].arr_date, instance.arr_date)
def test_conversion(self): def test_conversion(self):
instance = ModelWithArrays( instance = ModelWithArrays(arr_int=("1", "2", "3"), arr_date=["2010-01-01"])
arr_int=('1', '2', '3'),
arr_date=['2010-01-01']
)
self.assertEqual(instance.arr_str, []) self.assertEqual(instance.arr_str, [])
self.assertEqual(instance.arr_int, [1, 2, 3]) self.assertEqual(instance.arr_int, [1, 2, 3])
self.assertEqual(instance.arr_date, [date(2010, 1, 1)]) self.assertEqual(instance.arr_date, [date(2010, 1, 1)])
def test_assignment_error(self): def test_assignment_error(self):
instance = ModelWithArrays() instance = ModelWithArrays()
for value in (7, 'x', [date.today()], ['aaa'], [None]): for value in (7, "x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.arr_int = value instance.arr_int = value
def test_parse_array(self): def test_parse_array(self):
from clickhouse_orm.utils import parse_array, unescape from clickhouse_orm.utils import parse_array, unescape
self.assertEqual(parse_array("[]"), []) self.assertEqual(parse_array("[]"), [])
self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"]) self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"]) self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"]) self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
for s in ("", for s in ("", "[", "]", "[1, 2", "3, 4]", "['aaa', 'aaa]"):
"[",
"]",
"[1, 2",
"3, 4]",
"['aaa', 'aaa]"):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_array(s) parse_array(s)
@ -74,4 +66,4 @@ class ModelWithArrays(Model):
arr_int = ArrayField(Int32Field()) arr_int = ArrayField(Int32Field())
arr_date = ArrayField(DateField()) arr_date = ArrayField(DateField())
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -7,7 +7,6 @@ from .base_test_with_data import *
class BufferTestCase(TestCaseWithData): class BufferTestCase(TestCaseWithData):
def _insert_and_check_buffer(self, data, count): def _insert_and_check_buffer(self, data, count):
self.database.insert(data) self.database.insert(data)
self.assertEqual(count, self.database.count(PersonBuffer)) self.assertEqual(count, self.database.count(PersonBuffer))

View File

@ -10,9 +10,8 @@ from clickhouse_orm.utils import parse_tsv
class CompressedFieldsTestCase(unittest.TestCase): class CompressedFieldsTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(CompressedModel) self.database.create_table(CompressedModel)
def tearDown(self): def tearDown(self):
@ -24,7 +23,7 @@ class CompressedFieldsTestCase(unittest.TestCase):
self.database.insert([instance]) self.database.insert([instance])
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1)) self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc)) self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.string_field, 'dozo') self.assertEqual(instance.string_field, "dozo")
self.assertEqual(instance.int64_field, 42) self.assertEqual(instance.int64_field, 42)
self.assertEqual(instance.float_field, 0) self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.nullable_field, None) self.assertEqual(instance.nullable_field, None)
@ -36,11 +35,11 @@ class CompressedFieldsTestCase(unittest.TestCase):
uint64_field=217, uint64_field=217,
date_field=datetime.date(1973, 12, 6), date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc), datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
string_field='aloha', string_field="aloha",
int64_field=-50, int64_field=-50,
float_field=3.14, float_field=3.14,
nullable_field=-2.718281, nullable_field=-2.718281,
array_field=['123456789123456','','a'] array_field=["123456789123456", "", "a"],
) )
instance = CompressedModel(**kwargs) instance = CompressedModel(**kwargs)
self.database.insert([instance]) self.database.insert([instance])
@ -49,75 +48,91 @@ class CompressedFieldsTestCase(unittest.TestCase):
def test_string_conversion(self): def test_string_conversion(self):
# Check field conversion from string during construction # Check field conversion from string during construction
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', nullable_field=None, array_field='[a,b,c]') instance = CompressedModel(
date_field="1973-12-06", int64_field="100", float_field="7", nullable_field=None, array_field="[a,b,c]"
)
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6)) self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int64_field, 100) self.assertEqual(instance.int64_field, 100)
self.assertEqual(instance.float_field, 7) self.assertEqual(instance.float_field, 7)
self.assertEqual(instance.nullable_field, None) self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, ['a', 'b', 'c']) self.assertEqual(instance.array_field, ["a", "b", "c"])
# Check field conversion from string during assignment # Check field conversion from string during assignment
instance.int64_field = '99' instance.int64_field = "99"
self.assertEqual(instance.int64_field, 99) self.assertEqual(instance.int64_field, 99)
def test_to_dict(self): def test_to_dict(self):
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]') instance = CompressedModel(date_field="1973-12-06", int64_field="100", float_field="7", array_field="[a,b,c]")
self.assertDictEqual(instance.to_dict(), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a','b','c']
})
self.assertDictEqual(instance.to_dict(include_readonly=False), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a', 'b', 'c']
})
self.assertDictEqual( self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=('int64_field', 'alias_field', 'datetime_field')), { instance.to_dict(),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100, "int64_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc) "float_field": 7.0,
}) "datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=("int64_field", "alias_field", "datetime_field")),
{"int64_field": 100, "datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc)},
)
def test_confirm_compression_codec(self): def test_confirm_compression_codec(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]') instance = CompressedModel(date_field="1973-12-06", int64_field="100", float_field="7", array_field="[a,b,c]")
self.database.insert([instance]) self.database.insert([instance])
r = self.database.raw("select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(instance.table_name(), self.database.db_name)) r = self.database.raw(
"select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(
instance.table_name(), self.database.db_name
)
)
lines = r.splitlines() lines = r.splitlines()
field_names = parse_tsv(lines[0]) field_names = parse_tsv(lines[0])
field_types = parse_tsv(lines[1]) field_types = parse_tsv(lines[1])
data = [tuple(parse_tsv(line)) for line in lines[2:]] data = [tuple(parse_tsv(line)) for line in lines[2:]]
self.assertListEqual(data, [('uint64_field', 'CODEC(ZSTD(10))'), self.assertListEqual(
('datetime_field', 'CODEC(Delta(4), ZSTD(1))'), data,
('date_field', 'CODEC(Delta(4), ZSTD(22))'), [
('int64_field', 'CODEC(LZ4)'), ("uint64_field", "CODEC(ZSTD(10))"),
('string_field', 'CODEC(LZ4HC(10))'), ("datetime_field", "CODEC(Delta(4), ZSTD(1))"),
('nullable_field', 'CODEC(ZSTD(1))'), ("date_field", "CODEC(Delta(4), ZSTD(22))"),
('array_field', 'CODEC(Delta(2), LZ4HC(0))'), ("int64_field", "CODEC(LZ4)"),
('float_field', 'CODEC(NONE)'), ("string_field", "CODEC(LZ4HC(10))"),
('alias_field', 'CODEC(ZSTD(4))')]) ("nullable_field", "CODEC(ZSTD(1))"),
("array_field", "CODEC(Delta(2), LZ4HC(0))"),
("float_field", "CODEC(NONE)"),
("alias_field", "CODEC(ZSTD(4))"),
],
)
class CompressedModel(Model): class CompressedModel(Model):
uint64_field = UInt64Field(codec='ZSTD(10)') uint64_field = UInt64Field(codec="ZSTD(10)")
datetime_field = DateTimeField(codec='Delta,ZSTD') datetime_field = DateTimeField(codec="Delta,ZSTD")
date_field = DateField(codec='Delta(4),ZSTD(22)') date_field = DateField(codec="Delta(4),ZSTD(22)")
int64_field = Int64Field(default=42, codec='LZ4') int64_field = Int64Field(default=42, codec="LZ4")
string_field = StringField(default='dozo', codec='LZ4HC(10)') string_field = StringField(default="dozo", codec="LZ4HC(10)")
nullable_field = NullableField(Float32Field(), codec='ZSTD') nullable_field = NullableField(Float32Field(), codec="ZSTD")
array_field = ArrayField(FixedStringField(length=15), codec='Delta(2),LZ4HC') array_field = ArrayField(FixedStringField(length=15), codec="Delta(2),LZ4HC")
float_field = Float32Field(codec='NONE') float_field = Float32Field(codec="NONE")
alias_field = Float32Field(alias='float_field', codec='ZSTD(4)') alias_field = Float32Field(alias="float_field", codec="ZSTD(4)")
engine = MergeTree('datetime_field', ('uint64_field', 'datetime_field')) engine = MergeTree("datetime_field", ("uint64_field", "datetime_field"))

View File

@ -5,40 +5,37 @@ from .base_test_with_data import Person
class ConstraintsTest(unittest.TestCase): class ConstraintsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (19, 14, 3, 3): if self.database.server_version < (19, 14, 3, 3):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(PersonWithConstraints) self.database.create_table(PersonWithConstraints)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_valid_values(self): def test_insert_valid_values(self):
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66) [PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66)]
]) )
def test_insert_invalid_values(self): def test_insert_invalid_values(self):
with self.assertRaises(ServerError) as e: with self.assertRaises(ServerError) as e:
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66) [PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66)]
]) )
self.assertEqual(e.code, 469) self.assertEqual(e.code, 469)
self.assertTrue('Constraint `birthday_in_the_past`' in e.message) self.assertTrue("Constraint `birthday_in_the_past`" in e.message)
with self.assertRaises(ServerError) as e: with self.assertRaises(ServerError) as e:
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3) [PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3)]
]) )
self.assertEqual(e.code, 469) self.assertEqual(e.code, 469)
self.assertTrue('Constraint `max_height`' in e.message) self.assertTrue("Constraint `max_height`" in e.message)
class PersonWithConstraints(Person): class PersonWithConstraints(Person):
birthday_in_the_past = Constraint(Person.birthday <= F.today()) birthday_in_the_past = Constraint(Person.birthday <= F.today())
max_height = Constraint(Person.height <= 2.75) max_height = Constraint(Person.height <= 2.75)

View File

@ -6,9 +6,8 @@ from clickhouse_orm.engines import Memory
class CustomFieldsTest(unittest.TestCase): class CustomFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -19,15 +18,18 @@ class CustomFieldsTest(unittest.TestCase):
i = Int16Field() i = Int16Field()
f = BooleanField() f = BooleanField()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values # Check valid values
for index, value in enumerate([1, '1', True, 0, '0', False]): for index, value in enumerate([1, "1", True, 0, "0", False]):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
self.database.insert([rec]) self.database.insert([rec])
self.assertEqual([rec.f for rec in TestModel.objects_in(self.database).order_by('i')], self.assertEqual(
[True, True, True, False, False, False]) [rec.f for rec in TestModel.objects_in(self.database).order_by("i")],
[True, True, True, False, False, False],
)
# Check invalid values # Check invalid values
for value in [None, 'zzz', -5, 7]: for value in [None, "zzz", -5, 7]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)
@ -35,21 +37,20 @@ class CustomFieldsTest(unittest.TestCase):
class BooleanField(Field): class BooleanField(Field):
# The ClickHouse column type to use # The ClickHouse column type to use
db_type = 'UInt8' db_type = "UInt8"
# The default value if empty # The default value if empty
class_default = False class_default = False
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
# Convert valid values to bool # Convert valid values to bool
if value in (1, '1', True): if value in (1, "1", True):
return True return True
elif value in (0, '0', False): elif value in (0, "0", False):
return False return False
else: else:
raise ValueError('Invalid value for BooleanField: %r' % value) raise ValueError("Invalid value for BooleanField: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
# The value was already converted by to_python, so it's a bool # The value was already converted by to_python, so it's a bool
return '1' if value else '0' return "1" if value else "0"

View File

@ -12,7 +12,6 @@ from .base_test_with_data import *
class DatabaseTestCase(TestCaseWithData): class DatabaseTestCase(TestCaseWithData):
def test_insert__generator(self): def test_insert__generator(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
@ -33,17 +32,19 @@ class DatabaseTestCase(TestCaseWithData):
def test_insert__funcs_as_default_values(self): def test_insert__funcs_as_default_values(self):
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('Buggy in server versions before 20.1.2.4') raise unittest.SkipTest("Buggy in server versions before 20.1.2.4")
class TestModel(Model): class TestModel(Model):
a = DateTimeField(default=datetime.datetime(2020, 1, 1)) a = DateTimeField(default=datetime.datetime(2020, 1, 1))
b = DateField(default=F.toDate(a)) b = DateField(default=F.toDate(a))
c = Int32Field(default=7) c = Int32Field(default=7)
d = Int32Field(default=c * 5) d = Int32Field(default=c * 5)
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
self.database.insert([TestModel()]) self.database.insert([TestModel()])
t = TestModel.objects_in(self.database)[0] t = TestModel.objects_in(self.database)[0]
self.assertEqual(str(t.b), '2020-01-01') self.assertEqual(str(t.b), "2020-01-01")
self.assertEqual(t.d, 35) self.assertEqual(t.d, 35)
def test_count(self): def test_count(self):
@ -63,9 +64,9 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72) self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70) self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -79,10 +80,10 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 0) # default value self.assertEqual(results[0].height, 0) # default value
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 0) # default value self.assertEqual(results[1].height, 0) # default value
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -91,10 +92,10 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].__class__.__name__, 'AdHocModel') self.assertEqual(results[0].__class__.__name__, "AdHocModel")
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72) self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70) self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -116,7 +117,7 @@ class DatabaseTestCase(TestCaseWithData):
page_num = 1 page_num = 1
instances = set() instances = set()
while True: while True:
page = self.database.paginate(Person, 'first_name, last_name', page_num, page_size) page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
self.assertEqual(page.number_of_objects, len(data)) self.assertEqual(page.number_of_objects, len(data))
self.assertGreater(page.pages_total, 0) self.assertGreater(page.pages_total, 0)
[instances.add(obj.to_tsv()) for obj in page.objects] [instances.add(obj.to_tsv()) for obj in page.objects]
@ -131,15 +132,16 @@ class DatabaseTestCase(TestCaseWithData):
# Try different page sizes # Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150): for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality # Ask for the last page in two different ways and verify equality
page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size) page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size) page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
self.assertEqual(page_a[1:], page_b[1:]) self.assertEqual(page_a[1:], page_b[1:])
self.assertEqual([obj.to_tsv() for obj in page_a.objects], self.assertEqual([obj.to_tsv() for obj in page_a.objects], [obj.to_tsv() for obj in page_b.objects])
[obj.to_tsv() for obj in page_b.objects])
def test_pagination_empty_page(self): def test_pagination_empty_page(self):
for page_num in (-1, 1, 2): for page_num in (-1, 1, 2):
page = self.database.paginate(Person, 'first_name, last_name', page_num, 10, conditions="first_name = 'Ziggy'") page = self.database.paginate(
Person, "first_name, last_name", page_num, 10, conditions="first_name = 'Ziggy'"
)
self.assertEqual(page.number_of_objects, 0) self.assertEqual(page.number_of_objects, 0)
self.assertEqual(page.objects, []) self.assertEqual(page.objects, [])
self.assertEqual(page.pages_total, 0) self.assertEqual(page.pages_total, 0)
@ -149,22 +151,22 @@ class DatabaseTestCase(TestCaseWithData):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
for page_num in (0, -2, -100): for page_num in (0, -2, -100):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.database.paginate(Person, 'first_name, last_name', page_num, 100) self.database.paginate(Person, "first_name, last_name", page_num, 100)
def test_pagination_with_conditions(self): def test_pagination_with_conditions(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
# Conditions as string # Conditions as string
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions="first_name < 'Ava'") page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
# Conditions as expression # Conditions as expression
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Person.first_name < 'Ava') page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Person.first_name < "Ava")
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
# Conditions as Q object # Conditions as Q object
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Q(first_name__lt='Ava')) page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
def test_special_chars(self): def test_special_chars(self):
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r' s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
p = Person(first_name=s) p = Person(first_name=s)
self.database.insert([p]) self.database.insert([p])
p = list(self.database.select("SELECT * from $table", Person))[0] p = list(self.database.select("SELECT * from $table", Person))[0]
@ -178,18 +180,18 @@ class DatabaseTestCase(TestCaseWithData):
def test_invalid_user(self): def test_invalid_user(self):
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
Database(self.database.db_name, username='default', password='wrong') Database(self.database.db_name, username="default", password="wrong")
exc = cm.exception exc = cm.exception
if exc.code == 193: # ClickHouse version < 20.3 if exc.code == 193: # ClickHouse version < 20.3
self.assertTrue(exc.message.startswith('Wrong password for user default')) self.assertTrue(exc.message.startswith("Wrong password for user default"))
elif exc.code == 516: # ClickHouse version >= 20.3 elif exc.code == 516: # ClickHouse version >= 20.3
self.assertTrue(exc.message.startswith('default: Authentication failed')) self.assertTrue(exc.message.startswith("default: Authentication failed"))
else: else:
raise Exception('Unexpected error code - %s' % exc.code) raise Exception("Unexpected error code - %s" % exc.code)
def test_nonexisting_db(self): def test_nonexisting_db(self):
db = Database('db_not_here', autocreate=False) db = Database("db_not_here", autocreate=False)
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
db.create_table(Person) db.create_table(Person)
exc = cm.exception exc = cm.exception
@ -212,25 +214,28 @@ class DatabaseTestCase(TestCaseWithData):
def test_missing_engine(self): def test_missing_engine(self):
class EnginelessModel(Model): class EnginelessModel(Model):
float_field = Float32Field() float_field = Float32Field()
with self.assertRaises(DatabaseException) as cm: with self.assertRaises(DatabaseException) as cm:
self.database.create_table(EnginelessModel) self.database.create_table(EnginelessModel)
self.assertEqual(str(cm.exception), 'EnginelessModel class must define an engine') self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
def test_potentially_problematic_field_names(self): def test_potentially_problematic_field_names(self):
class Model1(Model): class Model1(Model):
system = StringField() system = StringField()
readonly = StringField() readonly = StringField()
engine = Memory() engine = Memory()
instance = Model1(system='s', readonly='r')
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r')) instance = Model1(system="s", readonly="r")
self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
self.database.create_table(Model1) self.database.create_table(Model1)
self.database.insert([instance]) self.database.insert([instance])
instance = Model1.objects_in(self.database)[0] instance = Model1.objects_in(self.database)[0]
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r')) self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
def test_does_table_exist(self): def test_does_table_exist(self):
class Person2(Person): class Person2(Person):
pass pass
self.assertTrue(self.database.does_table_exist(Person)) self.assertTrue(self.database.does_table_exist(Person))
self.assertFalse(self.database.does_table_exist(Person2)) self.assertFalse(self.database.does_table_exist(Person2))
@ -239,32 +244,31 @@ class DatabaseTestCase(TestCaseWithData):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.database.add_setting(0, 1) self.database.add_setting(0, 1)
# Add a setting and see that it makes the query fail # Add a setting and see that it makes the query fail
self.database.add_setting('max_columns_to_read', 1) self.database.add_setting("max_columns_to_read", 1)
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
list(self.database.select('SELECT * from system.tables')) list(self.database.select("SELECT * from system.tables"))
# Remove the setting and see that now it works # Remove the setting and see that now it works
self.database.add_setting('max_columns_to_read', None) self.database.add_setting("max_columns_to_read", None)
list(self.database.select('SELECT * from system.tables')) list(self.database.select("SELECT * from system.tables"))
def test_create_ad_hoc_field(self): def test_create_ad_hoc_field(self):
# Tests that create_ad_hoc_field works for all column types in the database # Tests that create_ad_hoc_field works for all column types in the database
from clickhouse_orm.models import ModelBase from clickhouse_orm.models import ModelBase
query = "SELECT DISTINCT type FROM system.columns" query = "SELECT DISTINCT type FROM system.columns"
for row in self.database.select(query): for row in self.database.select(query):
ModelBase.create_ad_hoc_field(row.type) ModelBase.create_ad_hoc_field(row.type)
def test_get_model_for_table(self): def test_get_model_for_table(self):
# Tests that get_model_for_table works for a non-system model # Tests that get_model_for_table works for a non-system model
model = self.database.get_model_for_table('person') model = self.database.get_model_for_table("person")
self.assertFalse(model.is_system_model()) self.assertFalse(model.is_system_model())
self.assertFalse(model.is_read_only()) self.assertFalse(model.is_read_only())
self.assertEqual(model.table_name(), 'person') self.assertEqual(model.table_name(), "person")
# Read a few records # Read a few records
list(model.objects_in(self.database)[:10]) list(model.objects_in(self.database)[:10])
# Inserts should work too # Inserts should work too
self.database.insert([ self.database.insert([model(first_name="aaa", last_name="bbb", height=1.77)])
model(first_name='aaa', last_name='bbb', height=1.77)
])
def test_get_model_for_table__system(self): def test_get_model_for_table__system(self):
# Tests that get_model_for_table works for all system tables # Tests that get_model_for_table works for all system tables
@ -279,7 +283,7 @@ class DatabaseTestCase(TestCaseWithData):
try: try:
list(model.objects_in(self.database)[:10]) list(model.objects_in(self.database)[:10])
except ServerError as e: except ServerError as e:
if 'Not enough privileges' in e.message: if "Not enough privileges" in e.message:
pass pass
else: else:
raise raise

View File

@ -9,33 +9,35 @@ from clickhouse_orm.engines import *
class DateFieldsTest(unittest.TestCase): class DateFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithDate) self.database.create_table(ModelWithDate)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self.database.insert([ self.database.insert(
ModelWithDate( [
date_field='2016-08-30', ModelWithDate(
datetime_field='2016-08-30 03:50:00', date_field="2016-08-30",
datetime64_field='2016-08-30 03:50:00.123456', datetime_field="2016-08-30 03:50:00",
datetime64_3_field='2016-08-30 03:50:00.123456' datetime64_field="2016-08-30 03:50:00.123456",
), datetime64_3_field="2016-08-30 03:50:00.123456",
ModelWithDate( ),
date_field='2016-08-31', ModelWithDate(
datetime_field='2016-08-31 01:30:00', date_field="2016-08-31",
datetime64_field='2016-08-31 01:30:00.123456', datetime_field="2016-08-31 01:30:00",
datetime64_3_field='2016-08-31 01:30:00.123456') datetime64_field="2016-08-31 01:30:00.123456",
]) datetime64_3_field="2016-08-31 01:30:00.123456",
),
]
)
# toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to # toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
query = 'SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field' query = "SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30)) self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
@ -46,11 +48,13 @@ class DateFieldsTest(unittest.TestCase):
self.assertEqual(results[1].hour_start, datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC)) self.assertEqual(results[1].hour_start, datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC)) self.assertEqual(results[0].datetime64_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_3_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123000, self.assertEqual(
tzinfo=pytz.UTC)) results[0].datetime64_3_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123000, tzinfo=pytz.UTC)
)
self.assertEqual(results[1].datetime64_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC)) self.assertEqual(results[1].datetime64_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_3_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123000, self.assertEqual(
tzinfo=pytz.UTC)) results[1].datetime64_3_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123000, tzinfo=pytz.UTC)
)
class ModelWithDate(Model): class ModelWithDate(Model):
@ -59,45 +63,46 @@ class ModelWithDate(Model):
datetime64_field = DateTime64Field() datetime64_field = DateTime64Field()
datetime64_3_field = DateTime64Field(precision=3) datetime64_3_field = DateTime64Field(precision=3)
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))
class ModelWithTz(Model): class ModelWithTz(Model):
datetime_no_tz_field = DateTimeField() # server tz datetime_no_tz_field = DateTimeField() # server tz
datetime_tz_field = DateTimeField(timezone='Europe/Madrid') datetime_tz_field = DateTimeField(timezone="Europe/Madrid")
datetime64_tz_field = DateTime64Field(timezone='Europe/Madrid') datetime64_tz_field = DateTime64Field(timezone="Europe/Madrid")
datetime_utc_field = DateTimeField(timezone=pytz.UTC) datetime_utc_field = DateTimeField(timezone=pytz.UTC)
engine = MergeTree('datetime_no_tz_field', ('datetime_no_tz_field',)) engine = MergeTree("datetime_no_tz_field", ("datetime_no_tz_field",))
class DateTimeFieldWithTzTest(unittest.TestCase): class DateTimeFieldWithTzTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithTz) self.database.create_table(ModelWithTz)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self.database.insert([ self.database.insert(
ModelWithTz( [
datetime_no_tz_field='2020-06-11 04:00:00', ModelWithTz(
datetime_tz_field='2020-06-11 04:00:00', datetime_no_tz_field="2020-06-11 04:00:00",
datetime64_tz_field='2020-06-11 04:00:00', datetime_tz_field="2020-06-11 04:00:00",
datetime_utc_field='2020-06-11 04:00:00', datetime64_tz_field="2020-06-11 04:00:00",
), datetime_utc_field="2020-06-11 04:00:00",
ModelWithTz( ),
datetime_no_tz_field='2020-06-11 07:00:00+0300', ModelWithTz(
datetime_tz_field='2020-06-11 07:00:00+0300', datetime_no_tz_field="2020-06-11 07:00:00+0300",
datetime64_tz_field='2020-06-11 07:00:00+0300', datetime_tz_field="2020-06-11 07:00:00+0300",
datetime_utc_field='2020-06-11 07:00:00+0300', datetime64_tz_field="2020-06-11 07:00:00+0300",
), datetime_utc_field="2020-06-11 07:00:00+0300",
]) ),
query = 'SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field' ]
)
query = "SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(results[0].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC)) self.assertEqual(results[0].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
@ -110,10 +115,10 @@ class DateTimeFieldWithTzTest(unittest.TestCase):
self.assertEqual(results[1].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC)) self.assertEqual(results[1].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone) self.assertEqual(results[0].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[0].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone) self.assertEqual(results[0].datetime_tz_field.tzinfo.zone, pytz.timezone("Europe/Madrid").zone)
self.assertEqual(results[0].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone) self.assertEqual(results[0].datetime64_tz_field.tzinfo.zone, pytz.timezone("Europe/Madrid").zone)
self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone) self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)
self.assertEqual(results[1].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone) self.assertEqual(results[1].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[1].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone) self.assertEqual(results[1].datetime_tz_field.tzinfo.zone, pytz.timezone("Europe/Madrid").zone)
self.assertEqual(results[1].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone) self.assertEqual(results[1].datetime64_tz_field.tzinfo.zone, pytz.timezone("Europe/Madrid").zone)
self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone) self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)

View File

@ -9,9 +9,8 @@ from clickhouse_orm.engines import *
class DecimalFieldsTest(unittest.TestCase): class DecimalFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
try: try:
self.database.create_table(DecimalModel) self.database.create_table(DecimalModel)
except ServerError as e: except ServerError as e:
@ -22,56 +21,58 @@ class DecimalFieldsTest(unittest.TestCase):
self.database.drop_database() self.database.drop_database()
def _insert_sample_data(self): def _insert_sample_data(self):
self.database.insert([ self.database.insert(
DecimalModel(date_field='2016-08-20'), [
DecimalModel(date_field='2016-08-21', dec=Decimal('1.234')), DecimalModel(date_field="2016-08-20"),
DecimalModel(date_field='2016-08-22', dec32=Decimal('12342.2345')), DecimalModel(date_field="2016-08-21", dec=Decimal("1.234")),
DecimalModel(date_field='2016-08-23', dec64=Decimal('12342.23456')), DecimalModel(date_field="2016-08-22", dec32=Decimal("12342.2345")),
DecimalModel(date_field='2016-08-24', dec128=Decimal('-4545456612342.234567')), DecimalModel(date_field="2016-08-23", dec64=Decimal("12342.23456")),
]) DecimalModel(date_field="2016-08-24", dec128=Decimal("-4545456612342.234567")),
]
)
def _assert_sample_data(self, results): def _assert_sample_data(self, results):
self.assertEqual(len(results), 5) self.assertEqual(len(results), 5)
self.assertEqual(results[0].dec, Decimal(0)) self.assertEqual(results[0].dec, Decimal(0))
self.assertEqual(results[0].dec32, Decimal(17)) self.assertEqual(results[0].dec32, Decimal(17))
self.assertEqual(results[1].dec, Decimal('1.234')) self.assertEqual(results[1].dec, Decimal("1.234"))
self.assertEqual(results[2].dec32, Decimal('12342.2345')) self.assertEqual(results[2].dec32, Decimal("12342.2345"))
self.assertEqual(results[3].dec64, Decimal('12342.23456')) self.assertEqual(results[3].dec64, Decimal("12342.23456"))
self.assertEqual(results[4].dec128, Decimal('-4545456612342.234567')) self.assertEqual(results[4].dec128, Decimal("-4545456612342.234567"))
def test_insert_and_select(self): def test_insert_and_select(self):
self._insert_sample_data() self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field' query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, DecimalModel)) results = list(self.database.select(query, DecimalModel))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self._insert_sample_data() self._insert_sample_data()
query = 'SELECT * from decimalmodel ORDER BY date_field' query = "SELECT * from decimalmodel ORDER BY date_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_rounding(self): def test_rounding(self):
d = Decimal('11111.2340000000000000001') d = Decimal("11111.2340000000000000001")
self.database.insert([DecimalModel(date_field='2016-08-20', dec=d, dec32=d, dec64=d, dec128=d)]) self.database.insert([DecimalModel(date_field="2016-08-20", dec=d, dec32=d, dec64=d, dec128=d)])
m = DecimalModel.objects_in(self.database)[0] m = DecimalModel.objects_in(self.database)[0]
for val in (m.dec, m.dec32, m.dec64, m.dec128): for val in (m.dec, m.dec32, m.dec64, m.dec128):
self.assertEqual(val, Decimal('11111.234')) self.assertEqual(val, Decimal("11111.234"))
def test_assignment_ok(self): def test_assignment_ok(self):
for value in (True, False, 17, 3.14, '20.5', Decimal('20.5')): for value in (True, False, 17, 3.14, "20.5", Decimal("20.5")):
DecimalModel(dec=value) DecimalModel(dec=value)
def test_assignment_error(self): def test_assignment_error(self):
for value in ('abc', u'זה ארוך', None, float('NaN'), Decimal('-Infinity')): for value in ("abc", u"זה ארוך", None, float("NaN"), Decimal("-Infinity")):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
DecimalModel(dec=value) DecimalModel(dec=value)
def test_aggregation(self): def test_aggregation(self):
self._insert_sample_data() self._insert_sample_data()
result = DecimalModel.objects_in(self.database).aggregate(m='min(dec)', n='max(dec)') result = DecimalModel.objects_in(self.database).aggregate(m="min(dec)", n="max(dec)")
self.assertEqual(result[0].m, Decimal(0)) self.assertEqual(result[0].m, Decimal(0))
self.assertEqual(result[0].n, Decimal('1.234')) self.assertEqual(result[0].n, Decimal("1.234"))
def test_precision_and_scale(self): def test_precision_and_scale(self):
# Go over all valid combinations # Go over all valid combinations
@ -86,36 +87,36 @@ class DecimalFieldsTest(unittest.TestCase):
def test_min_max(self): def test_min_max(self):
# In range # In range
f = DecimalField(3, 1) f = DecimalField(3, 1)
f.validate(f.to_python('99.9', None)) f.validate(f.to_python("99.9", None))
f.validate(f.to_python('-99.9', None)) f.validate(f.to_python("-99.9", None))
# In range after rounding # In range after rounding
f.validate(f.to_python('99.94', None)) f.validate(f.to_python("99.94", None))
f.validate(f.to_python('-99.94', None)) f.validate(f.to_python("-99.94", None))
# Out of range # Out of range
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.validate(f.to_python('99.99', None)) f.validate(f.to_python("99.99", None))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.validate(f.to_python('-99.99', None)) f.validate(f.to_python("-99.99", None))
# In range # In range
f = Decimal32Field(4) f = Decimal32Field(4)
f.validate(f.to_python('99999.9999', None)) f.validate(f.to_python("99999.9999", None))
f.validate(f.to_python('-99999.9999', None)) f.validate(f.to_python("-99999.9999", None))
# In range after rounding # In range after rounding
f.validate(f.to_python('99999.99994', None)) f.validate(f.to_python("99999.99994", None))
f.validate(f.to_python('-99999.99994', None)) f.validate(f.to_python("-99999.99994", None))
# Out of range # Out of range
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.validate(f.to_python('100000', None)) f.validate(f.to_python("100000", None))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.validate(f.to_python('-100000', None)) f.validate(f.to_python("-100000", None))
class DecimalModel(Model): class DecimalModel(Model):
date_field = DateField() date_field = DateField()
dec = DecimalField(15, 3) dec = DecimalField(15, 3)
dec32 = Decimal32Field(4, default=17) dec32 = Decimal32Field(4, default=17)
dec64 = Decimal64Field(5) dec64 = Decimal64Field(5)
dec128 = Decimal128Field(6) dec128 = Decimal128Field(6)
engine = Memory() engine = Memory()

View File

@ -5,37 +5,36 @@ from clickhouse_orm import *
class DictionaryTestMixin: class DictionaryTestMixin:
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 11, 73): if self.database.server_version < (20, 1, 11, 73):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self._create_dictionary() self._create_dictionary()
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def _test_func(self, func, expected_value): def _test_func(self, func, expected_value):
sql = 'SELECT %s AS value' % func.to_sql() sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql) logging.info(sql)
result = list(self.database.select(sql)) result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
print('Comparing %s to %s' % (result[0].value, expected_value)) print("Comparing %s to %s" % (result[0].value, expected_value))
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase): class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self): def _create_dictionary(self):
# Create a table to be used as source for the dictionary # Create a table to be used as source for the dictionary
self.database.create_table(NumberName) self.database.create_table(NumberName)
self.database.insert( self.database.insert(
NumberName(number=i, name=name) NumberName(number=i, name=name)
for i, name in enumerate('Zero One Two Three Four Five Six Seven Eight Nine Ten'.split()) for i, name in enumerate("Zero One Two Three Four Five Six Seven Eight Nine Ten".split())
) )
# Create the dictionary # Create the dictionary
self.database.raw(""" self.database.raw(
"""
CREATE DICTIONARY numbers_dict( CREATE DICTIONARY numbers_dict(
number UInt64, number UInt64,
name String DEFAULT '?' name String DEFAULT '?'
@ -46,16 +45,17 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
)) ))
LIFETIME(100) LIFETIME(100)
LAYOUT(HASHED()); LAYOUT(HASHED());
""") """
self.dict_name = 'test-db.numbers_dict' )
self.dict_name = "test-db.numbers_dict"
def test_dictget(self): def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(3)), 'Three') self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(3)), "Three")
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(99)), '?') self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(99)), "?")
def test_dictgetordefault(self): def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(3), 'n/a'), 'Three') self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(3), "n/a"), "Three")
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(99), 'n/a'), 'n/a') self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(99), "n/a"), "n/a")
def test_dicthas(self): def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -63,19 +63,21 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase): class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self): def _create_dictionary(self):
# Create a table to be used as source for the dictionary # Create a table to be used as source for the dictionary
self.database.create_table(Region) self.database.create_table(Region)
self.database.insert([ self.database.insert(
Region(region_id=1, parent_region=0, region_name='Russia'), [
Region(region_id=2, parent_region=1, region_name='Moscow'), Region(region_id=1, parent_region=0, region_name="Russia"),
Region(region_id=3, parent_region=2, region_name='Center'), Region(region_id=2, parent_region=1, region_name="Moscow"),
Region(region_id=4, parent_region=0, region_name='Great Britain'), Region(region_id=3, parent_region=2, region_name="Center"),
Region(region_id=5, parent_region=4, region_name='London'), Region(region_id=4, parent_region=0, region_name="Great Britain"),
]) Region(region_id=5, parent_region=4, region_name="London"),
]
)
# Create the dictionary # Create the dictionary
self.database.raw(""" self.database.raw(
"""
CREATE DICTIONARY regions_dict( CREATE DICTIONARY regions_dict(
region_id UInt64, region_id UInt64,
parent_region UInt64 HIERARCHICAL, parent_region UInt64 HIERARCHICAL,
@ -87,17 +89,18 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
)) ))
LIFETIME(100) LIFETIME(100)
LAYOUT(HASHED()); LAYOUT(HASHED());
""") """
self.dict_name = 'test-db.regions_dict' )
self.dict_name = "test-db.regions_dict"
def test_dictget(self): def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(3)), 'Center') self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(3)), "Center")
self._test_func(F.dictGet(self.dict_name, 'parent_region', F.toUInt64(3)), 2) self._test_func(F.dictGet(self.dict_name, "parent_region", F.toUInt64(3)), 2)
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(99)), '?') self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(99)), "?")
def test_dictgetordefault(self): def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(3), 'n/a'), 'Center') self._test_func(F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(3), "n/a"), "Center")
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(99), 'n/a'), 'n/a') self._test_func(F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(99), "n/a"), "n/a")
def test_dicthas(self): def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -114,7 +117,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class NumberName(Model): class NumberName(Model):
''' A table to act as a source for the dictionary ''' """A table to act as a source for the dictionary"""
number = UInt64Field() number = UInt64Field()
name = StringField() name = StringField()

View File

@ -4,13 +4,13 @@ import datetime
from clickhouse_orm import * from clickhouse_orm import *
import logging import logging
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
class _EnginesHelperTestCase(unittest.TestCase): class _EnginesHelperTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -19,32 +19,47 @@ class _EnginesHelperTestCase(unittest.TestCase):
class EnginesTestCase(_EnginesHelperTestCase): class EnginesTestCase(_EnginesHelperTestCase):
def _create_and_insert(self, model_class): def _create_and_insert(self, model_class):
self.database.create_table(model_class) self.database.create_table(model_class)
self.database.insert([ self.database.insert(
model_class(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1) [model_class(date="2017-01-01", event_id=23423, event_group=13, event_count=7, event_version=1)]
]) )
def test_merge_tree(self): def test_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group')) engine = MergeTree("date", ("date", "event_id", "event_group"))
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_sampling(self): def test_merge_tree_with_sampling(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', 'intHash32(event_id)'), sampling_expr='intHash32(event_id)') engine = MergeTree(
"date", ("date", "event_id", "event_group", "intHash32(event_id)"), sampling_expr="intHash32(event_id)"
)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_sampling__funcs(self): def test_merge_tree_with_sampling__funcs(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', F.intHash32(SampleModel.event_id)), sampling_expr=F.intHash32(SampleModel.event_id)) engine = MergeTree(
"date",
("date", "event_id", "event_group", F.intHash32(SampleModel.event_id)),
sampling_expr=F.intHash32(SampleModel.event_id),
)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_granularity(self): def test_merge_tree_with_granularity(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), index_granularity=4096) engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_replicated_merge_tree(self): def test_replicated_merge_tree(self):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits', replica_name='{replica}') engine = MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
replica_name="{replica}",
)
# In ClickHouse 1.1.54310 custom partitioning key was introduced and new syntax is used # In ClickHouse 1.1.54310 custom partitioning key was introduced and new syntax is used
if self.database.server_version >= (1, 1, 54310): if self.database.server_version >= (1, 1, 54310):
expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}') PARTITION BY (toYYYYMM(`date`)) ORDER BY (date, event_id, event_group) SETTINGS index_granularity=8192" expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}') PARTITION BY (toYYYYMM(`date`)) ORDER BY (date, event_id, event_group) SETTINGS index_granularity=8192"
@ -54,38 +69,48 @@ class EnginesTestCase(_EnginesHelperTestCase):
def test_replicated_merge_tree_incomplete(self): def test_replicated_merge_tree_incomplete(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits') MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_name='{replica}') MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
def test_collapsing_merge_tree(self): def test_collapsing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = CollapsingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_version') engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_summing_merge_tree(self): def test_summing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = SummingMergeTree('date', ('date', 'event_group'), ('event_count',)) engine = SummingMergeTree("date", ("date", "event_group"), ("event_count",))
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_replacing_merge_tree(self): def test_replacing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = ReplacingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_uversion') engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_tiny_log(self): def test_tiny_log(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = TinyLog() engine = TinyLog()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_log(self): def test_log(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = Log() engine = Log()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_memory(self): def test_memory(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = Memory() engine = Memory()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge(self): def test_merge(self):
@ -96,7 +121,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
engine = TinyLog() engine = TinyLog()
class TestMergeModel(MergeModel, SampleModel): class TestMergeModel(MergeModel, SampleModel):
engine = Merge('^testmodel') engine = Merge("^testmodel")
self.database.create_table(TestModel1) self.database.create_table(TestModel1)
self.database.create_table(TestModel2) self.database.create_table(TestModel2)
@ -104,54 +129,57 @@ class EnginesTestCase(_EnginesHelperTestCase):
# Insert operations are restricted for this model type # Insert operations are restricted for this model type
with self.assertRaises(DatabaseException): with self.assertRaises(DatabaseException):
self.database.insert([ self.database.insert(
TestMergeModel(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1) [TestMergeModel(date="2017-01-01", event_id=23423, event_group=13, event_count=7, event_version=1)]
]) )
# Testing select # Testing select
self.database.insert([ self.database.insert([TestModel1(date="2017-01-01", event_id=1, event_group=1, event_count=1, event_version=1)])
TestModel1(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1) self.database.insert([TestModel2(date="2017-01-02", event_id=2, event_group=2, event_count=2, event_version=2)])
])
self.database.insert([
TestModel2(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2)
])
# event_uversion is materialized field. So * won't select it and it will be zero # event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, _table, event_uversion FROM $table ORDER BY event_id', model_class=TestMergeModel) res = self.database.select(
"SELECT *, _table, event_uversion FROM $table ORDER BY event_id", model_class=TestMergeModel
)
res = list(res) res = list(res)
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertDictEqual({ self.assertDictEqual(
'_table': 'testmodel1', {
'date': datetime.date(2017, 1, 1), "_table": "testmodel1",
'event_id': 1, "date": datetime.date(2017, 1, 1),
'event_group': 1, "event_id": 1,
'event_count': 1, "event_group": 1,
'event_version': 1, "event_count": 1,
'event_uversion': 1 "event_version": 1,
}, res[0].to_dict(include_readonly=True)) "event_uversion": 1,
self.assertDictEqual({ },
'_table': 'testmodel2', res[0].to_dict(include_readonly=True),
'date': datetime.date(2017, 1, 2), )
'event_id': 2, self.assertDictEqual(
'event_group': 2, {
'event_count': 2, "_table": "testmodel2",
'event_version': 2, "date": datetime.date(2017, 1, 2),
'event_uversion': 2 "event_id": 2,
}, res[1].to_dict(include_readonly=True)) "event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=True),
)
def test_custom_partitioning(self): def test_custom_partitioning(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree( engine = MergeTree(
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"), partition_key=("toYYYYMM(date)", "event_group")
partition_key=('toYYYYMM(date)', 'event_group')
) )
class TestCollapseModel(SampleModel): class TestCollapseModel(SampleModel):
sign = Int8Field() sign = Int8Field()
engine = CollapsingMergeTree( engine = CollapsingMergeTree(
sign_col='sign', sign_col="sign",
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)', 'event_group') partition_key=("toYYYYMM(date)", "event_group"),
) )
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
@ -161,30 +189,30 @@ class EnginesTestCase(_EnginesHelperTestCase):
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table) parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
self.assertEqual(2, len(parts)) self.assertEqual(2, len(parts))
self.assertEqual('testcollapsemodel', parts[0].table) self.assertEqual("testcollapsemodel", parts[0].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[0].partition.replace(' ', '')) self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
self.assertEqual('testmodel', parts[1].table) self.assertEqual("testmodel", parts[1].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[1].partition.replace(' ', '')) self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
def test_custom_primary_key(self): def test_custom_primary_key(self):
if self.database.server_version < (18, 1): if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree( engine = MergeTree(
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)',), partition_key=("toYYYYMM(date)",),
primary_key=('date', 'event_id') primary_key=("date", "event_id"),
) )
class TestCollapseModel(SampleModel): class TestCollapseModel(SampleModel):
sign = Int8Field() sign = Int8Field()
engine = CollapsingMergeTree( engine = CollapsingMergeTree(
sign_col='sign', sign_col="sign",
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)',), partition_key=("toYYYYMM(date)",),
primary_key=('date', 'event_id') primary_key=("date", "event_id"),
) )
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
@ -195,28 +223,28 @@ class EnginesTestCase(_EnginesHelperTestCase):
class SampleModel(Model): class SampleModel(Model):
date = DateField() date = DateField()
event_id = UInt32Field() event_id = UInt32Field()
event_group = UInt32Field() event_group = UInt32Field()
event_count = UInt16Field() event_count = UInt16Field()
event_version = Int8Field() event_version = Int8Field()
event_uversion = UInt8Field(materialized='abs(event_version)') event_uversion = UInt8Field(materialized="abs(event_version)")
class DistributedTestCase(_EnginesHelperTestCase): class DistributedTestCase(_EnginesHelperTestCase):
def test_without_table_name(self): def test_without_table_name(self):
engine = Distributed('my_cluster') engine = Distributed("my_cluster")
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
engine.create_table_sql(self.database) engine.create_table_sql(self.database)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'Cannot create Distributed engine: specify an underlying table') self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
def test_with_table_name(self): def test_with_table_name(self):
engine = Distributed('my_cluster', 'foo') engine = Distributed("my_cluster", "foo")
sql = engine.create_table_sql(self.database) sql = engine.create_table_sql(self.database)
self.assertEqual(sql, 'Distributed(`my_cluster`, `test-db`, `foo`)') self.assertEqual(sql, "Distributed(`my_cluster`, `test-db`, `foo`)")
class TestModel(SampleModel): class TestModel(SampleModel):
engine = TinyLog() engine = TinyLog()
@ -231,7 +259,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_bad_cluster_name(self): def test_bad_cluster_name(self):
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
d_model = self._create_distributed('cluster_name') d_model = self._create_distributed("cluster_name")
self.database.count(d_model) self.database.count(d_model)
exc = cm.exception exc = cm.exception
@ -243,7 +271,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log() engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2): class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost', self.TestModel) engine = Distributed("test_shard_localhost", self.TestModel)
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
@ -251,7 +279,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_minimal_engine(self): def test_minimal_engine(self):
class TestDistributedModel(DistributedModel, self.TestModel): class TestDistributedModel(DistributedModel, self.TestModel):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
@ -263,64 +291,78 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log() engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2): class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure ' self.assertEqual(
'that your model has exactly one non-distributed superclass') str(exc),
"When defining Distributed engine without the table_name ensure "
"that your model has exactly one non-distributed superclass",
)
def test_minimal_engine_no_superclasses(self): def test_minimal_engine_no_superclasses(self):
class TestDistributedModel(DistributedModel): class TestDistributedModel(DistributedModel):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure ' self.assertEqual(
'that your model has a parent model') str(exc),
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
)
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True): def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
d_model = self._create_distributed('test_shard_localhost', underlying=test_model) d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
if local_to_distributed: if local_to_distributed:
to_insert, to_select = test_model, d_model to_insert, to_select = test_model, d_model
else: else:
to_insert, to_select = d_model, test_model to_insert, to_select = d_model, test_model
self.database.insert([ self.database.insert(
to_insert(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1), [
to_insert(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2) to_insert(date="2017-01-01", event_id=1, event_group=1, event_count=1, event_version=1),
]) to_insert(date="2017-01-02", event_id=2, event_group=2, event_count=2, event_version=2),
]
)
# event_uversion is materialized field. So * won't select it and it will be zero # event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, event_uversion FROM $table ORDER BY event_id', res = self.database.select("SELECT *, event_uversion FROM $table ORDER BY event_id", model_class=to_select)
model_class=to_select)
res = [row for row in res] res = [row for row in res]
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertDictEqual({ self.assertDictEqual(
'date': datetime.date(2017, 1, 1), {
'event_id': 1, "date": datetime.date(2017, 1, 1),
'event_group': 1, "event_id": 1,
'event_count': 1, "event_group": 1,
'event_version': 1, "event_count": 1,
'event_uversion': 1 "event_version": 1,
}, res[0].to_dict(include_readonly=include_readonly)) "event_uversion": 1,
self.assertDictEqual({ },
'date': datetime.date(2017, 1, 2), res[0].to_dict(include_readonly=include_readonly),
'event_id': 2, )
'event_group': 2, self.assertDictEqual(
'event_count': 2, {
'event_version': 2, "date": datetime.date(2017, 1, 2),
'event_uversion': 2 "event_id": 2,
}, res[1].to_dict(include_readonly=include_readonly)) "event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=include_readonly),
)
@unittest.skip("Bad support of materialized fields in Distributed tables " @unittest.skip(
"https://groups.google.com/forum/#!topic/clickhouse/XEYRRwZrsSc") "Bad support of materialized fields in Distributed tables "
"https://groups.google.com/forum/#!topic/clickhouse/XEYRRwZrsSc"
)
def test_insert_distributed_select_local(self): def test_insert_distributed_select_local(self):
return self._test_insert_select(local_to_distributed=False) return self._test_insert_select(local_to_distributed=False)

View File

@ -9,9 +9,8 @@ from enum import Enum
class EnumFieldsTest(unittest.TestCase): class EnumFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithEnum) self.database.create_table(ModelWithEnum)
self.database.create_table(ModelWithEnumArray) self.database.create_table(ModelWithEnumArray)
@ -19,12 +18,14 @@ class EnumFieldsTest(unittest.TestCase):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
self.database.insert([ self.database.insert(
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), [
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange), ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
]) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
query = 'SELECT * from $table ORDER BY date_field' ]
)
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnum)) results = list(self.database.select(query, ModelWithEnum))
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field, Fruit.apple) self.assertEqual(results[0].enum_field, Fruit.apple)
@ -32,12 +33,14 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(results[2].enum_field, Fruit.cherry) self.assertEqual(results[2].enum_field, Fruit.cherry)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self.database.insert([ self.database.insert(
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), [
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange), ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
]) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
query = 'SELECT * from $db.modelwithenum ORDER BY date_field' ]
)
query = "SELECT * from $db.modelwithenum ORDER BY date_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field.name, Fruit.apple.name) self.assertEqual(results[0].enum_field.name, Fruit.apple.name)
@ -50,11 +53,11 @@ class EnumFieldsTest(unittest.TestCase):
def test_conversion(self): def test_conversion(self):
self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange) self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
self.assertEqual(ModelWithEnum(enum_field=-7).enum_field, Fruit.cherry) self.assertEqual(ModelWithEnum(enum_field=-7).enum_field, Fruit.cherry)
self.assertEqual(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple) self.assertEqual(ModelWithEnum(enum_field="apple").enum_field, Fruit.apple)
self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana) self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
def test_assignment_error(self): def test_assignment_error(self):
for value in (0, 17, 'pear', '', None, 99.9): for value in (0, 17, "pear", "", None, 99.9):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ModelWithEnum(enum_field=value) ModelWithEnum(enum_field=value)
@ -63,15 +66,15 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(instance.enum_field, Fruit.apple) self.assertEqual(instance.enum_field, Fruit.apple)
def test_enum_array(self): def test_enum_array(self):
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange]) instance = ModelWithEnumArray(date_field="2016-08-30", enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
self.database.insert([instance]) self.database.insert([instance])
query = 'SELECT * from $table ORDER BY date_field' query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnumArray)) results = list(self.database.select(query, ModelWithEnumArray))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0].enum_array, instance.enum_array) self.assertEqual(results[0].enum_array, instance.enum_array)
Fruit = Enum('Fruit', [('apple', 1), ('banana', 2), ('orange', 3), ('cherry', -7)]) Fruit = Enum("Fruit", [("apple", 1), ("banana", 2), ("orange", 3), ("cherry", -7)])
class ModelWithEnum(Model): class ModelWithEnum(Model):
@ -79,7 +82,7 @@ class ModelWithEnum(Model):
date_field = DateField() date_field = DateField()
enum_field = Enum8Field(Fruit) enum_field = Enum8Field(Fruit)
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))
class ModelWithEnumArray(Model): class ModelWithEnumArray(Model):
@ -87,5 +90,4 @@ class ModelWithEnumArray(Model):
date_field = DateField() date_field = DateField()
enum_array = ArrayField(Enum16Field(Fruit)) enum_array = ArrayField(Enum16Field(Fruit))
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -8,43 +8,44 @@ from clickhouse_orm.engines import *
class FixedStringFieldsTest(unittest.TestCase): class FixedStringFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(FixedStringModel) self.database.create_table(FixedStringModel)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def _insert_sample_data(self): def _insert_sample_data(self):
self.database.insert([ self.database.insert(
FixedStringModel(date_field='2016-08-30', fstr_field=''), [
FixedStringModel(date_field='2016-08-30'), FixedStringModel(date_field="2016-08-30", fstr_field=""),
FixedStringModel(date_field='2016-08-31', fstr_field='foo'), FixedStringModel(date_field="2016-08-30"),
FixedStringModel(date_field='2016-08-31', fstr_field=u'לילה') FixedStringModel(date_field="2016-08-31", fstr_field="foo"),
]) FixedStringModel(date_field="2016-08-31", fstr_field=u"לילה"),
]
)
def _assert_sample_data(self, results): def _assert_sample_data(self, results):
self.assertEqual(len(results), 4) self.assertEqual(len(results), 4)
self.assertEqual(results[0].fstr_field, '') self.assertEqual(results[0].fstr_field, "")
self.assertEqual(results[1].fstr_field, 'ABCDEFGHIJK') self.assertEqual(results[1].fstr_field, "ABCDEFGHIJK")
self.assertEqual(results[2].fstr_field, 'foo') self.assertEqual(results[2].fstr_field, "foo")
self.assertEqual(results[3].fstr_field, u'לילה') self.assertEqual(results[3].fstr_field, u"לילה")
def test_insert_and_select(self): def test_insert_and_select(self):
self._insert_sample_data() self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field' query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, FixedStringModel)) results = list(self.database.select(query, FixedStringModel))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self._insert_sample_data() self._insert_sample_data()
query = 'SELECT * from $db.fixedstringmodel ORDER BY date_field' query = "SELECT * from $db.fixedstringmodel ORDER BY date_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_assignment_error(self): def test_assignment_error(self):
for value in (17, 'this is too long', u'זה ארוך', None, 99.9): for value in (17, "this is too long", u"זה ארוך", None, 99.9):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
FixedStringModel(fstr_field=value) FixedStringModel(fstr_field=value)
@ -52,6 +53,6 @@ class FixedStringFieldsTest(unittest.TestCase):
class FixedStringModel(Model): class FixedStringModel(Model):
date_field = DateField() date_field = DateField()
fstr_field = FixedStringField(12, default='ABCDEFGHIJK') fstr_field = FixedStringField(12, default="ABCDEFGHIJK")
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -13,7 +13,6 @@ from clickhouse_orm.funcs import F
class FuncsTestCase(TestCaseWithData): class FuncsTestCase(TestCaseWithData):
def setUp(self): def setUp(self):
super(FuncsTestCase, self).setUp() super(FuncsTestCase, self).setUp()
self.database.insert(self._sample_data()) self.database.insert(self._sample_data())
@ -23,24 +22,24 @@ class FuncsTestCase(TestCaseWithData):
count = 0 count = 0
for instance in qs: for instance in qs:
count += 1 count += 1
logging.info('\t[%d]\t%s' % (count, instance.to_dict())) logging.info("\t[%d]\t%s" % (count, instance.to_dict()))
self.assertEqual(count, expected_count) self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count) self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=NO_VALUE): def _test_func(self, func, expected_value=NO_VALUE):
sql = 'SELECT %s AS value' % func.to_sql() sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql) logging.info(sql)
try: try:
result = list(self.database.select(sql)) result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
if expected_value != NO_VALUE: if expected_value != NO_VALUE:
print('Comparing %s to %s' % (result[0].value, expected_value)) print("Comparing %s to %s" % (result[0].value, expected_value))
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
except ServerError as e: except ServerError as e:
if 'Unknown function' in e.message: if "Unknown function" in e.message:
logging.warning(e.message) logging.warning(e.message)
return # ignore functions that don't exist in the used ClickHouse version return # ignore functions that don't exist in the used ClickHouse version
raise raise
def _test_aggr(self, func, expected_value=NO_VALUE): def _test_aggr(self, func, expected_value=NO_VALUE):
@ -48,45 +47,45 @@ class FuncsTestCase(TestCaseWithData):
logging.info(qs.as_sql()) logging.info(qs.as_sql())
try: try:
result = list(qs) result = list(qs)
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
if expected_value != NO_VALUE: if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
except ServerError as e: except ServerError as e:
if 'Unknown function' in e.message: if "Unknown function" in e.message:
logging.warning(e.message) logging.warning(e.message)
return # ignore functions that don't exist in the used ClickHouse version return # ignore functions that don't exist in the used ClickHouse version
raise raise
def test_func_to_sql(self): def test_func_to_sql(self):
# No args # No args
self.assertEqual(F('func').to_sql(), 'func()') self.assertEqual(F("func").to_sql(), "func()")
# String args # String args
self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')") self.assertEqual(F("func", "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
# Numeric args # Numeric args
self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)") self.assertEqual(F("func", 1, 1.1, Decimal("3.3")).to_sql(), "func(1, 1.1, 3.3)")
# Date args # Date args
self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))") self.assertEqual(F("func", date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))")
# Datetime args # Datetime args
self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))") self.assertEqual(F("func", datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))")
# Boolean args # Boolean args
self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)") self.assertEqual(F("func", True, False).to_sql(), "func(1, 0)")
# Timezone args # Timezone args
self.assertEqual(F('func', pytz.utc).to_sql(), "func('UTC')") self.assertEqual(F("func", pytz.utc).to_sql(), "func('UTC')")
self.assertEqual(F('func', pytz.timezone('Europe/Athens')).to_sql(), "func('Europe/Athens')") self.assertEqual(F("func", pytz.timezone("Europe/Athens")).to_sql(), "func('Europe/Athens')")
# Null args # Null args
self.assertEqual(F('func', None).to_sql(), "func(NULL)") self.assertEqual(F("func", None).to_sql(), "func(NULL)")
# Fields as args # Fields as args
self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)") self.assertEqual(F("func", SampleModel.color).to_sql(), "func(`color`)")
# Funcs as args # Funcs as args
self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))') self.assertEqual(F("func", F("sqrt", 25)).to_sql(), "func(sqrt(25))")
# Iterables as args # Iterables as args
x = [1, 'z', F('foo', 17)] x = [1, "z", F("foo", 17)]
for y in [x, iter(x)]: for y in [x, iter(x)]:
self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)") self.assertEqual(F("func", y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
# Tuples as args # Tuples as args
self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])") self.assertEqual(F("func", [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])")
self.assertEqual(F('func', tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)") self.assertEqual(F("func", tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)")
# Binary operator functions # Binary operator functions
self.assertEqual(F.plus(1, 2).to_sql(), "(1 + 2)") self.assertEqual(F.plus(1, 2).to_sql(), "(1 + 2)")
self.assertEqual(F.lessOrEquals(1, 2).to_sql(), "(1 <= 2)") self.assertEqual(F.lessOrEquals(1, 2).to_sql(), "(1 <= 2)")
@ -106,32 +105,32 @@ class FuncsTestCase(TestCaseWithData):
def test_filter_date_field(self): def test_filter_date_field(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
# People born on the 30th # People born on the 30th
self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3) self._test_qs(qs.filter(F("equals", F("toDayOfMonth", Person.birthday), 30)), 3)
self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3) self._test_qs(qs.filter(F("toDayOfMonth", Person.birthday) == 30), 3)
self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3) self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3)
# People born on Sunday # People born on Sunday
self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18) self._test_qs(qs.filter(F("equals", F("toDayOfWeek", Person.birthday), 7)), 18)
self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18) self._test_qs(qs.filter(F("toDayOfWeek", Person.birthday) == 7), 18)
self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18) self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18)
# People born on 1976-10-01 # People born on 1976-10-01
self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1) self._test_qs(qs.filter(F("equals", Person.birthday, "1976-10-01")), 1)
self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 1))), 1) self._test_qs(qs.filter(F("equals", Person.birthday, date(1976, 10, 1))), 1)
self._test_qs(qs.filter(Person.birthday == date(1976, 10, 1)), 1) self._test_qs(qs.filter(Person.birthday == date(1976, 10, 1)), 1)
def test_func_as_field_value(self): def test_func_as_field_value(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96) self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96)
self._test_qs(qs.exclude(birthday=F.today()), 100) self._test_qs(qs.exclude(birthday=F.today()), 100)
self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100) self._test_qs(qs.filter(birthday__between=["1970-01-01", F.today()]), 100)
def test_in_and_not_in(self): def test_in_and_not_in(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(Person.first_name.isIn(['Ciaran', 'Elton'])), 4) self._test_qs(qs.filter(Person.first_name.isIn(["Ciaran", "Elton"])), 4)
self._test_qs(qs.filter(~Person.first_name.isIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.filter(~Person.first_name.isIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.filter(Person.first_name.isNotIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.filter(Person.first_name.isNotIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.exclude(Person.first_name.isIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.exclude(Person.first_name.isIn(["Ciaran", "Elton"])), 96)
# In subquery # In subquery
subquery = qs.filter(F.startsWith(Person.last_name, 'M')).only(Person.first_name) subquery = qs.filter(F.startsWith(Person.last_name, "M")).only(Person.first_name)
self._test_qs(qs.filter(Person.first_name.isIn(subquery)), 4) self._test_qs(qs.filter(Person.first_name.isIn(subquery)), 4)
def test_comparison_operators(self): def test_comparison_operators(self):
@ -213,14 +212,14 @@ class FuncsTestCase(TestCaseWithData):
dt = datetime(2018, 12, 31, 11, 22, 33) dt = datetime(2018, 12, 31, 11, 22, 33)
self._test_func(F.toYear(d), 2018) self._test_func(F.toYear(d), 2018)
self._test_func(F.toYear(dt), 2018) self._test_func(F.toYear(dt), 2018)
self._test_func(F.toISOYear(dt, 'Europe/Athens'), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1 self._test_func(F.toISOYear(dt, "Europe/Athens"), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toQuarter(d), 4) self._test_func(F.toQuarter(d), 4)
self._test_func(F.toQuarter(dt), 4) self._test_func(F.toQuarter(dt), 4)
self._test_func(F.toMonth(d), 12) self._test_func(F.toMonth(d), 12)
self._test_func(F.toMonth(dt), 12) self._test_func(F.toMonth(dt), 12)
self._test_func(F.toWeek(d), 52) self._test_func(F.toWeek(d), 52)
self._test_func(F.toWeek(dt), 52) self._test_func(F.toWeek(dt), 52)
self._test_func(F.toISOWeek(d), 1) # 2018-12-31 is ISO year 2019, week 1, day 1 self._test_func(F.toISOWeek(d), 1) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toISOWeek(dt), 1) self._test_func(F.toISOWeek(dt), 1)
self._test_func(F.toDayOfYear(d), 365) self._test_func(F.toDayOfYear(d), 365)
self._test_func(F.toDayOfYear(dt), 365) self._test_func(F.toDayOfYear(dt), 365)
@ -246,182 +245,218 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.toStartOfTenMinutes(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc)) self._test_func(F.toStartOfTenMinutes(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfWeek(dt), date(2018, 12, 30)) self._test_func(F.toStartOfWeek(dt), date(2018, 12, 30))
self._test_func(F.toTime(dt), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(F.toTime(dt), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toUnixTimestamp(dt, 'UTC'), int(dt.replace(tzinfo=pytz.utc).timestamp())) self._test_func(F.toUnixTimestamp(dt, "UTC"), int(dt.replace(tzinfo=pytz.utc).timestamp()))
self._test_func(F.toYYYYMM(d), 201812) self._test_func(F.toYYYYMM(d), 201812)
self._test_func(F.toYYYYMM(dt), 201812) self._test_func(F.toYYYYMM(dt), 201812)
self._test_func(F.toYYYYMM(dt, 'Europe/Athens'), 201812) self._test_func(F.toYYYYMM(dt, "Europe/Athens"), 201812)
self._test_func(F.toYYYYMMDD(d), 20181231) self._test_func(F.toYYYYMMDD(d), 20181231)
self._test_func(F.toYYYYMMDD(dt), 20181231) self._test_func(F.toYYYYMMDD(dt), 20181231)
self._test_func(F.toYYYYMMDD(dt, 'Europe/Athens'), 20181231) self._test_func(F.toYYYYMMDD(dt, "Europe/Athens"), 20181231)
self._test_func(F.toYYYYMMDDhhmmss(d), 20181231000000) self._test_func(F.toYYYYMMDDhhmmss(d), 20181231000000)
self._test_func(F.toYYYYMMDDhhmmss(dt, 'Europe/Athens'), 20181231132233) self._test_func(F.toYYYYMMDDhhmmss(dt, "Europe/Athens"), 20181231132233)
self._test_func(F.toRelativeYearNum(dt), 2018) self._test_func(F.toRelativeYearNum(dt), 2018)
self._test_func(F.toRelativeYearNum(dt, 'Europe/Athens'), 2018) self._test_func(F.toRelativeYearNum(dt, "Europe/Athens"), 2018)
self._test_func(F.toRelativeMonthNum(dt), 2018 * 12 + 12) self._test_func(F.toRelativeMonthNum(dt), 2018 * 12 + 12)
self._test_func(F.toRelativeMonthNum(dt, 'Europe/Athens'), 2018 * 12 + 12) self._test_func(F.toRelativeMonthNum(dt, "Europe/Athens"), 2018 * 12 + 12)
self._test_func(F.toRelativeWeekNum(dt), 2557) self._test_func(F.toRelativeWeekNum(dt), 2557)
self._test_func(F.toRelativeWeekNum(dt, 'Europe/Athens'), 2557) self._test_func(F.toRelativeWeekNum(dt, "Europe/Athens"), 2557)
self._test_func(F.toRelativeDayNum(dt), 17896) self._test_func(F.toRelativeDayNum(dt), 17896)
self._test_func(F.toRelativeDayNum(dt, 'Europe/Athens'), 17896) self._test_func(F.toRelativeDayNum(dt, "Europe/Athens"), 17896)
self._test_func(F.toRelativeHourNum(dt), 429515) self._test_func(F.toRelativeHourNum(dt), 429515)
self._test_func(F.toRelativeHourNum(dt, 'Europe/Athens'), 429515) self._test_func(F.toRelativeHourNum(dt, "Europe/Athens"), 429515)
self._test_func(F.toRelativeMinuteNum(dt), 25770922) self._test_func(F.toRelativeMinuteNum(dt), 25770922)
self._test_func(F.toRelativeMinuteNum(dt, 'Europe/Athens'), 25770922) self._test_func(F.toRelativeMinuteNum(dt, "Europe/Athens"), 25770922)
self._test_func(F.toRelativeSecondNum(dt), 1546255353) self._test_func(F.toRelativeSecondNum(dt), 1546255353)
self._test_func(F.toRelativeSecondNum(dt, 'Europe/Athens'), 1546255353) self._test_func(F.toRelativeSecondNum(dt, "Europe/Athens"), 1546255353)
self._test_func(F.timeSlot(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)) self._test_func(F.timeSlot(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc))
self._test_func(F.timeSlots(dt, 300), [datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)]) self._test_func(F.timeSlots(dt, 300), [datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)])
self._test_func(F.formatDateTime(dt, '%D %T', 'Europe/Athens'), '12/31/18 13:22:33') self._test_func(F.formatDateTime(dt, "%D %T", "Europe/Athens"), "12/31/18 13:22:33")
self._test_func(F.addDays(d, 7), date(2019, 1, 7)) self._test_func(F.addDays(d, 7), date(2019, 1, 7))
self._test_func(F.addDays(dt, 7, 'Europe/Athens')) self._test_func(F.addDays(dt, 7, "Europe/Athens"))
self._test_func(F.addHours(dt, 7, 'Europe/Athens')) self._test_func(F.addHours(dt, 7, "Europe/Athens"))
self._test_func(F.addMinutes(dt, 7, 'Europe/Athens')) self._test_func(F.addMinutes(dt, 7, "Europe/Athens"))
self._test_func(F.addMonths(d, 7), date(2019, 7, 31)) self._test_func(F.addMonths(d, 7), date(2019, 7, 31))
self._test_func(F.addMonths(dt, 7, 'Europe/Athens')) self._test_func(F.addMonths(dt, 7, "Europe/Athens"))
self._test_func(F.addQuarters(d, 7)) self._test_func(F.addQuarters(d, 7))
self._test_func(F.addQuarters(dt, 7, 'Europe/Athens')) self._test_func(F.addQuarters(dt, 7, "Europe/Athens"))
self._test_func(F.addSeconds(d, 7)) self._test_func(F.addSeconds(d, 7))
self._test_func(F.addSeconds(dt, 7, 'Europe/Athens')) self._test_func(F.addSeconds(dt, 7, "Europe/Athens"))
self._test_func(F.addWeeks(d, 7)) self._test_func(F.addWeeks(d, 7))
self._test_func(F.addWeeks(dt, 7, 'Europe/Athens')) self._test_func(F.addWeeks(dt, 7, "Europe/Athens"))
self._test_func(F.addYears(d, 7)) self._test_func(F.addYears(d, 7))
self._test_func(F.addYears(dt, 7, 'Europe/Athens')) self._test_func(F.addYears(dt, 7, "Europe/Athens"))
self._test_func(F.subtractDays(d, 3)) self._test_func(F.subtractDays(d, 3))
self._test_func(F.subtractDays(dt, 3, 'Europe/Athens')) self._test_func(F.subtractDays(dt, 3, "Europe/Athens"))
self._test_func(F.subtractHours(d, 3)) self._test_func(F.subtractHours(d, 3))
self._test_func(F.subtractHours(dt, 3, 'Europe/Athens')) self._test_func(F.subtractHours(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMinutes(d, 3)) self._test_func(F.subtractMinutes(d, 3))
self._test_func(F.subtractMinutes(dt, 3, 'Europe/Athens')) self._test_func(F.subtractMinutes(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMonths(d, 3)) self._test_func(F.subtractMonths(d, 3))
self._test_func(F.subtractMonths(dt, 3, 'Europe/Athens')) self._test_func(F.subtractMonths(dt, 3, "Europe/Athens"))
self._test_func(F.subtractQuarters(d, 3)) self._test_func(F.subtractQuarters(d, 3))
self._test_func(F.subtractQuarters(dt, 3, 'Europe/Athens')) self._test_func(F.subtractQuarters(dt, 3, "Europe/Athens"))
self._test_func(F.subtractSeconds(d, 3)) self._test_func(F.subtractSeconds(d, 3))
self._test_func(F.subtractSeconds(dt, 3, 'Europe/Athens')) self._test_func(F.subtractSeconds(dt, 3, "Europe/Athens"))
self._test_func(F.subtractWeeks(d, 3)) self._test_func(F.subtractWeeks(d, 3))
self._test_func(F.subtractWeeks(dt, 3, 'Europe/Athens')) self._test_func(F.subtractWeeks(dt, 3, "Europe/Athens"))
self._test_func(F.subtractYears(d, 3)) self._test_func(F.subtractYears(d, 3))
self._test_func(F.subtractYears(dt, 3, 'Europe/Athens')) self._test_func(F.subtractYears(dt, 3, "Europe/Athens"))
self._test_func(F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3)) self._test_func(
self._test_func(F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3)) F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3)
self._test_func(F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000)) )
self._test_func(
F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3)
)
self._test_func(
F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000)
)
def test_date_functions__utc_only(self): def test_date_functions__utc_only(self):
if self.database.server_timezone != pytz.utc: if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone') raise unittest.SkipTest("This test must run with UTC as the server timezone")
d = date(2018, 12, 31) d = date(2018, 12, 31)
dt = datetime(2018, 12, 31, 11, 22, 33) dt = datetime(2018, 12, 31, 11, 22, 33)
athens_tz = pytz.timezone('Europe/Athens') athens_tz = pytz.timezone("Europe/Athens")
self._test_func(F.toHour(dt), 11) self._test_func(F.toHour(dt), 11)
self._test_func(F.toStartOfDay(dt), datetime(2018, 12, 31, 0, 0, 0, tzinfo=pytz.utc)) self._test_func(F.toStartOfDay(dt), datetime(2018, 12, 31, 0, 0, 0, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, pytz.utc), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(F.toTime(dt, pytz.utc), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, 'Europe/Athens'), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33))) self._test_func(F.toTime(dt, "Europe/Athens"), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)))
self._test_func(F.toTime(dt, athens_tz), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33))) self._test_func(F.toTime(dt, athens_tz), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)))
self._test_func(F.toTimeZone(dt, 'Europe/Athens'), athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33))) self._test_func(F.toTimeZone(dt, "Europe/Athens"), athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33)))
self._test_func(F.now(), datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)) # FIXME this may fail if the timing is just right self._test_func(
F.now(), datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)
) # FIXME this may fail if the timing is just right
self._test_func(F.today(), datetime.utcnow().date()) self._test_func(F.today(), datetime.utcnow().date())
self._test_func(F.yesterday(), datetime.utcnow().date() - timedelta(days=1)) self._test_func(F.yesterday(), datetime.utcnow().date() - timedelta(days=1))
self._test_func(F.toYYYYMMDDhhmmss(dt), 20181231112233) self._test_func(F.toYYYYMMDDhhmmss(dt), 20181231112233)
self._test_func(F.formatDateTime(dt, '%D %T'), '12/31/18 11:22:33') self._test_func(F.formatDateTime(dt, "%D %T"), "12/31/18 11:22:33")
self._test_func(F.addHours(d, 7), datetime(2018, 12, 31, 7, 0, 0, tzinfo=pytz.utc)) self._test_func(F.addHours(d, 7), datetime(2018, 12, 31, 7, 0, 0, tzinfo=pytz.utc))
self._test_func(F.addMinutes(d, 7), datetime(2018, 12, 31, 0, 7, 0, tzinfo=pytz.utc)) self._test_func(F.addMinutes(d, 7), datetime(2018, 12, 31, 0, 7, 0, tzinfo=pytz.utc))
def test_type_conversion_functions(self): def test_type_conversion_functions(self):
for f in (F.toUInt8, F.toUInt16, F.toUInt32, F.toUInt64, F.toInt8, F.toInt16, F.toInt32, F.toInt64, F.toFloat32, F.toFloat64): for f in (
F.toUInt8,
F.toUInt16,
F.toUInt32,
F.toUInt64,
F.toInt8,
F.toInt16,
F.toInt32,
F.toInt64,
F.toFloat32,
F.toFloat64,
):
self._test_func(f(17), 17) self._test_func(f(17), 17)
self._test_func(f('17'), 17) self._test_func(f("17"), 17)
for f in (F.toUInt8OrZero, F.toUInt16OrZero, F.toUInt32OrZero, F.toUInt64OrZero, F.toInt8OrZero, F.toInt16OrZero, F.toInt32OrZero, F.toInt64OrZero, F.toFloat32OrZero, F.toFloat64OrZero): for f in (
self._test_func(f('17'), 17) F.toUInt8OrZero,
self._test_func(f('a'), 0) F.toUInt16OrZero,
F.toUInt32OrZero,
F.toUInt64OrZero,
F.toInt8OrZero,
F.toInt16OrZero,
F.toInt32OrZero,
F.toInt64OrZero,
F.toFloat32OrZero,
F.toFloat64OrZero,
):
self._test_func(f("17"), 17)
self._test_func(f("a"), 0)
for f in (F.toDecimal32, F.toDecimal64, F.toDecimal128): for f in (F.toDecimal32, F.toDecimal64, F.toDecimal128):
self._test_func(f(17.17, 2), Decimal('17.17')) self._test_func(f(17.17, 2), Decimal("17.17"))
self._test_func(f('17.17', 2), Decimal('17.17')) self._test_func(f("17.17", 2), Decimal("17.17"))
self._test_func(F.toDate('2018-12-31'), date(2018, 12, 31)) self._test_func(F.toDate("2018-12-31"), date(2018, 12, 31))
self._test_func(F.toString(123), '123') self._test_func(F.toString(123), "123")
self._test_func(F.toFixedString('123', 5), '123') self._test_func(F.toFixedString("123", 5), "123")
self._test_func(F.toStringCutToZero('123\0'), '123') self._test_func(F.toStringCutToZero("123\0"), "123")
self._test_func(F.CAST(17, 'String'), '17') self._test_func(F.CAST(17, "String"), "17")
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffort("31/12/2019 10:05AM", "Europe/Athens"))
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self._test_func(F.parseDateTimeBestEffort('foo')) self._test_func(F.parseDateTimeBestEffort("foo"))
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrNull('foo'), None) self._test_func(F.parseDateTimeBestEffortOrNull("foo"), None)
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrZero('foo'), DateTimeField.class_default) self._test_func(F.parseDateTimeBestEffortOrZero("foo"), DateTimeField.class_default)
def test_type_conversion_functions__utc_only(self): def test_type_conversion_functions__utc_only(self):
if self.database.server_timezone != pytz.utc: if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone') raise unittest.SkipTest("This test must run with UTC as the server timezone")
self._test_func(F.toDateTime('2018-12-31 11:22:33'), datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(F.toDateTime("2018-12-31 11:22:33"), datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toDateTime64('2018-12-31 11:22:33.001', 6), datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc)) self._test_func(
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) F.toDateTime64("2018-12-31 11:22:33.001", 6), datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc)
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) )
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) self._test_func(F.parseDateTimeBestEffort("31/12/2019 10:05AM"), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc))
self._test_func(
F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM"), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)
)
self._test_func(
F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM"), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)
)
def test_string_functions(self): def test_string_functions(self):
self._test_func(F.empty(''), 1) self._test_func(F.empty(""), 1)
self._test_func(F.empty('x'), 0) self._test_func(F.empty("x"), 0)
self._test_func(F.notEmpty(''), 0) self._test_func(F.notEmpty(""), 0)
self._test_func(F.notEmpty('x'), 1) self._test_func(F.notEmpty("x"), 1)
self._test_func(F.length('x'), 1) self._test_func(F.length("x"), 1)
self._test_func(F.lengthUTF8('x'), 1) self._test_func(F.lengthUTF8("x"), 1)
self._test_func(F.lower('Ab'), 'ab') self._test_func(F.lower("Ab"), "ab")
self._test_func(F.upper('Ab'), 'AB') self._test_func(F.upper("Ab"), "AB")
self._test_func(F.lowerUTF8('Ab'), 'ab') self._test_func(F.lowerUTF8("Ab"), "ab")
self._test_func(F.upperUTF8('Ab'), 'AB') self._test_func(F.upperUTF8("Ab"), "AB")
self._test_func(F.reverse('Ab'), 'bA') self._test_func(F.reverse("Ab"), "bA")
self._test_func(F.reverseUTF8('Ab'), 'bA') self._test_func(F.reverseUTF8("Ab"), "bA")
self._test_func(F.concat('Ab', 'Cd', 'Ef'), 'AbCdEf') self._test_func(F.concat("Ab", "Cd", "Ef"), "AbCdEf")
self._test_func(F.substring('123456', 3, 2), '34') self._test_func(F.substring("123456", 3, 2), "34")
self._test_func(F.substringUTF8('123456', 3, 2), '34') self._test_func(F.substringUTF8("123456", 3, 2), "34")
self._test_func(F.appendTrailingCharIfAbsent('Hello', '!'), 'Hello!') self._test_func(F.appendTrailingCharIfAbsent("Hello", "!"), "Hello!")
self._test_func(F.appendTrailingCharIfAbsent('Hello!', '!'), 'Hello!') self._test_func(F.appendTrailingCharIfAbsent("Hello!", "!"), "Hello!")
self._test_func(F.convertCharset(F.convertCharset('Hello', 'latin1', 'utf16'), 'utf16', 'latin1'), 'Hello') self._test_func(F.convertCharset(F.convertCharset("Hello", "latin1", "utf16"), "utf16", "latin1"), "Hello")
self._test_func(F.startsWith('aaa', 'aa'), True) self._test_func(F.startsWith("aaa", "aa"), True)
self._test_func(F.startsWith('aaa', 'bb'), False) self._test_func(F.startsWith("aaa", "bb"), False)
self._test_func(F.endsWith('aaa', 'aa'), True) self._test_func(F.endsWith("aaa", "aa"), True)
self._test_func(F.endsWith('aaa', 'bb'), False) self._test_func(F.endsWith("aaa", "bb"), False)
self._test_func(F.trimLeft(' abc '), 'abc ') self._test_func(F.trimLeft(" abc "), "abc ")
self._test_func(F.trimRight(' abc '), ' abc') self._test_func(F.trimRight(" abc "), " abc")
self._test_func(F.trimBoth(' abc '), 'abc') self._test_func(F.trimBoth(" abc "), "abc")
self._test_func(F.CRC32('whoops'), 3361378926) self._test_func(F.CRC32("whoops"), 3361378926)
def test_string_search_functions(self): def test_string_search_functions(self):
self._test_func(F.position('Hello, world!', '!'), 13) self._test_func(F.position("Hello, world!", "!"), 13)
self._test_func(F.positionCaseInsensitive('Hello, world!', 'hello'), 1) self._test_func(F.positionCaseInsensitive("Hello, world!", "hello"), 1)
self._test_func(F.positionUTF8('Привет, мир!', '!'), 12) self._test_func(F.positionUTF8("Привет, мир!", "!"), 12)
self._test_func(F.positionCaseInsensitiveUTF8('Привет, мир!', 'Мир'), 9) self._test_func(F.positionCaseInsensitiveUTF8("Привет, мир!", "Мир"), 9)
self._test_func(F.like('Hello, world!', '%ll%'), 1) self._test_func(F.like("Hello, world!", "%ll%"), 1)
self._test_func(F.notLike('Hello, world!', '%ll%'), 0) self._test_func(F.notLike("Hello, world!", "%ll%"), 0)
self._test_func(F.match('Hello, world!', '[lmnop]{3}'), 1) self._test_func(F.match("Hello, world!", "[lmnop]{3}"), 1)
self._test_func(F.extract('Hello, world!', '[lmnop]{3}'), 'llo') self._test_func(F.extract("Hello, world!", "[lmnop]{3}"), "llo")
self._test_func(F.extractAll('Hello, world!', '[a-z]+'), ['ello', 'world']) self._test_func(F.extractAll("Hello, world!", "[a-z]+"), ["ello", "world"])
self._test_func(F.ngramDistance('Hello', 'Hello'), 0) self._test_func(F.ngramDistance("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitive('Hello', 'hello'), 0) self._test_func(F.ngramDistanceCaseInsensitive("Hello", "hello"), 0)
self._test_func(F.ngramDistanceUTF8('Hello', 'Hello'), 0) self._test_func(F.ngramDistanceUTF8("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitiveUTF8('Hello', 'hello'), 0) self._test_func(F.ngramDistanceCaseInsensitiveUTF8("Hello", "hello"), 0)
self._test_func(F.ngramSearch('Hello', 'Hello'), 1) self._test_func(F.ngramSearch("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitive('Hello', 'hello'), 1) self._test_func(F.ngramSearchCaseInsensitive("Hello", "hello"), 1)
self._test_func(F.ngramSearchUTF8('Hello', 'Hello'), 1) self._test_func(F.ngramSearchUTF8("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitiveUTF8('Hello', 'hello'), 1) self._test_func(F.ngramSearchCaseInsensitiveUTF8("Hello", "hello"), 1)
def test_base64_functions(self): def test_base64_functions(self):
try: try:
self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.base64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.tryBase64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(':-)')) self._test_func(F.tryBase64Decode(":-)"))
except ServerError as e: except ServerError as e:
# ClickHouse version that doesn't support these functions # ClickHouse version that doesn't support these functions
raise unittest.SkipTest(e.message) raise unittest.SkipTest(e.message)
def test_replace_functions(self): def test_replace_functions(self):
haystack = 'hello' haystack = "hello"
self._test_func(F.replace(haystack, 'l', 'L'), 'heLLo') self._test_func(F.replace(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceAll(haystack, 'l', 'L'), 'heLLo') self._test_func(F.replaceAll(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceOne(haystack, 'l', 'L'), 'heLlo') self._test_func(F.replaceOne(haystack, "l", "L"), "heLlo")
self._test_func(F.replaceRegexpAll(haystack, '[eo]', 'X'), 'hXllX') self._test_func(F.replaceRegexpAll(haystack, "[eo]", "X"), "hXllX")
self._test_func(F.replaceRegexpOne(haystack, '[eo]', 'X'), 'hXllo') self._test_func(F.replaceRegexpOne(haystack, "[eo]", "X"), "hXllo")
self._test_func(F.regexpQuoteMeta('[eo]'), '\\[eo\\]') self._test_func(F.regexpQuoteMeta("[eo]"), "\\[eo\\]")
def test_math_functions(self): def test_math_functions(self):
x = 17 x = 17
@ -515,15 +550,15 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.arrayDifference(arr), [0, 1, 1]) self._test_func(F.arrayDifference(arr), [0, 1, 1])
self._test_func(F.arrayDistinct(arr + arr), arr) self._test_func(F.arrayDistinct(arr + arr), arr)
self._test_func(F.arrayIntersect(arr, [3, 4]), [3]) self._test_func(F.arrayIntersect(arr, [3, 4]), [3])
self._test_func(F.arrayReduce('min', arr), 1) self._test_func(F.arrayReduce("min", arr), 1)
self._test_func(F.arrayReverse(arr), [3, 2, 1]) self._test_func(F.arrayReverse(arr), [3, 2, 1])
def test_split_and_merge_functions(self): def test_split_and_merge_functions(self):
self._test_func(F.splitByChar('_', 'a_b_c'), ['a', 'b', 'c']) self._test_func(F.splitByChar("_", "a_b_c"), ["a", "b", "c"])
self._test_func(F.splitByString('__', 'a__b__c'), ['a', 'b', 'c']) self._test_func(F.splitByString("__", "a__b__c"), ["a", "b", "c"])
self._test_func(F.arrayStringConcat(['a', 'b', 'c']), 'abc') self._test_func(F.arrayStringConcat(["a", "b", "c"]), "abc")
self._test_func(F.arrayStringConcat(['a', 'b', 'c'], '_'), 'a_b_c') self._test_func(F.arrayStringConcat(["a", "b", "c"], "_"), "a_b_c")
self._test_func(F.alphaTokens('aaa.bbb.111'), ['aaa', 'bbb']) self._test_func(F.alphaTokens("aaa.bbb.111"), ["aaa", "bbb"])
def test_bit_functions(self): def test_bit_functions(self):
x = 17 x = 17
@ -546,10 +581,12 @@ class FuncsTestCase(TestCaseWithData):
def test_bitmap_functions(self): def test_bitmap_functions(self):
self._test_func(F.bitmapToArray(F.bitmapBuild([1, 2, 3])), [1, 2, 3]) self._test_func(F.bitmapToArray(F.bitmapBuild([1, 2, 3])), [1, 2, 3])
self._test_func(F.bitmapContains(F.bitmapBuild([1, 5, 7, 9]), F.toUInt32(9)), 1) self._test_func(F.bitmapContains(F.bitmapBuild([1, 5, 7, 9]), F.toUInt32(9)), 1)
self._test_func(F.bitmapHasAny(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 1) self._test_func(F.bitmapHasAny(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 1)
self._test_func(F.bitmapHasAll(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 0) self._test_func(F.bitmapHasAll(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 0)
self._test_func(F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [3]) self._test_func(F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [3])
self._test_func(F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 3, 4, 5]) self._test_func(
F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 3, 4, 5]
)
self._test_func(F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 4, 5]) self._test_func(F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 4, 5])
self._test_func(F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2]) self._test_func(F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2])
self._test_func(F.bitmapCardinality(F.bitmapBuild([1, 2, 3, 4, 5])), 5) self._test_func(F.bitmapCardinality(F.bitmapBuild([1, 2, 3, 4, 5])), 5)
@ -559,10 +596,10 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 2) self._test_func(F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 2)
def test_hash_functions(self): def test_hash_functions(self):
args = ['x', 'y', 'z'] args = ["x", "y", "z"]
x = 17 x = 17
s = 'hello' s = "hello"
url = 'http://example.com/a/b/c/d' url = "http://example.com/a/b/c/d"
self._test_func(F.hex(F.MD5(s))) self._test_func(F.hex(F.MD5(s)))
self._test_func(F.hex(F.sipHash128(s))) self._test_func(F.hex(F.sipHash128(s)))
self._test_func(F.hex(F.cityHash64(*args))) self._test_func(F.hex(F.cityHash64(*args)))
@ -594,17 +631,18 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.rand(17)) self._test_func(F.rand(17))
self._test_func(F.rand64()) self._test_func(F.rand64())
self._test_func(F.rand64(17)) self._test_func(F.rand64(17))
if self.database.server_version >= (19, 15): # buggy in older versions if self.database.server_version >= (19, 15): # buggy in older versions
self._test_func(F.randConstant()) self._test_func(F.randConstant())
self._test_func(F.randConstant(17)) self._test_func(F.randConstant(17))
def test_encoding_functions(self): def test_encoding_functions(self):
self._test_func(F.hex(F.unhex('0FA1')), '0FA1') self._test_func(F.hex(F.unhex("0FA1")), "0FA1")
self._test_func(F.bitmaskToArray(17)) self._test_func(F.bitmaskToArray(17))
self._test_func(F.bitmaskToList(18)) self._test_func(F.bitmaskToList(18))
def test_uuid_functions(self): def test_uuid_functions(self):
from uuid import UUID from uuid import UUID
uuid = self._test_func(F.generateUUIDv4()) uuid = self._test_func(F.generateUUIDv4())
self.assertEqual(type(uuid), UUID) self.assertEqual(type(uuid), UUID)
s = str(uuid) s = str(uuid)
@ -612,17 +650,20 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.UUIDNumToString(F.UUIDStringToNum(s)), s) self._test_func(F.UUIDNumToString(F.UUIDStringToNum(s)), s)
def test_ip_funcs(self): def test_ip_funcs(self):
self._test_func(F.IPv4NumToString(F.toUInt32(1)), '0.0.0.1') self._test_func(F.IPv4NumToString(F.toUInt32(1)), "0.0.0.1")
self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), '0.0.0.xxx') self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), "0.0.0.xxx")
self._test_func(F.IPv4StringToNum('0.0.0.17'), 17) self._test_func(F.IPv4StringToNum("0.0.0.17"), 17)
self._test_func(F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum('192.168.0.1'))), '::ffff:192.168.0.1') self._test_func(F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum("192.168.0.1"))), "::ffff:192.168.0.1")
self._test_func(F.IPv6NumToString(F.IPv6StringToNum('2a02:6b8::11')), '2a02:6b8::11') self._test_func(F.IPv6NumToString(F.IPv6StringToNum("2a02:6b8::11")), "2a02:6b8::11")
self._test_func(F.toIPv4('10.20.30.40'), IPv4Address('10.20.30.40')) self._test_func(F.toIPv4("10.20.30.40"), IPv4Address("10.20.30.40"))
self._test_func(F.toIPv6('2001:438:ffff::407d:1bc1'), IPv6Address('2001:438:ffff::407d:1bc1')) self._test_func(F.toIPv6("2001:438:ffff::407d:1bc1"), IPv6Address("2001:438:ffff::407d:1bc1"))
self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16), self._test_func(
[IPv4Address('192.168.0.0'), IPv4Address('192.168.255.255')]) F.IPv4CIDRToRange(F.toIPv4("192.168.5.2"), 16), [IPv4Address("192.168.0.0"), IPv4Address("192.168.255.255")]
self._test_func(F.IPv6CIDRToRange(F.toIPv6('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 32), )
[IPv6Address('2001:db8::'), IPv6Address('2001:db8:ffff:ffff:ffff:ffff:ffff:ffff')]) self._test_func(
F.IPv6CIDRToRange(F.toIPv6("2001:0db8:0000:85a3:0000:0000:ac1f:8001"), 32),
[IPv6Address("2001:db8::"), IPv6Address("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff")],
)
def test_aggregate_funcs(self): def test_aggregate_funcs(self):
self._test_aggr(F.any(Person.first_name)) self._test_aggr(F.any(Person.first_name))
@ -649,32 +690,32 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.varSamp(Person.height)) self._test_aggr(F.varSamp(Person.height))
def test_aggregate_funcs__or_default(self): def test_aggregate_funcs__or_default(self):
self.database.raw('TRUNCATE TABLE person') self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrDefault(), 0) self._test_aggr(F.countOrDefault(), 0)
self._test_aggr(F.maxOrDefault(Person.height), 0) self._test_aggr(F.maxOrDefault(Person.height), 0)
def test_aggregate_funcs__or_null(self): def test_aggregate_funcs__or_null(self):
self.database.raw('TRUNCATE TABLE person') self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrNull(), None) self._test_aggr(F.countOrNull(), None)
self._test_aggr(F.maxOrNull(Person.height), None) self._test_aggr(F.maxOrNull(Person.height), None)
def test_aggregate_funcs__if(self): def test_aggregate_funcs__if(self):
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H')) self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > "H"))
self._test_aggr(F.countIf(Person.last_name > 'H'), 57) self._test_aggr(F.countIf(Person.last_name > "H"), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6) self._test_aggr(F.minIf(Person.height, Person.last_name > "H"), 1.6)
def test_aggregate_funcs__or_default_if(self): def test_aggregate_funcs__or_default_if(self):
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z')) self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0) self._test_aggr(F.countOrDefaultIf(Person.last_name > "Z"), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0) self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > "Z"), 0)
def test_aggregate_funcs__or_null_if(self): def test_aggregate_funcs__or_null_if(self):
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z')) self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None) self._test_aggr(F.countOrNullIf(Person.last_name > "Z"), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None) self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > "Z"), None)
def test_quantile_funcs(self): def test_quantile_funcs(self):
cond = Person.last_name > 'H' cond = Person.last_name > "H"
weight_expr = F.toUInt32(F.round(Person.height)) weight_expr = F.toUInt32(F.round(Person.height))
# Quantile # Quantile
self._test_aggr(F.quantile(0.9)(Person.height)) self._test_aggr(F.quantile(0.9)(Person.height))
@ -712,13 +753,13 @@ class FuncsTestCase(TestCaseWithData):
def test_top_k_funcs(self): def test_top_k_funcs(self):
self._test_aggr(F.topK(3)(Person.height)) self._test_aggr(F.topK(3)(Person.height))
self._test_aggr(F.topKOrDefault(3)(Person.height)) self._test_aggr(F.topKOrDefault(3)(Person.height))
self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > 'H')) self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > "H"))
self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > 'H')) self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > "H"))
weight_expr = F.toUInt32(F.round(Person.height)) weight_expr = F.toUInt32(F.round(Person.height))
self._test_aggr(F.topKWeighted(3)(Person.height, weight_expr)) self._test_aggr(F.topKWeighted(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedOrDefault(3)(Person.height, weight_expr)) self._test_aggr(F.topKWeightedOrDefault(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > 'H')) self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > "H"))
self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > 'H')) self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > "H"))
def test_null_funcs(self): def test_null_funcs(self):
self._test_func(F.ifNull(17, 18), 17) self._test_func(F.ifNull(17, 18), 17)

View File

@ -4,11 +4,10 @@ from clickhouse_orm import *
class IndexesTest(unittest.TestCase): class IndexesTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -29,4 +28,4 @@ class ModelWithIndexes(Model):
i4 = Index(F.lower(f2), type=Index.tokenbf_v1(256, 2, 0), granularity=2) i4 = Index(F.lower(f2), type=Index.tokenbf_v1(256, 2, 0), granularity=2)
i5 = Index((F.toQuarter(date), f2), type=Index.bloom_filter(), granularity=3) i5 = Index((F.toQuarter(date), f2), type=Index.bloom_filter(), granularity=3)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))

View File

@ -9,17 +9,16 @@ from clickhouse_orm.engines import *
class InheritanceTestCase(unittest.TestCase): class InheritanceTestCase(unittest.TestCase):
def assertFieldNames(self, model_class, names): def assertFieldNames(self, model_class, names):
self.assertEqual(names, list(model_class.fields())) self.assertEqual(names, list(model_class.fields()))
def test_field_inheritance(self): def test_field_inheritance(self):
self.assertFieldNames(ParentModel, ['date_field', 'int_field']) self.assertFieldNames(ParentModel, ["date_field", "int_field"])
self.assertFieldNames(Model1, ['date_field', 'int_field', 'string_field']) self.assertFieldNames(Model1, ["date_field", "int_field", "string_field"])
self.assertFieldNames(Model2, ['date_field', 'int_field', 'float_field']) self.assertFieldNames(Model2, ["date_field", "int_field", "float_field"])
def test_create_table_sql(self): def test_create_table_sql(self):
default_db = Database('default') default_db = Database("default")
sql1 = ParentModel.create_table_sql(default_db) sql1 = ParentModel.create_table_sql(default_db)
sql2 = Model1.create_table_sql(default_db) sql2 = Model1.create_table_sql(default_db)
sql3 = Model2.create_table_sql(default_db) sql3 = Model2.create_table_sql(default_db)
@ -28,11 +27,11 @@ class InheritanceTestCase(unittest.TestCase):
self.assertNotEqual(sql2, sql3) self.assertNotEqual(sql2, sql3)
def test_get_field(self): def test_get_field(self):
self.assertIsNotNone(ParentModel().get_field('date_field')) self.assertIsNotNone(ParentModel().get_field("date_field"))
self.assertIsNone(ParentModel().get_field('string_field')) self.assertIsNone(ParentModel().get_field("string_field"))
self.assertIsNotNone(Model1().get_field('date_field')) self.assertIsNotNone(Model1().get_field("date_field"))
self.assertIsNotNone(Model1().get_field('string_field')) self.assertIsNotNone(Model1().get_field("string_field"))
self.assertIsNone(Model1().get_field('float_field')) self.assertIsNone(Model1().get_field("float_field"))
class ParentModel(Model): class ParentModel(Model):
@ -40,7 +39,7 @@ class ParentModel(Model):
date_field = DateField() date_field = DateField()
int_field = Int32Field() int_field = Int32Field()
engine = MergeTree('date_field', ('int_field', 'date_field')) engine = MergeTree("date_field", ("int_field", "date_field"))
class Model1(ParentModel): class Model1(ParentModel):

View File

@ -7,54 +7,50 @@ from clickhouse_orm.engines import Memory
class IPFieldsTest(unittest.TestCase): class IPFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_ipv4_field(self): def test_ipv4_field(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = IPv4Field() f = IPv4Field()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same ip) # Check valid values (all values are the same ip)
values = [ values = ["1.2.3.4", b"\x01\x02\x03\x04", 16909060, IPv4Address("1.2.3.4")]
'1.2.3.4',
b'\x01\x02\x03\x04',
16909060,
IPv4Address('1.2.3.4')
]
for index, value in enumerate(values): for index, value in enumerate(values):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
self.database.insert([rec]) self.database.insert([rec])
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv4Address(values[0])) self.assertEqual(rec.f, IPv4Address(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)
def test_ipv6_field(self): def test_ipv6_field(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = IPv6Field() f = IPv6Field()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same ip) # Check valid values (all values are the same ip)
values = [ values = [
'2a02:e980:1e::1', "2a02:e980:1e::1",
b'*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01', b"*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
55842696359362256756849388082849382401, 55842696359362256756849388082849382401,
IPv6Address('2a02:e980:1e::1') IPv6Address("2a02:e980:1e::1"),
] ]
for index, value in enumerate(values): for index, value in enumerate(values):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
@ -62,7 +58,6 @@ class IPFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv6Address(values[0])) self.assertEqual(rec.f, IPv6Address(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)

View File

@ -1,4 +1,3 @@
import unittest import unittest
import json import json
@ -6,9 +5,8 @@ from clickhouse_orm import database, engines, fields, models
class JoinTest(unittest.TestCase): class JoinTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = database.Database('test-db', log_statements=True) self.database = database.Database("test-db", log_statements=True)
self.database.create_table(Foo) self.database.create_table(Foo)
self.database.create_table(Bar) self.database.create_table(Bar)
self.database.insert([Foo(id=i) for i in range(3)]) self.database.insert([Foo(id=i) for i in range(3)])
@ -29,8 +27,16 @@ class JoinTest(unittest.TestCase):
self.print_res("SELECT b FROM $db.{} ALL LEFT JOIN $db.{} USING id".format(Foo.table_name(), Bar.table_name())) self.print_res("SELECT b FROM $db.{} ALL LEFT JOIN $db.{} USING id".format(Foo.table_name(), Bar.table_name()))
def test_with_subquery(self): def test_with_subquery(self):
self.print_res("SELECT b FROM {} ALL LEFT JOIN (SELECT * from {}) subquery USING id".format(Foo.table_name(), Bar.table_name())) self.print_res(
self.print_res("SELECT b FROM $db.{} ALL LEFT JOIN (SELECT * from $db.{}) subquery USING id".format(Foo.table_name(), Bar.table_name())) "SELECT b FROM {} ALL LEFT JOIN (SELECT * from {}) subquery USING id".format(
Foo.table_name(), Bar.table_name()
)
)
self.print_res(
"SELECT b FROM $db.{} ALL LEFT JOIN (SELECT * from $db.{}) subquery USING id".format(
Foo.table_name(), Bar.table_name()
)
)
class Foo(models.Model): class Foo(models.Model):

View File

@ -9,24 +9,21 @@ from clickhouse_orm.funcs import F
class MaterializedFieldsTest(unittest.TestCase): class MaterializedFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithMaterializedFields) self.database.create_table(ModelWithMaterializedFields)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithMaterializedFields( instance = ModelWithMaterializedFields(date_time_field="2016-08-30 11:00:00", int_field=-10, str_field="TEST")
date_time_field='2016-08-30 11:00:00',
int_field=-10,
str_field='TEST'
)
self.database.insert([instance]) self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields # We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func' \ query = (
' FROM $db.%s ORDER BY mat_date' % ModelWithMaterializedFields.table_name() "SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func"
" FROM $db.%s ORDER BY mat_date" % ModelWithMaterializedFields.table_name()
)
for model_cls in (ModelWithMaterializedFields, None): for model_cls in (ModelWithMaterializedFields, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_assignment_error(self): def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields. # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithMaterializedFields() instance = ModelWithMaterializedFields()
for value in ('x', [date.today()], ['aaa'], [None]): for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.mat_date = value instance.mat_date = value
@ -51,10 +48,10 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_duplicate_default(self): def test_duplicate_default(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(materialized='str_field', default='with default') StringField(materialized="str_field", default="with default")
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(materialized='str_field', alias='str_field') StringField(materialized="str_field", alias="str_field")
def test_default_value(self): def test_default_value(self):
instance = ModelWithMaterializedFields() instance = ModelWithMaterializedFields()
@ -66,9 +63,9 @@ class ModelWithMaterializedFields(Model):
date_time_field = DateTimeField() date_time_field = DateTimeField()
str_field = StringField() str_field = StringField()
mat_str = StringField(materialized='lower(str_field)') mat_str = StringField(materialized="lower(str_field)")
mat_int = Int32Field(materialized='abs(int_field)') mat_int = Int32Field(materialized="abs(int_field)")
mat_date = DateField(materialized=u'toDate(date_time_field)') mat_date = DateField(materialized=u"toDate(date_time_field)")
mat_func = StringField(materialized=F.lower(str_field)) mat_func = StringField(materialized=F.lower(str_field))
engine = MergeTree('mat_date', ('mat_date',)) engine = MergeTree("mat_date", ("mat_date",))

View File

@ -7,20 +7,22 @@ from clickhouse_orm.engines import *
from clickhouse_orm.migrations import MigrationHistory from clickhouse_orm.migrations import MigrationHistory
from enum import Enum from enum import Enum
# Add tests to path so that migrations will be importable # Add tests to path so that migrations will be importable
import sys, os import sys, os
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
import logging import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
class MigrationsTestCase(unittest.TestCase): class MigrationsTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.drop_table(MigrationHistory) self.database.drop_table(MigrationHistory)
def tearDown(self): def tearDown(self):
@ -35,123 +37,157 @@ class MigrationsTestCase(unittest.TestCase):
return [(row.name, row.type) for row in self.database.select(query)] return [(row.name, row.type) for row in self.database.select(query)]
def get_table_def(self, model_class): def get_table_def(self, model_class):
return self.database.raw('SHOW CREATE TABLE $db.`%s`' % model_class.table_name()) return self.database.raw("SHOW CREATE TABLE $db.`%s`" % model_class.table_name())
def test_migrations(self): def test_migrations(self):
# Creation and deletion of table # Creation and deletion of table
self.database.migrate('tests.sample_migrations', 1) self.database.migrate("tests.sample_migrations", 1)
self.assertTrue(self.table_exists(Model1)) self.assertTrue(self.table_exists(Model1))
self.database.migrate('tests.sample_migrations', 2) self.database.migrate("tests.sample_migrations", 2)
self.assertFalse(self.table_exists(Model1)) self.assertFalse(self.table_exists(Model1))
self.database.migrate('tests.sample_migrations', 3) self.database.migrate("tests.sample_migrations", 3)
self.assertTrue(self.table_exists(Model1)) self.assertTrue(self.table_exists(Model1))
# Adding, removing and altering simple fields # Adding, removing and altering simple fields
self.assertEqual(self.get_table_fields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) self.assertEqual(self.get_table_fields(Model1), [("date", "Date"), ("f1", "Int32"), ("f2", "String")])
self.database.migrate('tests.sample_migrations', 4) self.database.migrate("tests.sample_migrations", 4)
self.assertEqual(self.get_table_fields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String'), ('f5', 'Array(UInt64)')]) self.assertEqual(
self.database.migrate('tests.sample_migrations', 5) self.get_table_fields(Model2),
self.assertEqual(self.get_table_fields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')]) [
("date", "Date"),
("f1", "Int32"),
("f3", "Float32"),
("f2", "String"),
("f4", "String"),
("f5", "Array(UInt64)"),
],
)
self.database.migrate("tests.sample_migrations", 5)
self.assertEqual(
self.get_table_fields(Model3), [("date", "Date"), ("f1", "Int64"), ("f3", "Float64"), ("f4", "String")]
)
# Altering enum fields # Altering enum fields
self.database.migrate('tests.sample_migrations', 6) self.database.migrate("tests.sample_migrations", 6)
self.assertTrue(self.table_exists(EnumModel1)) self.assertTrue(self.table_exists(EnumModel1))
self.assertEqual(self.get_table_fields(EnumModel1), self.assertEqual(
[('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")]) self.get_table_fields(EnumModel1), [("date", "Date"), ("f1", "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")]
self.database.migrate('tests.sample_migrations', 7) )
self.database.migrate("tests.sample_migrations", 7)
self.assertTrue(self.table_exists(EnumModel1)) self.assertTrue(self.table_exists(EnumModel1))
self.assertEqual(self.get_table_fields(EnumModel2), self.assertEqual(
[('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")]) self.get_table_fields(EnumModel2),
[("date", "Date"), ("f1", "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")],
)
# Materialized fields and alias fields # Materialized fields and alias fields
self.database.migrate('tests.sample_migrations', 8) self.database.migrate("tests.sample_migrations", 8)
self.assertTrue(self.table_exists(MaterializedModel)) self.assertTrue(self.table_exists(MaterializedModel))
self.assertEqual(self.get_table_fields(MaterializedModel), self.assertEqual(self.get_table_fields(MaterializedModel), [("date_time", "DateTime"), ("date", "Date")])
[('date_time', "DateTime"), ('date', 'Date')]) self.database.migrate("tests.sample_migrations", 9)
self.database.migrate('tests.sample_migrations', 9)
self.assertTrue(self.table_exists(AliasModel)) self.assertTrue(self.table_exists(AliasModel))
self.assertEqual(self.get_table_fields(AliasModel), self.assertEqual(self.get_table_fields(AliasModel), [("date", "Date"), ("date_alias", "Date")])
[('date', 'Date'), ('date_alias', "Date")])
# Buffer models creation and alteration # Buffer models creation and alteration
self.database.migrate('tests.sample_migrations', 10) self.database.migrate("tests.sample_migrations", 10)
self.assertTrue(self.table_exists(Model4)) self.assertTrue(self.table_exists(Model4))
self.assertTrue(self.table_exists(Model4Buffer)) self.assertTrue(self.table_exists(Model4Buffer))
self.assertEqual(self.get_table_fields(Model4), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) self.assertEqual(self.get_table_fields(Model4), [("date", "Date"), ("f1", "Int32"), ("f2", "String")])
self.assertEqual(self.get_table_fields(Model4Buffer), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) self.assertEqual(self.get_table_fields(Model4Buffer), [("date", "Date"), ("f1", "Int32"), ("f2", "String")])
self.database.migrate('tests.sample_migrations', 11) self.database.migrate("tests.sample_migrations", 11)
self.assertEqual(self.get_table_fields(Model4), [('date', 'Date'), ('f3', 'DateTime'), ('f2', 'String')]) self.assertEqual(self.get_table_fields(Model4), [("date", "Date"), ("f3", "DateTime"), ("f2", "String")])
self.assertEqual(self.get_table_fields(Model4Buffer), [('date', 'Date'), ('f3', 'DateTime'), ('f2', 'String')]) self.assertEqual(self.get_table_fields(Model4Buffer), [("date", "Date"), ("f3", "DateTime"), ("f2", "String")])
self.database.migrate('tests.sample_migrations', 12) self.database.migrate("tests.sample_migrations", 12)
self.assertEqual(self.database.count(Model3), 3) self.assertEqual(self.database.count(Model3), 3)
data = [item.f1 for item in self.database.select('SELECT f1 FROM $table ORDER BY f1', model_class=Model3)] data = [item.f1 for item in self.database.select("SELECT f1 FROM $table ORDER BY f1", model_class=Model3)]
self.assertListEqual(data, [1, 2, 3]) self.assertListEqual(data, [1, 2, 3])
self.database.migrate('tests.sample_migrations', 13) self.database.migrate("tests.sample_migrations", 13)
self.assertEqual(self.database.count(Model3), 4) self.assertEqual(self.database.count(Model3), 4)
data = [item.f1 for item in self.database.select('SELECT f1 FROM $table ORDER BY f1', model_class=Model3)] data = [item.f1 for item in self.database.select("SELECT f1 FROM $table ORDER BY f1", model_class=Model3)]
self.assertListEqual(data, [1, 2, 3, 4]) self.assertListEqual(data, [1, 2, 3, 4])
self.database.migrate('tests.sample_migrations', 14) self.database.migrate("tests.sample_migrations", 14)
self.assertTrue(self.table_exists(MaterializedModel1)) self.assertTrue(self.table_exists(MaterializedModel1))
self.assertEqual(self.get_table_fields(MaterializedModel1), self.assertEqual(
[('date_time', 'DateTime'), ('int_field', 'Int8'), ('date', 'Date'), ('int_field_plus_one', 'Int8')]) self.get_table_fields(MaterializedModel1),
[("date_time", "DateTime"), ("int_field", "Int8"), ("date", "Date"), ("int_field_plus_one", "Int8")],
)
self.assertTrue(self.table_exists(AliasModel1)) self.assertTrue(self.table_exists(AliasModel1))
self.assertEqual(self.get_table_fields(AliasModel1), self.assertEqual(
[('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')]) self.get_table_fields(AliasModel1),
[("date", "Date"), ("int_field", "Int8"), ("date_alias", "Date"), ("int_field_plus_one", "Int8")],
)
# Codecs and low cardinality # Codecs and low cardinality
self.database.migrate('tests.sample_migrations', 15) self.database.migrate("tests.sample_migrations", 15)
self.assertTrue(self.table_exists(Model4_compressed)) self.assertTrue(self.table_exists(Model4_compressed))
if self.database.has_low_cardinality_support: if self.database.has_low_cardinality_support:
self.assertEqual(self.get_table_fields(Model2LowCardinality), self.assertEqual(
[('date', 'Date'), ('f1', 'LowCardinality(Int32)'), ('f3', 'LowCardinality(Float32)'), self.get_table_fields(Model2LowCardinality),
('f2', 'LowCardinality(String)'), ('f4', 'LowCardinality(Nullable(String))'), ('f5', 'Array(LowCardinality(UInt64))')]) [
("date", "Date"),
("f1", "LowCardinality(Int32)"),
("f3", "LowCardinality(Float32)"),
("f2", "LowCardinality(String)"),
("f4", "LowCardinality(Nullable(String))"),
("f5", "Array(LowCardinality(UInt64))"),
],
)
else: else:
logging.warning('No support for low cardinality') logging.warning("No support for low cardinality")
self.assertEqual(self.get_table_fields(Model2), self.assertEqual(
[('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'Nullable(String)'), self.get_table_fields(Model2),
('f5', 'Array(UInt64)')]) [
("date", "Date"),
("f1", "Int32"),
("f3", "Float32"),
("f2", "String"),
("f4", "Nullable(String)"),
("f5", "Array(UInt64)"),
],
)
if self.database.server_version >= (19, 14, 3, 3): if self.database.server_version >= (19, 14, 3, 3):
# Creating constraints # Creating constraints
self.database.migrate('tests.sample_migrations', 16) self.database.migrate("tests.sample_migrations", 16)
self.assertTrue(self.table_exists(ModelWithConstraints)) self.assertTrue(self.table_exists(ModelWithConstraints))
self.database.insert([ModelWithConstraints(f1=101, f2='a')]) self.database.insert([ModelWithConstraints(f1=101, f2="a")])
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=99, f2='a')]) self.database.insert([ModelWithConstraints(f1=99, f2="a")])
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=101, f2='x')]) self.database.insert([ModelWithConstraints(f1=101, f2="x")])
# Modifying constraints # Modifying constraints
self.database.migrate('tests.sample_migrations', 17) self.database.migrate("tests.sample_migrations", 17)
self.database.insert([ModelWithConstraints(f1=99, f2='a')]) self.database.insert([ModelWithConstraints(f1=99, f2="a")])
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=101, f2='a')]) self.database.insert([ModelWithConstraints(f1=101, f2="a")])
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=99, f2='x')]) self.database.insert([ModelWithConstraints(f1=99, f2="x")])
if self.database.server_version >= (20, 1, 2, 4): if self.database.server_version >= (20, 1, 2, 4):
# Creating indexes # Creating indexes
self.database.migrate('tests.sample_migrations', 18) self.database.migrate("tests.sample_migrations", 18)
self.assertTrue(self.table_exists(ModelWithIndex)) self.assertTrue(self.table_exists(ModelWithIndex))
self.assertIn('INDEX index ', self.get_table_def(ModelWithIndex)) self.assertIn("INDEX index ", self.get_table_def(ModelWithIndex))
self.assertIn('INDEX another_index ', self.get_table_def(ModelWithIndex)) self.assertIn("INDEX another_index ", self.get_table_def(ModelWithIndex))
# Modifying indexes # Modifying indexes
self.database.migrate('tests.sample_migrations', 19) self.database.migrate("tests.sample_migrations", 19)
self.assertNotIn('INDEX index ', self.get_table_def(ModelWithIndex)) self.assertNotIn("INDEX index ", self.get_table_def(ModelWithIndex))
self.assertIn('INDEX index2 ', self.get_table_def(ModelWithIndex)) self.assertIn("INDEX index2 ", self.get_table_def(ModelWithIndex))
self.assertIn('INDEX another_index ', self.get_table_def(ModelWithIndex)) self.assertIn("INDEX another_index ", self.get_table_def(ModelWithIndex))
# Several different models with the same table name, to simulate a table that changes over time # Several different models with the same table name, to simulate a table that changes over time
class Model1(Model): class Model1(Model):
date = DateField() date = DateField()
f1 = Int32Field() f1 = Int32Field()
f2 = StringField() f2 = StringField()
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'mig' return "mig"
class Model2(Model): class Model2(Model):
@ -161,99 +197,99 @@ class Model2(Model):
f3 = Float32Field() f3 = Float32Field()
f2 = StringField() f2 = StringField()
f4 = StringField() f4 = StringField()
f5 = ArrayField(UInt64Field()) # addition of an array field f5 = ArrayField(UInt64Field()) # addition of an array field
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'mig' return "mig"
class Model3(Model): class Model3(Model):
date = DateField() date = DateField()
f1 = Int64Field() # changed from Int32 f1 = Int64Field() # changed from Int32
f3 = Float64Field() # changed from Float32 f3 = Float64Field() # changed from Float32
f4 = StringField() f4 = StringField()
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'mig' return "mig"
class EnumModel1(Model): class EnumModel1(Model):
date = DateField() date = DateField()
f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow')) f1 = Enum8Field(Enum("SomeEnum1", "dog cat cow"))
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'enum_mig' return "enum_mig"
class EnumModel2(Model): class EnumModel2(Model):
date = DateField() date = DateField()
f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values f1 = Enum16Field(Enum("SomeEnum2", "dog cat horse pig")) # changed type and values
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'enum_mig' return "enum_mig"
class MaterializedModel(Model): class MaterializedModel(Model):
date_time = DateTimeField() date_time = DateTimeField()
date = DateField(materialized='toDate(date_time)') date = DateField(materialized="toDate(date_time)")
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'materalized_date' return "materalized_date"
class MaterializedModel1(Model): class MaterializedModel1(Model):
date_time = DateTimeField() date_time = DateTimeField()
date = DateField(materialized='toDate(date_time)') date = DateField(materialized="toDate(date_time)")
int_field = Int8Field() int_field = Int8Field()
int_field_plus_one = Int8Field(materialized='int_field + 1') int_field_plus_one = Int8Field(materialized="int_field + 1")
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'materalized_date' return "materalized_date"
class AliasModel(Model): class AliasModel(Model):
date = DateField() date = DateField()
date_alias = DateField(alias='date') date_alias = DateField(alias="date")
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'alias_date' return "alias_date"
class AliasModel1(Model): class AliasModel1(Model):
date = DateField() date = DateField()
date_alias = DateField(alias='date') date_alias = DateField(alias="date")
int_field = Int8Field() int_field = Int8Field()
int_field_plus_one = Int8Field(alias='int_field + 1') int_field_plus_one = Int8Field(alias="int_field + 1")
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'alias_date' return "alias_date"
class Model4(Model): class Model4(Model):
@ -262,11 +298,11 @@ class Model4(Model):
f1 = Int32Field() f1 = Int32Field()
f2 = StringField() f2 = StringField()
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'model4' return "model4"
class Model4Buffer(BufferModel, Model4): class Model4Buffer(BufferModel, Model4):
@ -275,7 +311,7 @@ class Model4Buffer(BufferModel, Model4):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'model4buffer' return "model4buffer"
class Model4_changed(Model): class Model4_changed(Model):
@ -284,11 +320,11 @@ class Model4_changed(Model):
f3 = DateTimeField() f3 = DateTimeField()
f2 = StringField() f2 = StringField()
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'model4' return "model4"
class Model4Buffer_changed(BufferModel, Model4_changed): class Model4Buffer_changed(BufferModel, Model4_changed):
@ -297,20 +333,20 @@ class Model4Buffer_changed(BufferModel, Model4_changed):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'model4buffer' return "model4buffer"
class Model4_compressed(Model): class Model4_compressed(Model):
date = DateField() date = DateField()
f3 = DateTimeField(codec='Delta,ZSTD(10)') f3 = DateTimeField(codec="Delta,ZSTD(10)")
f2 = StringField(codec='LZ4HC') f2 = StringField(codec="LZ4HC")
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'model4' return "model4"
class Model2LowCardinality(Model): class Model2LowCardinality(Model):
@ -321,11 +357,11 @@ class Model2LowCardinality(Model):
f4 = LowCardinalityField(NullableField(StringField())) f4 = LowCardinalityField(NullableField(StringField()))
f5 = ArrayField(LowCardinalityField(UInt64Field())) f5 = ArrayField(LowCardinalityField(UInt64Field()))
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'mig' return "mig"
class ModelWithConstraints(Model): class ModelWithConstraints(Model):
@ -334,14 +370,14 @@ class ModelWithConstraints(Model):
f1 = Int32Field() f1 = Int32Field()
f2 = StringField() f2 = StringField()
constraint = Constraint(f2.isIn(['a', 'b', 'c'])) # check reserved keyword as constraint name constraint = Constraint(f2.isIn(["a", "b", "c"])) # check reserved keyword as constraint name
f1_constraint = Constraint(f1 > 100) f1_constraint = Constraint(f1 > 100)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'modelwithconstraints' return "modelwithconstraints"
class ModelWithConstraints2(Model): class ModelWithConstraints2(Model):
@ -350,14 +386,14 @@ class ModelWithConstraints2(Model):
f1 = Int32Field() f1 = Int32Field()
f2 = StringField() f2 = StringField()
constraint = Constraint(f2.isIn(['a', 'b', 'c'])) constraint = Constraint(f2.isIn(["a", "b", "c"]))
f1_constraint_new = Constraint(f1 < 100) f1_constraint_new = Constraint(f1 < 100)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'modelwithconstraints' return "modelwithconstraints"
class ModelWithIndex(Model): class ModelWithIndex(Model):
@ -369,11 +405,11 @@ class ModelWithIndex(Model):
index = Index(f1, type=Index.minmax(), granularity=1) index = Index(f1, type=Index.minmax(), granularity=1)
another_index = Index(f2, type=Index.set(0), granularity=1) another_index = Index(f2, type=Index.set(0), granularity=1)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'modelwithindex' return "modelwithindex"
class ModelWithIndex2(Model): class ModelWithIndex2(Model):
@ -385,9 +421,8 @@ class ModelWithIndex2(Model):
index2 = Index(f1, type=Index.bloom_filter(), granularity=2) index2 = Index(f1, type=Index.bloom_filter(), granularity=2)
another_index = Index(f2, type=Index.set(0), granularity=1) another_index = Index(f2, type=Index.set(0), granularity=1)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'modelwithindex' return "modelwithindex"

View File

@ -9,13 +9,12 @@ from clickhouse_orm.funcs import F
class ModelTestCase(unittest.TestCase): class ModelTestCase(unittest.TestCase):
def test_defaults(self): def test_defaults(self):
# Check that all fields have their explicit or implicit defaults # Check that all fields have their explicit or implicit defaults
instance = SimpleModel() instance = SimpleModel()
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1)) self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc)) self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.str_field, 'dozo') self.assertEqual(instance.str_field, "dozo")
self.assertEqual(instance.int_field, 17) self.assertEqual(instance.int_field, 17)
self.assertEqual(instance.float_field, 0) self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.default_func, NO_VALUE) self.assertEqual(instance.default_func, NO_VALUE)
@ -25,9 +24,9 @@ class ModelTestCase(unittest.TestCase):
kwargs = dict( kwargs = dict(
date_field=datetime.date(1973, 12, 6), date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc), datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
str_field='aloha', str_field="aloha",
int_field=-50, int_field=-50,
float_field=3.14 float_field=3.14,
) )
instance = SimpleModel(**kwargs) instance = SimpleModel(**kwargs)
for name, value in kwargs.items(): for name, value in kwargs.items():
@ -36,12 +35,12 @@ class ModelTestCase(unittest.TestCase):
def test_assignment_error(self): def test_assignment_error(self):
# Check non-existing field during construction # Check non-existing field during construction
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
instance = SimpleModel(int_field=7450, pineapple='tasty') instance = SimpleModel(int_field=7450, pineapple="tasty")
# Check invalid field values during construction # Check invalid field values during construction
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance = SimpleModel(int_field='nope') instance = SimpleModel(int_field="nope")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance = SimpleModel(date_field='nope') instance = SimpleModel(date_field="nope")
# Check invalid field values during assignment # Check invalid field values during assignment
instance = SimpleModel() instance = SimpleModel()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -49,38 +48,43 @@ class ModelTestCase(unittest.TestCase):
def test_string_conversion(self): def test_string_conversion(self):
# Check field conversion from string during construction # Check field conversion from string during construction
instance = SimpleModel(date_field='1973-12-06', int_field='100', float_field='7') instance = SimpleModel(date_field="1973-12-06", int_field="100", float_field="7")
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6)) self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int_field, 100) self.assertEqual(instance.int_field, 100)
self.assertEqual(instance.float_field, 7) self.assertEqual(instance.float_field, 7)
# Check field conversion from string during assignment # Check field conversion from string during assignment
instance.int_field = '99' instance.int_field = "99"
self.assertEqual(instance.int_field, 99) self.assertEqual(instance.int_field, 99)
def test_to_dict(self): def test_to_dict(self):
instance = SimpleModel(date_field='1973-12-06', int_field='100', float_field='7') instance = SimpleModel(date_field="1973-12-06", int_field="100", float_field="7")
self.assertDictEqual(instance.to_dict(), {
"date_field": datetime.date(1973, 12, 6),
"int_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
"str_field": "dozo",
"default_func": NO_VALUE
})
self.assertDictEqual(instance.to_dict(include_readonly=False), {
"date_field": datetime.date(1973, 12, 6),
"int_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"str_field": "dozo",
"default_func": NO_VALUE
})
self.assertDictEqual( self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=('int_field', 'alias_field', 'datetime_field')), { instance.to_dict(),
{
"date_field": datetime.date(1973, 12, 6),
"int_field": 100, "int_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc) "float_field": 7.0,
}) "datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
"str_field": "dozo",
"default_func": NO_VALUE,
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False),
{
"date_field": datetime.date(1973, 12, 6),
"int_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"str_field": "dozo",
"default_func": NO_VALUE,
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=("int_field", "alias_field", "datetime_field")),
{"int_field": 100, "datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc)},
)
def test_field_name_in_error_message_for_invalid_value_in_constructor(self): def test_field_name_in_error_message_for_invalid_value_in_constructor(self):
bad_value = 1 bad_value = 1
@ -88,19 +92,17 @@ class ModelTestCase(unittest.TestCase):
SimpleModel(str_field=bad_value) SimpleModel(str_field=bad_value)
self.assertEqual( self.assertEqual(
"Invalid value for StringField: {} (field 'str_field')".format(repr(bad_value)), "Invalid value for StringField: {} (field 'str_field')".format(repr(bad_value)), str(cm.exception)
str(cm.exception)
) )
def test_field_name_in_error_message_for_invalid_value_in_assignment(self): def test_field_name_in_error_message_for_invalid_value_in_assignment(self):
instance = SimpleModel() instance = SimpleModel()
bad_value = 'foo' bad_value = "foo"
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
instance.float_field = bad_value instance.float_field = bad_value
self.assertEqual( self.assertEqual(
"Invalid value for Float32Field - {} (field 'float_field')".format(repr(bad_value)), "Invalid value for Float32Field - {} (field 'float_field')".format(repr(bad_value)), str(cm.exception)
str(cm.exception)
) )
@ -108,10 +110,10 @@ class SimpleModel(Model):
date_field = DateField() date_field = DateField()
datetime_field = DateTimeField() datetime_field = DateTimeField()
str_field = StringField(default='dozo') str_field = StringField(default="dozo")
int_field = Int32Field(default=17) int_field = Int32Field(default=17)
float_field = Float32Field() float_field = Float32Field()
alias_field = Float32Field(alias='float_field') alias_field = Float32Field(alias="float_field")
default_func = Float32Field(default=F.sqrt(float_field) + 17) default_func = Float32Field(default=F.sqrt(float_field) + 17)
engine = MergeTree('date_field', ('int_field', 'date_field')) engine = MergeTree("date_field", ("int_field", "date_field"))

View File

@ -5,15 +5,14 @@ from time import sleep
class MutationsTestCase(TestCaseWithData): class MutationsTestCase(TestCaseWithData):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
if self.database.server_version < (18,): if self.database.server_version < (18,):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self._insert_all() self._insert_all()
def _wait_for_mutations(self): def _wait_for_mutations(self):
sql = 'SELECT * FROM system.mutations WHERE is_done = 0' sql = "SELECT * FROM system.mutations WHERE is_done = 0"
while list(self.database.raw(sql)): while list(self.database.raw(sql)):
sleep(0.25) sleep(0.25)
@ -23,7 +22,7 @@ class MutationsTestCase(TestCaseWithData):
self.assertFalse(Person.objects_in(self.database)) self.assertFalse(Person.objects_in(self.database))
def test_delete_with_where_cond(self): def test_delete_with_where_cond(self):
cond = Person.first_name == 'Cassady' cond = Person.first_name == "Cassady"
self.assertTrue(Person.objects_in(self.database).filter(cond)) self.assertTrue(Person.objects_in(self.database).filter(cond))
Person.objects_in(self.database).filter(cond).delete() Person.objects_in(self.database).filter(cond).delete()
self._wait_for_mutations() self._wait_for_mutations()
@ -41,11 +40,12 @@ class MutationsTestCase(TestCaseWithData):
def test_update_all(self): def test_update_all(self):
Person.objects_in(self.database).update(height=0) Person.objects_in(self.database).update(height=0)
self._wait_for_mutations() self._wait_for_mutations()
for p in Person.objects_in(self.database): print(p.height) for p in Person.objects_in(self.database):
print(p.height)
self.assertFalse(Person.objects_in(self.database).exclude(height=0)) self.assertFalse(Person.objects_in(self.database).exclude(height=0))
def test_update_with_where_cond(self): def test_update_with_where_cond(self):
cond = Person.first_name == 'Cassady' cond = Person.first_name == "Cassady"
Person.objects_in(self.database).filter(cond).update(height=0) Person.objects_in(self.database).filter(cond).update(height=0)
self._wait_for_mutations() self._wait_for_mutations()
self.assertFalse(Person.objects_in(self.database).filter(cond).exclude(height=0)) self.assertFalse(Person.objects_in(self.database).filter(cond).exclude(height=0))
@ -71,9 +71,9 @@ class MutationsTestCase(TestCaseWithData):
base_query = Person.objects_in(self.database) base_query = Person.objects_in(self.database)
queries = [ queries = [
base_query[0:1], base_query[0:1],
base_query.limit_by(5, 'first_name'), base_query.limit_by(5, "first_name"),
base_query.distinct(), base_query.distinct(),
base_query.aggregate('first_name', count=F.count()) base_query.aggregate("first_name", count=F.count()),
] ]
for query in queries: for query in queries:
print(query) print(query)

View File

@ -11,9 +11,8 @@ from datetime import date, datetime
class NullableFieldsTest(unittest.TestCase): class NullableFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithNullable) self.database.create_table(ModelWithNullable)
def tearDown(self): def tearDown(self):
@ -23,18 +22,20 @@ class NullableFieldsTest(unittest.TestCase):
f = NullableField(DateTimeField()) f = NullableField(DateTimeField())
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc) epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values # Valid values
for value in (date(1970, 1, 1), for value in (
datetime(1970, 1, 1), date(1970, 1, 1),
epoch, datetime(1970, 1, 1),
epoch.astimezone(pytz.timezone('US/Eastern')), epoch,
epoch.astimezone(pytz.timezone('Asia/Jerusalem')), epoch.astimezone(pytz.timezone("US/Eastern")),
'1970-01-01 00:00:00', epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
'1970-01-17 00:00:17', "1970-01-01 00:00:00",
'0000-00-00 00:00:00', "1970-01-17 00:00:17",
0, "0000-00-00 00:00:00",
'\\N'): 0,
"\\N",
):
dt = f.to_python(value, pytz.utc) dt = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(dt) self.assertIsNone(dt)
else: else:
self.assertTrue(dt.tzinfo) self.assertTrue(dt.tzinfo)
@ -42,32 +43,32 @@ class NullableFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5): for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
def test_nullable_uint8_field(self): def test_nullable_uint8_field(self):
f = NullableField(UInt8Field()) f = NullableField(UInt8Field())
# Valid values # Valid values
for value in (17, '17', 17.0, '\\N'): for value in (17, "17", 17.0, "\\N"):
python_value = f.to_python(value, pytz.utc) python_value = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(python_value) self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value)) self.assertEqual(value, f.to_db_string(python_value))
else: else:
self.assertEqual(python_value, 17) self.assertEqual(python_value, 17)
# Invalid values # Invalid values
for value in ('nope', date.today()): for value in ("nope", date.today()):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
def test_nullable_string_field(self): def test_nullable_string_field(self):
f = NullableField(StringField()) f = NullableField(StringField())
# Valid values # Valid values
for value in ('\\\\N', 'N', 'some text', '\\N'): for value in ("\\\\N", "N", "some text", "\\N"):
python_value = f.to_python(value, pytz.utc) python_value = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(python_value) self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value)) self.assertEqual(value, f.to_db_string(python_value))
else: else:
@ -91,12 +92,16 @@ class NullableFieldsTest(unittest.TestCase):
def _insert_sample_data(self): def _insert_sample_data(self):
dt = date(1970, 1, 1) dt = date(1970, 1, 1)
self.database.insert([ self.database.insert(
ModelWithNullable(date_field='2016-08-30', null_str='', null_int=42, null_date=dt), [
ModelWithNullable(date_field='2016-08-30', null_str='nothing', null_int=None, null_date=None), ModelWithNullable(date_field="2016-08-30", null_str="", null_int=42, null_date=dt),
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=42, null_date=dt), ModelWithNullable(date_field="2016-08-30", null_str="nothing", null_int=None, null_date=None),
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=None, null_date=None, null_default=None) ModelWithNullable(date_field="2016-08-31", null_str=None, null_int=42, null_date=dt),
]) ModelWithNullable(
date_field="2016-08-31", null_str=None, null_int=None, null_date=None, null_default=None
),
]
)
def _assert_sample_data(self, results): def _assert_sample_data(self, results):
for r in results: for r in results:
@ -110,7 +115,7 @@ class NullableFieldsTest(unittest.TestCase):
self.assertEqual(results[0].null_materialized, 420) self.assertEqual(results[0].null_materialized, 420)
self.assertEqual(results[0].null_date, dt) self.assertEqual(results[0].null_date, dt)
self.assertIsNone(results[1].null_date) self.assertIsNone(results[1].null_date)
self.assertEqual(results[1].null_str, 'nothing') self.assertEqual(results[1].null_str, "nothing")
self.assertIsNone(results[1].null_date) self.assertIsNone(results[1].null_date)
self.assertIsNone(results[2].null_str) self.assertIsNone(results[2].null_str)
self.assertEqual(results[2].null_date, dt) self.assertEqual(results[2].null_date, dt)
@ -128,14 +133,14 @@ class NullableFieldsTest(unittest.TestCase):
def test_insert_and_select(self): def test_insert_and_select(self):
self._insert_sample_data() self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys()) fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $table ORDER BY date_field' % fields query = "SELECT %s from $table ORDER BY date_field" % fields
results = list(self.database.select(query, ModelWithNullable)) results = list(self.database.select(query, ModelWithNullable))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self._insert_sample_data() self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys()) fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $db.modelwithnullable ORDER BY date_field' % fields query = "SELECT %s from $db.modelwithnullable ORDER BY date_field" % fields
results = list(self.database.select(query)) results = list(self.database.select(query))
self._assert_sample_data(results) self._assert_sample_data(results)
@ -143,11 +148,11 @@ class NullableFieldsTest(unittest.TestCase):
class ModelWithNullable(Model): class ModelWithNullable(Model):
date_field = DateField() date_field = DateField()
null_str = NullableField(StringField(), extra_null_values={''}) null_str = NullableField(StringField(), extra_null_values={""})
null_int = NullableField(Int32Field()) null_int = NullableField(Int32Field())
null_date = NullableField(DateField()) null_date = NullableField(DateField())
null_default = NullableField(Int32Field(), default=7) null_default = NullableField(Int32Field(), default=7)
null_alias = NullableField(Int32Field(), alias='null_int/2') null_alias = NullableField(Int32Field(), alias="null_int/2")
null_materialized = NullableField(Int32Field(), alias='null_int*10') null_materialized = NullableField(Int32Field(), alias="null_int*10")
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -9,12 +9,11 @@ from enum import Enum
from decimal import Decimal from decimal import Decimal
from logging import getLogger from logging import getLogger
logger = getLogger('tests')
logger = getLogger("tests")
class QuerySetTestCase(TestCaseWithData): class QuerySetTestCase(TestCaseWithData):
def setUp(self): def setUp(self):
super(QuerySetTestCase, self).setUp() super(QuerySetTestCase, self).setUp()
self.database.insert(self._sample_data()) self.database.insert(self._sample_data())
@ -24,7 +23,7 @@ class QuerySetTestCase(TestCaseWithData):
count = 0 count = 0
for instance in qs: for instance in qs:
count += 1 count += 1
logger.info('\t[%d]\t%s' % (count, instance.to_dict())) logger.info("\t[%d]\t%s" % (count, instance.to_dict()))
self.assertEqual(count, expected_count) self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count) self.assertEqual(qs.count(), expected_count)
@ -32,8 +31,8 @@ class QuerySetTestCase(TestCaseWithData):
# We can't distinguish prewhere and where results, it affects performance only. # We can't distinguish prewhere and where results, it affects performance only.
# So let's control prewhere acts like where does # So let's control prewhere acts like where does
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self.assertTrue(qs.filter(first_name='Connor', prewhere=True)) self.assertTrue(qs.filter(first_name="Connor", prewhere=True))
self.assertFalse(qs.filter(first_name='Willy', prewhere=True)) self.assertFalse(qs.filter(first_name="Willy", prewhere=True))
def test_no_filtering(self): def test_no_filtering(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
@ -41,8 +40,8 @@ class QuerySetTestCase(TestCaseWithData):
def test_truthiness(self): def test_truthiness(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self.assertTrue(qs.filter(first_name='Connor')) self.assertTrue(qs.filter(first_name="Connor"))
self.assertFalse(qs.filter(first_name='Willy')) self.assertFalse(qs.filter(first_name="Willy"))
def test_filter_null_value(self): def test_filter_null_value(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
@ -53,81 +52,96 @@ class QuerySetTestCase(TestCaseWithData):
def test_filter_string_field(self): def test_filter_string_field(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(first_name='Ciaran'), 2) self._test_qs(qs.filter(first_name="Ciaran"), 2)
self._test_qs(qs.filter(first_name='ciaran'), 0) # case sensitive self._test_qs(qs.filter(first_name="ciaran"), 0) # case sensitive
self._test_qs(qs.filter(first_name__iexact='ciaran'), 2) # case insensitive self._test_qs(qs.filter(first_name__iexact="ciaran"), 2) # case insensitive
self._test_qs(qs.filter(first_name__gt='Whilemina'), 4) self._test_qs(qs.filter(first_name__gt="Whilemina"), 4)
self._test_qs(qs.filter(first_name__gte='Whilemina'), 5) self._test_qs(qs.filter(first_name__gte="Whilemina"), 5)
self._test_qs(qs.filter(first_name__lt='Adam'), 1) self._test_qs(qs.filter(first_name__lt="Adam"), 1)
self._test_qs(qs.filter(first_name__lte='Adam'), 2) self._test_qs(qs.filter(first_name__lte="Adam"), 2)
self._test_qs(qs.filter(first_name__in=('Connor', 'Courtney')), 3) # in tuple self._test_qs(qs.filter(first_name__in=("Connor", "Courtney")), 3) # in tuple
self._test_qs(qs.filter(first_name__in=['Connor', 'Courtney']), 3) # in list self._test_qs(qs.filter(first_name__in=["Connor", "Courtney"]), 3) # in list
self._test_qs(qs.filter(first_name__in="'Connor', 'Courtney'"), 3) # in string self._test_qs(qs.filter(first_name__in="'Connor', 'Courtney'"), 3) # in string
self._test_qs(qs.filter(first_name__not_in="'Connor', 'Courtney'"), 97) self._test_qs(qs.filter(first_name__not_in="'Connor', 'Courtney'"), 97)
self._test_qs(qs.filter(first_name__contains='sh'), 3) # case sensitive self._test_qs(qs.filter(first_name__contains="sh"), 3) # case sensitive
self._test_qs(qs.filter(first_name__icontains='sh'), 6) # case insensitive self._test_qs(qs.filter(first_name__icontains="sh"), 6) # case insensitive
self._test_qs(qs.filter(first_name__startswith='le'), 0) # case sensitive self._test_qs(qs.filter(first_name__startswith="le"), 0) # case sensitive
self._test_qs(qs.filter(first_name__istartswith='Le'), 2) # case insensitive self._test_qs(qs.filter(first_name__istartswith="Le"), 2) # case insensitive
self._test_qs(qs.filter(first_name__istartswith=''), 100) # empty prefix self._test_qs(qs.filter(first_name__istartswith=""), 100) # empty prefix
self._test_qs(qs.filter(first_name__endswith='IA'), 0) # case sensitive self._test_qs(qs.filter(first_name__endswith="IA"), 0) # case sensitive
self._test_qs(qs.filter(first_name__iendswith='ia'), 3) # case insensitive self._test_qs(qs.filter(first_name__iendswith="ia"), 3) # case insensitive
self._test_qs(qs.filter(first_name__iendswith=''), 100) # empty suffix self._test_qs(qs.filter(first_name__iendswith=""), 100) # empty suffix
def test_filter_with_q_objects(self): def test_filter_with_q_objects(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(Q(first_name='Ciaran')), 2) self._test_qs(qs.filter(Q(first_name="Ciaran")), 2)
self._test_qs(qs.filter(Q(first_name='Ciaran') | Q(first_name='Chelsea')), 3) self._test_qs(qs.filter(Q(first_name="Ciaran") | Q(first_name="Chelsea")), 3)
self._test_qs(qs.filter(Q(first_name__in=['Warren', 'Whilemina', 'Whitney']) & Q(height__gte=1.7)), 3) self._test_qs(qs.filter(Q(first_name__in=["Warren", "Whilemina", "Whitney"]) & Q(height__gte=1.7)), 3)
self._test_qs(qs.filter((Q(first_name__in=['Warren', 'Whilemina', 'Whitney']) & Q(height__gte=1.7) | self._test_qs(
(Q(first_name__in=['Victoria', 'Victor', 'Venus']) & Q(height__lt=1.7)))), 4) qs.filter(
self._test_qs(qs.filter(Q(first_name='Elton') & ~Q(last_name='Smith')), 1) (
Q(first_name__in=["Warren", "Whilemina", "Whitney"]) & Q(height__gte=1.7)
| (Q(first_name__in=["Victoria", "Victor", "Venus"]) & Q(height__lt=1.7))
)
),
4,
)
self._test_qs(qs.filter(Q(first_name="Elton") & ~Q(last_name="Smith")), 1)
# Check operator precendence # Check operator precendence
self._test_qs(qs.filter(first_name='Cassady').filter(Q(last_name='Knapp') | Q(last_name='Rogers') | Q(last_name='Gregory')), 2) self._test_qs(
self._test_qs(qs.filter(Q(first_name='Cassady') & Q(last_name='Knapp') | Q(first_name='Beatrice') & Q(last_name='Gregory')), 2) qs.filter(first_name="Cassady").filter(
self._test_qs(qs.filter(Q(first_name='Courtney') | Q(first_name='Cassady') & Q(last_name='Knapp')), 3) Q(last_name="Knapp") | Q(last_name="Rogers") | Q(last_name="Gregory")
),
2,
)
self._test_qs(
qs.filter(
Q(first_name="Cassady") & Q(last_name="Knapp") | Q(first_name="Beatrice") & Q(last_name="Gregory")
),
2,
)
self._test_qs(qs.filter(Q(first_name="Courtney") | Q(first_name="Cassady") & Q(last_name="Knapp")), 3)
def test_filter_unicode_string(self): def test_filter_unicode_string(self):
self.database.insert([ self.database.insert([Person(first_name=u"דונלד", last_name=u"דאק")])
Person(first_name=u'דונלד', last_name=u'דאק')
])
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(first_name=u'דונלד'), 1) self._test_qs(qs.filter(first_name=u"דונלד"), 1)
def test_filter_float_field(self): def test_filter_float_field(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=2), 0) self._test_qs(qs.filter(height__gt=2), 0)
self._test_qs(qs.filter(height__lt=1.61), 4) self._test_qs(qs.filter(height__lt=1.61), 4)
self._test_qs(qs.filter(height__lt='1.61'), 4) self._test_qs(qs.filter(height__lt="1.61"), 4)
self._test_qs(qs.exclude(height__lt='1.61'), 96) self._test_qs(qs.exclude(height__lt="1.61"), 96)
self._test_qs(qs.filter(height__gt=0), 100) self._test_qs(qs.filter(height__gt=0), 100)
self._test_qs(qs.exclude(height__gt=0), 0) self._test_qs(qs.exclude(height__gt=0), 0)
def test_filter_date_field(self): def test_filter_date_field(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(birthday='1970-12-02'), 1) self._test_qs(qs.filter(birthday="1970-12-02"), 1)
self._test_qs(qs.filter(birthday__eq='1970-12-02'), 1) self._test_qs(qs.filter(birthday__eq="1970-12-02"), 1)
self._test_qs(qs.filter(birthday__ne='1970-12-02'), 99) self._test_qs(qs.filter(birthday__ne="1970-12-02"), 99)
self._test_qs(qs.filter(birthday=date(1970, 12, 2)), 1) self._test_qs(qs.filter(birthday=date(1970, 12, 2)), 1)
self._test_qs(qs.filter(birthday__lte=date(1970, 12, 2)), 3) self._test_qs(qs.filter(birthday__lte=date(1970, 12, 2)), 3)
def test_mutiple_filter(self): def test_mutiple_filter(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
# Single filter call with multiple conditions is ANDed # Single filter call with multiple conditions is ANDed
self._test_qs(qs.filter(first_name='Ciaran', last_name='Carver'), 1) self._test_qs(qs.filter(first_name="Ciaran", last_name="Carver"), 1)
# Separate filter calls are also ANDed # Separate filter calls are also ANDed
self._test_qs(qs.filter(first_name='Ciaran').filter(last_name='Carver'), 1) self._test_qs(qs.filter(first_name="Ciaran").filter(last_name="Carver"), 1)
self._test_qs(qs.filter(birthday='1970-12-02').filter(birthday='1986-01-07'), 0) self._test_qs(qs.filter(birthday="1970-12-02").filter(birthday="1986-01-07"), 0)
def test_multiple_exclude(self): def test_multiple_exclude(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
# Single exclude call with multiple conditions is ANDed # Single exclude call with multiple conditions is ANDed
self._test_qs(qs.exclude(first_name='Ciaran', last_name='Carver'), 99) self._test_qs(qs.exclude(first_name="Ciaran", last_name="Carver"), 99)
# Separate exclude calls are ORed # Separate exclude calls are ORed
self._test_qs(qs.exclude(first_name='Ciaran').exclude(last_name='Carver'), 98) self._test_qs(qs.exclude(first_name="Ciaran").exclude(last_name="Carver"), 98)
self._test_qs(qs.exclude(birthday='1970-12-02').exclude(birthday='1986-01-07'), 98) self._test_qs(qs.exclude(birthday="1970-12-02").exclude(birthday="1986-01-07"), 98)
def test_only(self): def test_only(self):
qs = Person.objects_in(self.database).only('first_name', 'last_name') qs = Person.objects_in(self.database).only("first_name", "last_name")
for person in qs: for person in qs:
self.assertTrue(person.first_name) self.assertTrue(person.first_name)
self.assertTrue(person.last_name) self.assertTrue(person.last_name)
@ -136,46 +150,50 @@ class QuerySetTestCase(TestCaseWithData):
def test_order_by(self): def test_order_by(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self.assertFalse('ORDER BY' in qs.as_sql()) self.assertFalse("ORDER BY" in qs.as_sql())
self.assertFalse(qs.order_by_as_sql()) self.assertFalse(qs.order_by_as_sql())
person = list(qs.order_by('first_name', 'last_name'))[0] person = list(qs.order_by("first_name", "last_name"))[0]
self.assertEqual(person.first_name, 'Abdul') self.assertEqual(person.first_name, "Abdul")
person = list(qs.order_by('-first_name', '-last_name'))[0] person = list(qs.order_by("-first_name", "-last_name"))[0]
self.assertEqual(person.first_name, 'Yolanda') self.assertEqual(person.first_name, "Yolanda")
person = list(qs.order_by('height'))[0] person = list(qs.order_by("height"))[0]
self.assertEqual(person.height, 1.59) self.assertEqual(person.height, 1.59)
person = list(qs.order_by('-height'))[0] person = list(qs.order_by("-height"))[0]
self.assertEqual(person.height, 1.8) self.assertEqual(person.height, 1.8)
def test_in_subquery(self): def test_in_subquery(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__in='SELECT max(height) FROM $table'), 2) self._test_qs(qs.filter(height__in="SELECT max(height) FROM $table"), 2)
self._test_qs(qs.filter(first_name__in=qs.only('last_name')), 2) self._test_qs(qs.filter(first_name__in=qs.only("last_name")), 2)
self._test_qs(qs.filter(first_name__not_in=qs.only('last_name')), 98) self._test_qs(qs.filter(first_name__not_in=qs.only("last_name")), 98)
def _insert_sample_model(self): def _insert_sample_model(self):
self.database.create_table(SampleModel) self.database.create_table(SampleModel)
now = datetime.now() now = datetime.now()
self.database.insert([ self.database.insert(
SampleModel(timestamp=now, num=1, color=Color.red), [
SampleModel(timestamp=now, num=2, color=Color.red), SampleModel(timestamp=now, num=1, color=Color.red),
SampleModel(timestamp=now, num=3, color=Color.blue), SampleModel(timestamp=now, num=2, color=Color.red),
SampleModel(timestamp=now, num=4, color=Color.white), SampleModel(timestamp=now, num=3, color=Color.blue),
]) SampleModel(timestamp=now, num=4, color=Color.white),
]
)
def _insert_sample_collapsing_model(self): def _insert_sample_collapsing_model(self):
self.database.create_table(SampleCollapsingModel) self.database.create_table(SampleCollapsingModel)
now = datetime.now() now = datetime.now()
self.database.insert([ self.database.insert(
SampleCollapsingModel(timestamp=now, num=1, color=Color.red), [
SampleCollapsingModel(timestamp=now, num=2, color=Color.red), SampleCollapsingModel(timestamp=now, num=1, color=Color.red),
SampleCollapsingModel(timestamp=now, num=2, color=Color.red, sign=-1), SampleCollapsingModel(timestamp=now, num=2, color=Color.red),
SampleCollapsingModel(timestamp=now, num=2, color=Color.green), SampleCollapsingModel(timestamp=now, num=2, color=Color.red, sign=-1),
SampleCollapsingModel(timestamp=now, num=3, color=Color.white), SampleCollapsingModel(timestamp=now, num=2, color=Color.green),
SampleCollapsingModel(timestamp=now, num=4, color=Color.white, sign=1), SampleCollapsingModel(timestamp=now, num=3, color=Color.white),
SampleCollapsingModel(timestamp=now, num=4, color=Color.white, sign=-1), SampleCollapsingModel(timestamp=now, num=4, color=Color.white, sign=1),
SampleCollapsingModel(timestamp=now, num=4, color=Color.blue, sign=1), SampleCollapsingModel(timestamp=now, num=4, color=Color.white, sign=-1),
]) SampleCollapsingModel(timestamp=now, num=4, color=Color.blue, sign=1),
]
)
def test_filter_enum_field(self): def test_filter_enum_field(self):
self._insert_sample_model() self._insert_sample_model()
@ -184,7 +202,7 @@ class QuerySetTestCase(TestCaseWithData):
self._test_qs(qs.exclude(color=Color.white), 3) self._test_qs(qs.exclude(color=Color.white), 3)
# Different ways to specify blue # Different ways to specify blue
self._test_qs(qs.filter(color__gt=Color.blue), 1) self._test_qs(qs.filter(color__gt=Color.blue), 1)
self._test_qs(qs.filter(color__gt='blue'), 1) self._test_qs(qs.filter(color__gt="blue"), 1)
self._test_qs(qs.filter(color__gt=2), 1) self._test_qs(qs.filter(color__gt=2), 1)
def test_filter_int_field(self): def test_filter_int_field(self):
@ -199,7 +217,7 @@ class QuerySetTestCase(TestCaseWithData):
self._test_qs(qs.filter(num__in=range(1, 4)), 3) self._test_qs(qs.filter(num__in=range(1, 4)), 3)
def test_slicing(self): def test_slicing(self):
db = Database('system') db = Database("system")
numbers = list(range(100)) numbers = list(range(100))
qs = Numbers.objects_in(db) qs = Numbers.objects_in(db)
self.assertEqual(qs[0].number, numbers[0]) self.assertEqual(qs[0].number, numbers[0])
@ -211,7 +229,7 @@ class QuerySetTestCase(TestCaseWithData):
self.assertEqual([row.number for row in qs[10:10]], numbers[10:10]) self.assertEqual([row.number for row in qs[10:10]], numbers[10:10])
def test_invalid_slicing(self): def test_invalid_slicing(self):
db = Database('system') db = Database("system")
qs = Numbers.objects_in(db) qs = Numbers.objects_in(db)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
qs[3:10:2] qs[3:10:2]
@ -223,7 +241,7 @@ class QuerySetTestCase(TestCaseWithData):
qs[50:1] qs[50:1]
def test_pagination(self): def test_pagination(self):
qs = Person.objects_in(self.database).order_by('first_name', 'last_name') qs = Person.objects_in(self.database).order_by("first_name", "last_name")
# Try different page sizes # Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150): for page_size in (1, 2, 7, 10, 30, 100, 150):
# Iterate over pages and collect all instances # Iterate over pages and collect all instances
@ -241,31 +259,30 @@ class QuerySetTestCase(TestCaseWithData):
self.assertEqual(len(instances), len(data)) self.assertEqual(len(instances), len(data))
def test_pagination_last_page(self): def test_pagination_last_page(self):
qs = Person.objects_in(self.database).order_by('first_name', 'last_name') qs = Person.objects_in(self.database).order_by("first_name", "last_name")
# Try different page sizes # Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150): for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality # Ask for the last page in two different ways and verify equality
page_a = qs.paginate(-1, page_size) page_a = qs.paginate(-1, page_size)
page_b = qs.paginate(page_a.pages_total, page_size) page_b = qs.paginate(page_a.pages_total, page_size)
self.assertEqual(page_a[1:], page_b[1:]) self.assertEqual(page_a[1:], page_b[1:])
self.assertEqual([obj.to_tsv() for obj in page_a.objects], self.assertEqual([obj.to_tsv() for obj in page_a.objects], [obj.to_tsv() for obj in page_b.objects])
[obj.to_tsv() for obj in page_b.objects])
def test_pagination_invalid_page(self): def test_pagination_invalid_page(self):
qs = Person.objects_in(self.database).order_by('first_name', 'last_name') qs = Person.objects_in(self.database).order_by("first_name", "last_name")
for page_num in (0, -2, -100): for page_num in (0, -2, -100):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
qs.paginate(page_num, 100) qs.paginate(page_num, 100)
def test_pagination_with_conditions(self): def test_pagination_with_conditions(self):
qs = Person.objects_in(self.database).order_by('first_name', 'last_name').filter(first_name__lt='Ava') qs = Person.objects_in(self.database).order_by("first_name", "last_name").filter(first_name__lt="Ava")
page = qs.paginate(1, 100) page = qs.paginate(1, 100)
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
def test_distinct(self): def test_distinct(self):
qs = Person.objects_in(self.database).distinct() qs = Person.objects_in(self.database).distinct()
self._test_qs(qs, 100) self._test_qs(qs, 100)
self._test_qs(qs.only('first_name'), 94) self._test_qs(qs.only("first_name"), 94)
def test_materialized_field(self): def test_materialized_field(self):
self._insert_sample_model() self._insert_sample_model()
@ -291,31 +308,31 @@ class QuerySetTestCase(TestCaseWithData):
Person.objects_in(self.database).final() Person.objects_in(self.database).final()
self._insert_sample_collapsing_model() self._insert_sample_collapsing_model()
res = list(SampleCollapsingModel.objects_in(self.database).final().order_by('num')) res = list(SampleCollapsingModel.objects_in(self.database).final().order_by("num"))
self.assertEqual(4, len(res)) self.assertEqual(4, len(res))
for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)): for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)):
self.assertEqual(exp_color, item.color) self.assertEqual(exp_color, item.color)
def test_mixed_filter(self): def test_mixed_filter(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
qs = qs.filter(Q(first_name='a'), F('greater', Person.height, 1.7), last_name='b') qs = qs.filter(Q(first_name="a"), F("greater", Person.height, 1.7), last_name="b")
self.assertEqual(qs.conditions_as_sql(), self.assertEqual(
"(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')") qs.conditions_as_sql(), "(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')"
)
def test_invalid_filter(self): def test_invalid_filter(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
qs.filter('foo') qs.filter("foo")
class AggregateTestCase(TestCaseWithData): class AggregateTestCase(TestCaseWithData):
def setUp(self): def setUp(self):
super(AggregateTestCase, self).setUp() super(AggregateTestCase, self).setUp()
self.database.insert(self._sample_data()) self.database.insert(self._sample_data())
def test_aggregate_no_grouping(self): def test_aggregate_no_grouping(self):
qs = Person.objects_in(self.database).aggregate(average_height='avg(height)', count='count()') qs = Person.objects_in(self.database).aggregate(average_height="avg(height)", count="count()")
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
for row in qs: for row in qs:
@ -331,14 +348,22 @@ class AggregateTestCase(TestCaseWithData):
def test_aggregate_with_filter(self): def test_aggregate_with_filter(self):
# When filter comes before aggregate # When filter comes before aggregate
qs = Person.objects_in(self.database).filter(first_name='Warren').aggregate(average_height='avg(height)', count='count()') qs = (
Person.objects_in(self.database)
.filter(first_name="Warren")
.aggregate(average_height="avg(height)", count="count()")
)
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
for row in qs: for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4) self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2) self.assertEqual(row.count, 2)
# When filter comes after aggregate # When filter comes after aggregate
qs = Person.objects_in(self.database).aggregate(average_height='avg(height)', count='count()').filter(first_name='Warren') qs = (
Person.objects_in(self.database)
.aggregate(average_height="avg(height)", count="count()")
.filter(first_name="Warren")
)
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
for row in qs: for row in qs:
@ -347,14 +372,22 @@ class AggregateTestCase(TestCaseWithData):
def test_aggregate_with_filter__funcs(self): def test_aggregate_with_filter__funcs(self):
# When filter comes before aggregate # When filter comes before aggregate
qs = Person.objects_in(self.database).filter(Person.first_name=='Warren').aggregate(average_height=F.avg(Person.height), count=F.count()) qs = (
Person.objects_in(self.database)
.filter(Person.first_name == "Warren")
.aggregate(average_height=F.avg(Person.height), count=F.count())
)
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
for row in qs: for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4) self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2) self.assertEqual(row.count, 2)
# When filter comes after aggregate # When filter comes after aggregate
qs = Person.objects_in(self.database).aggregate(average_height=F.avg(Person.height), count=F.count()).filter(Person.first_name=='Warren') qs = (
Person.objects_in(self.database)
.aggregate(average_height=F.avg(Person.height), count=F.count())
.filter(Person.first_name == "Warren")
)
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
for row in qs: for row in qs:
@ -362,7 +395,7 @@ class AggregateTestCase(TestCaseWithData):
self.assertEqual(row.count, 2) self.assertEqual(row.count, 2)
def test_aggregate_with_implicit_grouping(self): def test_aggregate_with_implicit_grouping(self):
qs = Person.objects_in(self.database).aggregate('first_name', average_height='avg(height)', count='count()') qs = Person.objects_in(self.database).aggregate("first_name", average_height="avg(height)", count="count()")
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 94) self.assertEqual(qs.count(), 94)
total = 0 total = 0
@ -373,7 +406,11 @@ class AggregateTestCase(TestCaseWithData):
self.assertEqual(total, 100) self.assertEqual(total, 100)
def test_aggregate_with_explicit_grouping(self): def test_aggregate_with_explicit_grouping(self):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 7) self.assertEqual(qs.count(), 7)
total = 0 total = 0
@ -382,24 +419,40 @@ class AggregateTestCase(TestCaseWithData):
self.assertEqual(total, 100) self.assertEqual(total, 100)
def test_aggregate_with_order_by(self): def test_aggregate_with_order_by(self):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
days = [row.weekday for row in qs.order_by('weekday')] Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
days = [row.weekday for row in qs.order_by("weekday")]
self.assertEqual(days, list(range(1, 8))) self.assertEqual(days, list(range(1, 8)))
def test_aggregate_with_indexing(self): def test_aggregate_with_indexing(self):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
total = 0 total = 0
for i in range(7): for i in range(7):
total += qs[i].count total += qs[i].count
self.assertEqual(total, 100) self.assertEqual(total, 100)
def test_aggregate_with_slicing(self): def test_aggregate_with_slicing(self):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
total = sum(row.count for row in qs[:3]) + sum(row.count for row in qs[3:]) total = sum(row.count for row in qs[:3]) + sum(row.count for row in qs[3:])
self.assertEqual(total, 100) self.assertEqual(total, 100)
def test_aggregate_with_pagination(self): def test_aggregate_with_pagination(self):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
total = 0 total = 0
page_num = 1 page_num = 1
while True: while True:
@ -413,7 +466,9 @@ class AggregateTestCase(TestCaseWithData):
def test_aggregate_with_wrong_grouping(self): def test_aggregate_with_wrong_grouping(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('first_name') Person.objects_in(self.database).aggregate(weekday="toDayOfWeek(birthday)", count="count()").group_by(
"first_name"
)
def test_aggregate_with_no_calculated_fields(self): def test_aggregate_with_no_calculated_fields(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -422,31 +477,41 @@ class AggregateTestCase(TestCaseWithData):
def test_aggregate_with_only(self): def test_aggregate_with_only(self):
# Cannot put only() after aggregate() # Cannot put only() after aggregate()
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').only('weekday') Person.objects_in(self.database).aggregate(weekday="toDayOfWeek(birthday)", count="count()").only("weekday")
# When only() comes before aggregate(), it gets overridden # When only() comes before aggregate(), it gets overridden
qs = Person.objects_in(self.database).only('last_name').aggregate(average_height='avg(height)', count='count()') qs = Person.objects_in(self.database).only("last_name").aggregate(average_height="avg(height)", count="count()")
self.assertTrue('last_name' not in qs.as_sql()) self.assertTrue("last_name" not in qs.as_sql())
def test_aggregate_on_aggregate(self): def test_aggregate_on_aggregate(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').aggregate(s='sum(height)') Person.objects_in(self.database).aggregate(weekday="toDayOfWeek(birthday)", count="count()").aggregate(
s="sum(height)"
)
def test_filter_on_calculated_field(self): def test_filter_on_calculated_field(self):
# This is currently not supported, so we expect it to fail # This is currently not supported, so we expect it to fail
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') qs = (
Person.objects_in(self.database)
.aggregate(weekday="toDayOfWeek(birthday)", count="count()")
.group_by("weekday")
)
qs = qs.filter(weekday=1) qs = qs.filter(weekday=1)
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
def test_aggregate_with_distinct(self): def test_aggregate_with_distinct(self):
# In this case distinct has no effect # In this case distinct has no effect
qs = Person.objects_in(self.database).aggregate(average_height='avg(height)').distinct() qs = Person.objects_in(self.database).aggregate(average_height="avg(height)").distinct()
print(qs.as_sql()) print(qs.as_sql())
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
def test_aggregate_with_totals(self): def test_aggregate_with_totals(self):
qs = Person.objects_in(self.database).aggregate('first_name', count='count()').\ qs = (
with_totals().order_by('-count')[:5] Person.objects_in(self.database)
.aggregate("first_name", count="count()")
.with_totals()
.order_by("-count")[:5]
)
print(qs.as_sql()) print(qs.as_sql())
result = list(qs) result = list(qs)
self.assertEqual(len(result), 6) self.assertEqual(len(result), 6)
@ -460,61 +525,68 @@ class AggregateTestCase(TestCaseWithData):
the__number = Int32Field() the__number = Int32Field()
the__next__number = Int32Field() the__next__number = Int32Field()
engine = Memory() engine = Memory()
qs = Mdl.objects_in(self.database).filter(the__number=1) qs = Mdl.objects_in(self.database).filter(the__number=1)
self.assertEqual(qs.conditions_as_sql(), 'the__number = 1') self.assertEqual(qs.conditions_as_sql(), "the__number = 1")
qs = Mdl.objects_in(self.database).filter(the__number__gt=1) qs = Mdl.objects_in(self.database).filter(the__number__gt=1)
self.assertEqual(qs.conditions_as_sql(), 'the__number > 1') self.assertEqual(qs.conditions_as_sql(), "the__number > 1")
qs = Mdl.objects_in(self.database).filter(the__next__number=1) qs = Mdl.objects_in(self.database).filter(the__next__number=1)
self.assertEqual(qs.conditions_as_sql(), 'the__next__number = 1') self.assertEqual(qs.conditions_as_sql(), "the__next__number = 1")
qs = Mdl.objects_in(self.database).filter(the__next__number__gt=1) qs = Mdl.objects_in(self.database).filter(the__next__number__gt=1)
self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1') self.assertEqual(qs.conditions_as_sql(), "the__next__number > 1")
def test_limit_by(self): def test_limit_by(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Test without offset # Test without offset
qs = Person.objects_in(self.database).aggregate('first_name', 'last_name', 'height', n='count()').\ qs = (
order_by('first_name', '-height').limit_by(1, 'first_name') Person.objects_in(self.database)
.aggregate("first_name", "last_name", "height", n="count()")
.order_by("first_name", "-height")
.limit_by(1, "first_name")
)
self.assertEqual(qs.count(), 94) self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen') self.assertEqual(list(qs)[89].last_name, "Bowen")
# Test with funcs and fields # Test with funcs and fields
qs = Person.objects_in(self.database).aggregate(Person.first_name, Person.last_name, Person.height, n=F.count()).\ qs = (
order_by(Person.first_name, '-height').limit_by(1, F.upper(Person.first_name)) Person.objects_in(self.database)
.aggregate(Person.first_name, Person.last_name, Person.height, n=F.count())
.order_by(Person.first_name, "-height")
.limit_by(1, F.upper(Person.first_name))
)
self.assertEqual(qs.count(), 94) self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen') self.assertEqual(list(qs)[89].last_name, "Bowen")
# Test with limit and offset, also mixing LIMIT with LIMIT BY # Test with limit and offset, also mixing LIMIT with LIMIT BY
qs = Person.objects_in(self.database).filter(height__gt=1.67).order_by('height', 'first_name') qs = Person.objects_in(self.database).filter(height__gt=1.67).order_by("height", "first_name")
limited_qs = qs.limit_by((0, 3), 'height') limited_qs = qs.limit_by((0, 3), "height")
self.assertEqual([p.first_name for p in limited_qs[:3]], ['Amanda', 'Buffy', 'Dora']) self.assertEqual([p.first_name for p in limited_qs[:3]], ["Amanda", "Buffy", "Dora"])
limited_qs = qs.limit_by((3, 3), 'height') limited_qs = qs.limit_by((3, 3), "height")
self.assertEqual([p.first_name for p in limited_qs[:3]], ['Elton', 'Josiah', 'Macaulay']) self.assertEqual([p.first_name for p in limited_qs[:3]], ["Elton", "Josiah", "Macaulay"])
limited_qs = qs.limit_by((6, 3), 'height') limited_qs = qs.limit_by((6, 3), "height")
self.assertEqual([p.first_name for p in limited_qs[:3]], ['Norman', 'Octavius', 'Oliver']) self.assertEqual([p.first_name for p in limited_qs[:3]], ["Norman", "Octavius", "Oliver"])
Color = Enum('Color', u'red blue green yellow brown white black') Color = Enum("Color", u"red blue green yellow brown white black")
class SampleModel(Model): class SampleModel(Model):
timestamp = DateTimeField() timestamp = DateTimeField()
materialized_date = DateField(materialized='toDate(timestamp)') materialized_date = DateField(materialized="toDate(timestamp)")
num = Int32Field() num = Int32Field()
color = Enum8Field(Color) color = Enum8Field(Color)
num_squared = Int32Field(alias='num*num') num_squared = Int32Field(alias="num*num")
engine = MergeTree('materialized_date', ('materialized_date',)) engine = MergeTree("materialized_date", ("materialized_date",))
class SampleCollapsingModel(SampleModel): class SampleCollapsingModel(SampleModel):
sign = Int8Field(default=1) sign = Int8Field(default=1)
engine = CollapsingMergeTree('materialized_date', ('num',), 'sign') engine = CollapsingMergeTree("materialized_date", ("num",), "sign")
class Numbers(Model): class Numbers(Model):
number = UInt64Field() number = UInt64Field()

View File

@ -5,7 +5,6 @@ from .base_test_with_data import *
class ReadonlyTestCase(TestCaseWithData): class ReadonlyTestCase(TestCaseWithData):
def _test_readonly_db(self, username): def _test_readonly_db(self, username):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
orig_database = self.database orig_database = self.database
@ -16,7 +15,7 @@ class ReadonlyTestCase(TestCaseWithData):
self._check_db_readonly_err(cm.exception) self._check_db_readonly_err(cm.exception)
self.assertEqual(self.database.count(Person), 100) self.assertEqual(self.database.count(Person), 100)
list(self.database.select('SELECT * from $table', Person)) list(self.database.select("SELECT * from $table", Person))
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
self.database.drop_table(Person) self.database.drop_table(Person)
self._check_db_readonly_err(cm.exception, drop_table=True) self._check_db_readonly_err(cm.exception, drop_table=True)
@ -25,9 +24,11 @@ class ReadonlyTestCase(TestCaseWithData):
self.database.drop_database() self.database.drop_database()
self._check_db_readonly_err(cm.exception, drop_table=True) self._check_db_readonly_err(cm.exception, drop_table=True)
except ServerError as e: except ServerError as e:
if e.code == 192 and e.message.startswith('Unknown user'): # ClickHouse version < 20.3 if e.code == 192 and e.message.startswith("Unknown user"): # ClickHouse version < 20.3
raise unittest.SkipTest('Database user "%s" is not defined' % username) raise unittest.SkipTest('Database user "%s" is not defined' % username)
elif e.code == 516 and e.message.startswith('readonly: Authentication failed'): # ClickHouse version >= 20.3 elif e.code == 516 and e.message.startswith(
"readonly: Authentication failed"
): # ClickHouse version >= 20.3
raise unittest.SkipTest('Database user "%s" is not defined' % username) raise unittest.SkipTest('Database user "%s" is not defined' % username)
else: else:
raise raise
@ -38,20 +39,20 @@ class ReadonlyTestCase(TestCaseWithData):
self.assertEqual(exc.code, 164) self.assertEqual(exc.code, 164)
print(exc.message) print(exc.message)
if self.database.server_version >= (20, 3): if self.database.server_version >= (20, 3):
self.assertTrue('Cannot execute query in readonly mode' in exc.message) self.assertTrue("Cannot execute query in readonly mode" in exc.message)
elif drop_table: elif drop_table:
self.assertTrue(exc.message.startswith('Cannot drop table in readonly mode')) self.assertTrue(exc.message.startswith("Cannot drop table in readonly mode"))
else: else:
self.assertTrue(exc.message.startswith('Cannot insert into table in readonly mode')) self.assertTrue(exc.message.startswith("Cannot insert into table in readonly mode"))
def test_readonly_db_with_default_user(self): def test_readonly_db_with_default_user(self):
self._test_readonly_db('default') self._test_readonly_db("default")
def test_readonly_db_with_readonly_user(self): def test_readonly_db_with_readonly_user(self):
self._test_readonly_db('readonly') self._test_readonly_db("readonly")
def test_insert_readonly(self): def test_insert_readonly(self):
m = ReadOnlyModel(name='readonly') m = ReadOnlyModel(name="readonly")
self.database.create_table(ReadOnlyModel) self.database.create_table(ReadOnlyModel)
with self.assertRaises(DatabaseException): with self.assertRaises(DatabaseException):
self.database.insert([m]) self.database.insert([m])
@ -64,8 +65,8 @@ class ReadonlyTestCase(TestCaseWithData):
def test_nonexisting_readonly_database(self): def test_nonexisting_readonly_database(self):
with self.assertRaises(DatabaseException) as cm: with self.assertRaises(DatabaseException) as cm:
db = Database('dummy', readonly=True) db = Database("dummy", readonly=True)
self.assertEqual(str(cm.exception), 'Database does not exist, and cannot be created under readonly connection') self.assertEqual(str(cm.exception), "Database does not exist, and cannot be created under readonly connection")
class ReadOnlyModel(Model): class ReadOnlyModel(Model):
@ -73,4 +74,4 @@ class ReadOnlyModel(Model):
name = StringField() name = StringField()
date = DateField() date = DateField()
engine = MergeTree('date', ('name',)) engine = MergeTree("date", ("name",))

View File

@ -4,28 +4,36 @@ from clickhouse_orm.database import ServerError
class ServerErrorTest(unittest.TestCase): class ServerErrorTest(unittest.TestCase):
def test_old_format(self): def test_old_format(self):
code, msg = ServerError.get_error_code_msg("Code: 81, e.displayText() = DB::Exception: Database db_not_here doesn't exist, e.what() = DB::Exception (from [::1]:33458)") code, msg = ServerError.get_error_code_msg(
"Code: 81, e.displayText() = DB::Exception: Database db_not_here doesn't exist, e.what() = DB::Exception (from [::1]:33458)"
)
self.assertEqual(code, 81) self.assertEqual(code, 81)
self.assertEqual(msg, "Database db_not_here doesn't exist") self.assertEqual(msg, "Database db_not_here doesn't exist")
code, msg = ServerError.get_error_code_msg("Code: 161, e.displayText() = DB::Exception: Limit for number of columns to read exceeded. Requested: 11, maximum: 1, e.what() = DB::Exception\n") code, msg = ServerError.get_error_code_msg(
"Code: 161, e.displayText() = DB::Exception: Limit for number of columns to read exceeded. Requested: 11, maximum: 1, e.what() = DB::Exception\n"
)
self.assertEqual(code, 161) self.assertEqual(code, 161)
self.assertEqual(msg, "Limit for number of columns to read exceeded. Requested: 11, maximum: 1") self.assertEqual(msg, "Limit for number of columns to read exceeded. Requested: 11, maximum: 1")
def test_new_format(self): def test_new_format(self):
code, msg = ServerError.get_error_code_msg("Code: 164, e.displayText() = DB::Exception: Cannot drop table in readonly mode") code, msg = ServerError.get_error_code_msg(
"Code: 164, e.displayText() = DB::Exception: Cannot drop table in readonly mode"
)
self.assertEqual(code, 164) self.assertEqual(code, 164)
self.assertEqual(msg, "Cannot drop table in readonly mode") self.assertEqual(msg, "Cannot drop table in readonly mode")
code, msg = ServerError.get_error_code_msg("Code: 48, e.displayText() = DB::Exception: Method write is not supported by storage Merge") code, msg = ServerError.get_error_code_msg(
"Code: 48, e.displayText() = DB::Exception: Method write is not supported by storage Merge"
)
self.assertEqual(code, 48) self.assertEqual(code, 48)
self.assertEqual(msg, "Method write is not supported by storage Merge") self.assertEqual(msg, "Method write is not supported by storage Merge")
code, msg = ServerError.get_error_code_msg("Code: 60, e.displayText() = DB::Exception: Table default.zuzu doesn't exist.\n") code, msg = ServerError.get_error_code_msg(
"Code: 60, e.displayText() = DB::Exception: Table default.zuzu doesn't exist.\n"
)
self.assertEqual(code, 60) self.assertEqual(code, 60)
self.assertEqual(msg, "Table default.zuzu doesn't exist.") self.assertEqual(msg, "Table default.zuzu doesn't exist.")

View File

@ -9,11 +9,20 @@ class SimpleFieldsTest(unittest.TestCase):
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc) epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values # Valid values
dates = [ dates = [
date(1970, 1, 1), datetime(1970, 1, 1), epoch, date(1970, 1, 1),
epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')), datetime(1970, 1, 1),
'1970-01-01 00:00:00', '1970-01-17 00:00:17', '0000-00-00 00:00:00', 0, epoch,
'2017-07-26T08:31:05', '2017-07-26T08:31:05Z', '2017-07-26 08:31', epoch.astimezone(pytz.timezone("US/Eastern")),
'2017-07-26T13:31:05+05', '2017-07-26 13:31:05+0500' epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
"1970-01-01 00:00:00",
"1970-01-17 00:00:17",
"0000-00-00 00:00:00",
0,
"2017-07-26T08:31:05",
"2017-07-26T08:31:05Z",
"2017-07-26 08:31",
"2017-07-26T13:31:05+05",
"2017-07-26 13:31:05+0500",
] ]
def test_datetime_field(self): def test_datetime_field(self):
@ -25,8 +34,7 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5, for value in ("nope", "21/7/1999", 0.5, "2017-01 15:06:00", "2017-01-01X15:06:00", "2017-13-01T15:06:00"):
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
@ -35,10 +43,16 @@ class SimpleFieldsTest(unittest.TestCase):
# Valid values # Valid values
for value in self.dates + [ for value in self.dates + [
datetime(1970, 1, 1, microsecond=100000), datetime(1970, 1, 1, microsecond=100000),
pytz.timezone('US/Eastern').localize(datetime(1970, 1, 1, microsecond=100000)), pytz.timezone("US/Eastern").localize(datetime(1970, 1, 1, microsecond=100000)),
'1970-01-01 00:00:00.1', '1970-01-17 00:00:17.1', '0000-00-00 00:00:00.1', 0.1, "1970-01-01 00:00:00.1",
'2017-07-26T08:31:05.1', '2017-07-26T08:31:05.1Z', '2017-07-26 08:31.1', "1970-01-17 00:00:17.1",
'2017-07-26T13:31:05.1+05', '2017-07-26 13:31:05.1+0500' "0000-00-00 00:00:00.1",
0.1,
"2017-07-26T08:31:05.1",
"2017-07-26T08:31:05.1Z",
"2017-07-26 08:31.1",
"2017-07-26T13:31:05.1+05",
"2017-07-26 13:31:05.1+0500",
]: ]:
dt = f.to_python(value, pytz.utc) dt = f.to_python(value, pytz.utc)
self.assertTrue(dt.tzinfo) self.assertTrue(dt.tzinfo)
@ -46,8 +60,7 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', for value in ("nope", "21/7/1999", "2017-01 15:06:00", "2017-01-01X15:06:00", "2017-13-01T15:06:00"):
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
@ -56,21 +69,21 @@ class SimpleFieldsTest(unittest.TestCase):
f = DateTime64Field(precision=precision, timezone=pytz.utc) f = DateTime64Field(precision=precision, timezone=pytz.utc)
dt = f.to_python(datetime(2000, 1, 1, microsecond=123456), pytz.utc) dt = f.to_python(datetime(2000, 1, 1, microsecond=123456), pytz.utc)
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
m = round(123456, precision - 6) # round rightmost microsecond digits according to precision m = round(123456, precision - 6) # round rightmost microsecond digits according to precision
self.assertEqual(dt2, dt.replace(microsecond=m)) self.assertEqual(dt2, dt.replace(microsecond=m))
def test_date_field(self): def test_date_field(self):
f = DateField() f = DateField()
epoch = date(1970, 1, 1) epoch = date(1970, 1, 1)
# Valid values # Valid values
for value in (datetime(1970, 1, 1), epoch, '1970-01-01', '0000-00-00', 0): for value in (datetime(1970, 1, 1), epoch, "1970-01-01", "0000-00-00", 0):
d = f.to_python(value, pytz.utc) d = f.to_python(value, pytz.utc)
self.assertEqual(d, epoch) self.assertEqual(d, epoch)
# Verify that conversion to and from db string does not change value # Verify that conversion to and from db string does not change value
d2 = f.to_python(f.to_db_string(d, quote=False), pytz.utc) d2 = f.to_python(f.to_db_string(d, quote=False), pytz.utc)
self.assertEqual(d, d2) self.assertEqual(d, d2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5): for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
# Range check # Range check
@ -81,29 +94,29 @@ class SimpleFieldsTest(unittest.TestCase):
def test_date_field_timezone(self): def test_date_field_timezone(self):
# Verify that conversion of timezone-aware datetime is correct # Verify that conversion of timezone-aware datetime is correct
f = DateField() f = DateField()
dt = datetime(2017, 10, 5, tzinfo=pytz.timezone('Asia/Jerusalem')) dt = datetime(2017, 10, 5, tzinfo=pytz.timezone("Asia/Jerusalem"))
self.assertEqual(f.to_python(dt, pytz.utc), date(2017, 10, 4)) self.assertEqual(f.to_python(dt, pytz.utc), date(2017, 10, 4))
def test_datetime_field_timezone(self): def test_datetime_field_timezone(self):
# Verify that conversion of timezone-aware datetime is correct # Verify that conversion of timezone-aware datetime is correct
f = DateTimeField() f = DateTimeField()
utc_value = datetime(2017, 7, 26, 8, 31, 5, tzinfo=pytz.UTC) utc_value = datetime(2017, 7, 26, 8, 31, 5, tzinfo=pytz.UTC)
for value in ( for value in (
'2017-07-26T08:31:05', "2017-07-26T08:31:05",
'2017-07-26T08:31:05Z', "2017-07-26T08:31:05Z",
'2017-07-26T11:31:05+03', "2017-07-26T11:31:05+03",
'2017-07-26 11:31:05+0300', "2017-07-26 11:31:05+0300",
'2017-07-26T03:31:05-0500', "2017-07-26T03:31:05-0500",
): ):
self.assertEqual(f.to_python(value, pytz.utc), utc_value) self.assertEqual(f.to_python(value, pytz.utc), utc_value)
def test_uint8_field(self): def test_uint8_field(self):
f = UInt8Field() f = UInt8Field()
# Valid values # Valid values
for value in (17, '17', 17.0): for value in (17, "17", 17.0):
self.assertEqual(f.to_python(value, pytz.utc), 17) self.assertEqual(f.to_python(value, pytz.utc), 17)
# Invalid values # Invalid values
for value in ('nope', date.today()): for value in ("nope", date.today()):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
# Range check # Range check

View File

@ -11,9 +11,8 @@ from clickhouse_orm.system_models import SystemPart
class SystemTest(unittest.TestCase): class SystemTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -34,10 +33,10 @@ class SystemTest(unittest.TestCase):
class SystemPartTest(unittest.TestCase): class SystemPartTest(unittest.TestCase):
BACKUP_DIRS = ['/var/lib/clickhouse/shadow', '/opt/clickhouse/shadow/'] BACKUP_DIRS = ["/var/lib/clickhouse/shadow", "/opt/clickhouse/shadow/"]
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(TestTable) self.database.create_table(TestTable)
self.database.create_table(CustomPartitionedTable) self.database.create_table(CustomPartitionedTable)
self.database.insert([TestTable(date_field=date.today())]) self.database.insert([TestTable(date_field=date.today())])
@ -51,7 +50,7 @@ class SystemPartTest(unittest.TestCase):
if os.path.exists(dir): if os.path.exists(dir):
_, dirnames, _ = next(os.walk(dir)) _, dirnames, _ = next(os.walk(dir))
return dirnames return dirnames
raise unittest.SkipTest('Cannot find backups dir') raise unittest.SkipTest("Cannot find backups dir")
def test_is_read_only(self): def test_is_read_only(self):
self.assertTrue(SystemPart.is_read_only()) self.assertTrue(SystemPart.is_read_only())
@ -109,20 +108,20 @@ class SystemPartTest(unittest.TestCase):
def test_query(self): def test_query(self):
SystemPart.objects_in(self.database).count() SystemPart.objects_in(self.database).count()
list(SystemPart.objects_in(self.database).filter(table='testtable')) list(SystemPart.objects_in(self.database).filter(table="testtable"))
class TestTable(Model): class TestTable(Model):
date_field = DateField() date_field = DateField()
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))
class CustomPartitionedTable(Model): class CustomPartitionedTable(Model):
date_field = DateField() date_field = DateField()
group_field = UInt32Field() group_field = UInt32Field()
engine = MergeTree(order_by=('date_field', 'group_field'), partition_key=('toYYYYMM(date_field)', 'group_field')) engine = MergeTree(order_by=("date_field", "group_field"), partition_key=("toYYYYMM(date_field)", "group_field"))
class SystemTestModel(Model): class SystemTestModel(Model):

View File

@ -7,29 +7,29 @@ from clickhouse_orm.engines import Memory
class UUIDFieldsTest(unittest.TestCase): class UUIDFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_uuid_field(self): def test_uuid_field(self):
if self.database.server_version < (18, 1): if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = UUIDField() f = UUIDField()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same UUID) # Check valid values (all values are the same UUID)
values = [ values = [
'12345678-1234-5678-1234-567812345678', "12345678-1234-5678-1234-567812345678",
'{12345678-1234-5678-1234-567812345678}', "{12345678-1234-5678-1234-567812345678}",
'12345678123456781234567812345678', "12345678123456781234567812345678",
'urn:uuid:12345678-1234-5678-1234-567812345678', "urn:uuid:12345678-1234-5678-1234-567812345678",
b'\x12\x34\x56\x78'*4, b"\x12\x34\x56\x78" * 4,
(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678), (0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678),
0x12345678123456781234567812345678, 0x12345678123456781234567812345678,
UUID(int=0x12345678123456781234567812345678), UUID(int=0x12345678123456781234567812345678),
@ -40,7 +40,6 @@ class UUIDFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, UUID(values[0])) self.assertEqual(rec.f, UUID(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)