From ffd9bab0ef58c5670f3a39c978525f5f32a00742 Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Sat, 6 Jun 2020 11:07:25 +0300 Subject: [PATCH] Support for model constraints --- src/infi/clickhouse_orm/models.py | 79 +++++++++++++++++++++++++------ tests/sample_migrations/0016.py | 6 +++ tests/sample_migrations/0017.py | 6 +++ tests/test_constraints.py | 45 ++++++++++++++++++ tests/test_migrations.py | 54 ++++++++++++++++++++- 5 files changed, 174 insertions(+), 16 deletions(-) create mode 100644 tests/sample_migrations/0016.py create mode 100644 tests/sample_migrations/0017.py create mode 100644 tests/test_constraints.py diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index e4766e5..506185b 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import sys from collections import OrderedDict +from itertools import chain from logging import getLogger import pytz @@ -14,6 +15,31 @@ from .engines import Merge, Distributed logger = getLogger('clickhouse_orm') + +class Constraint(): + ''' + Defines a model constraint. + ''' + + name = None # this is set by the parent model + parent = None # this is set by the parent model + + def __init__(self, expr): + ''' + Initializer. Expects an expression that ClickHouse will verify when inserting data. + ''' + self.expr = expr + + def create_table_sql(self): + ''' + Returns the SQL statement for defining this constraint during table creation. + ''' + return 'CONSTRAINT `%s` CHECK %s' % (self.name, self.expr) + + def str(self): + return self.create_table_sql() + + class ModelBase(type): ''' A metaclass for ORM models. It adds the _fields list to model classes. @@ -22,18 +48,21 @@ class ModelBase(type): ad_hoc_model_cache = {} def __new__(cls, name, bases, attrs): - # Collect fields from parent classes - base_fields = dict() + # Collect fields and constraints from parent classes + fields = dict() + constraints = dict() for base in bases: if isinstance(base, ModelBase): - base_fields.update(base._fields) + fields.update(base._fields) + constraints.update(base._constraints) - fields = base_fields - - # Build a list of fields, in the order they were listed in the class + # Build a list of (name, field) tuples, in the order they were listed in the class fields.update({n: f for n, f in attrs.items() if isinstance(f, Field)}) fields = sorted(fields.items(), key=lambda item: item[1].creation_counter) + # Build a list of constraints + constraints.update({n: c for n, c in attrs.items() if isinstance(c, Constraint)}) + # Build a dictionary of default values defaults = {} has_funcs_as_defaults = False @@ -49,16 +78,17 @@ class ModelBase(type): attrs = dict( attrs, _fields=OrderedDict(fields), + _constraints=constraints, _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _defaults=defaults, _has_funcs_as_defaults=has_funcs_as_defaults ) model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) - # Let each field know its parent and its own name - for n, f in fields: - setattr(f, 'parent', model) - setattr(f, 'name', n) + # Let each field and constraint know its parent and its own name + for n, obj in chain(fields, constraints.items()): + setattr(obj, 'parent', model) + setattr(obj, 'name', n) return model @@ -222,17 +252,27 @@ class Model(metaclass=ModelBase): @classmethod def create_table_sql(cls, db): ''' - Returns the SQL command for creating a table for this model. + Returns the SQL statement for creating a table for this model. ''' parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] cols = [] for name, field in cls.fields().items(): cols.append(' %s %s' % (name, field.get_sql(db=db))) parts.append(',\n'.join(cols)) + parts.append(cls._constraints_sql()) parts.append(')') parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) return '\n'.join(parts) + @classmethod + def _constraints_sql(cls): + ''' + Returns this model's contraints as SQL. + ''' + if not cls._constraints: + return '' + return ',' + ',\n'.join(c.create_table_sql() for c in cls._constraints.values()) + @classmethod def drop_table_sql(cls, db): ''' @@ -348,7 +388,7 @@ class BufferModel(Model): @classmethod def create_table_sql(cls, db): ''' - Returns the SQL command for creating a table for this model. + Returns the SQL statement for creating a table for this model. ''' parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())] @@ -370,6 +410,9 @@ class MergeModel(Model): @classmethod def create_table_sql(cls, db): + ''' + Returns the SQL statement for creating a table for this model. + ''' assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] cols = [] @@ -377,6 +420,7 @@ class MergeModel(Model): if name != '_table': cols.append(' %s %s' % (name, field.get_sql(db=db))) parts.append(',\n'.join(cols)) + parts.append(cls._constraints_sql()) parts.append(')') parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) return '\n'.join(parts) @@ -386,10 +430,14 @@ class MergeModel(Model): class DistributedModel(Model): """ - Model for Distributed engine + Model class for use with a `Distributed` engine. """ def set_database(self, db): + ''' + Sets the `Database` that this model instance belongs to. + This is done automatically when the instance is read from the database or written to it. + ''' assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" res = super(DistributedModel, self).set_database(db) return res @@ -447,6 +495,9 @@ class DistributedModel(Model): @classmethod def create_table_sql(cls, db): + ''' + Returns the SQL statement for creating a table for this model. + ''' assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" cls.fix_engine_table() @@ -459,4 +510,4 @@ class DistributedModel(Model): # Expose only relevant classes in import * -__all__ = get_subclass_names(locals(), Model) +__all__ = get_subclass_names(locals(), (Model, Constraint)) diff --git a/tests/sample_migrations/0016.py b/tests/sample_migrations/0016.py new file mode 100644 index 0000000..6f0f814 --- /dev/null +++ b/tests/sample_migrations/0016.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.CreateTable(ModelWithConstraints) +] diff --git a/tests/sample_migrations/0017.py b/tests/sample_migrations/0017.py new file mode 100644 index 0000000..4151189 --- /dev/null +++ b/tests/sample_migrations/0017.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.AlterConstraints(ModelWithConstraints2) +] diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 0000000..a14ed6c --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,45 @@ +import unittest + +from infi.clickhouse_orm import * +from .base_test_with_data import Person + + +class ArrayFieldsTest(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db', log_statements=True) + self.database.create_table(PersonWithConstraints) + + def tearDown(self): + self.database.drop_database() + + def test_insert_valid_values(self): + self.database.insert([ + PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66) + ]) + + def test_insert_invalid_values(self): + if self.database.server_version < (19, 14, 3, 3): + raise unittest.SkipTest('ClickHouse version too old') + + with self.assertRaises(ServerError) as e: + self.database.insert([ + PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66) + ]) + self.assertEqual(e.code, 469) + self.assertTrue('Constraint `birthday_in_the_past`' in e.message) + + with self.assertRaises(ServerError) as e: + self.database.insert([ + PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3) + ]) + self.assertEqual(e.code, 469) + self.assertTrue('Constraint `max_height`' in e.message) + + +class PersonWithConstraints(Person): + + birthday_in_the_past = Constraint(Person.birthday <= F.today()) + max_height = Constraint(Person.height <= 2.75) + + diff --git a/tests/test_migrations.py b/tests/test_migrations.py index f92b1e9..e8d7776 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals import unittest -from infi.clickhouse_orm.database import Database -from infi.clickhouse_orm.models import Model, BufferModel +from infi.clickhouse_orm.database import Database, ServerError +from infi.clickhouse_orm.models import Model, BufferModel, Constraint from infi.clickhouse_orm.fields import * from infi.clickhouse_orm.engines import * from infi.clickhouse_orm.migrations import MigrationHistory @@ -94,6 +94,7 @@ class MigrationsTestCase(unittest.TestCase): self.assertTrue(self.tableExists(AliasModel1)) self.assertEqual(self.getTableFields(AliasModel1), [('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')]) + # Codecs and low cardinality self.database.migrate('tests.sample_migrations', 15) self.assertTrue(self.tableExists(Model4_compressed)) if self.database.has_low_cardinality_support: @@ -106,6 +107,22 @@ class MigrationsTestCase(unittest.TestCase): [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'Nullable(String)'), ('f5', 'Array(UInt64)')]) + if self.database.server_version >= (19, 14, 3, 3): + # Adding constraints + self.database.migrate('tests.sample_migrations', 16) + self.assertTrue(self.tableExists(ModelWithConstraints)) + self.database.insert([ModelWithConstraints(f1=101, f2='a')]) + with self.assertRaises(ServerError): + self.database.insert([ModelWithConstraints(f1=99, f2='a')]) + with self.assertRaises(ServerError): + self.database.insert([ModelWithConstraints(f1=101, f2='x')]) + # Modifying constraints + self.database.migrate('tests.sample_migrations', 17) + self.database.insert([ModelWithConstraints(f1=99, f2='a')]) + with self.assertRaises(ServerError): + self.database.insert([ModelWithConstraints(f1=101, f2='a')]) + with self.assertRaises(ServerError): + self.database.insert([ModelWithConstraints(f1=99, f2='x')]) # Several different models with the same table name, to simulate a table that changes over time @@ -294,3 +311,36 @@ class Model2LowCardinality(Model): @classmethod def table_name(cls): return 'mig' + + +class ModelWithConstraints(Model): + + date = DateField() + f1 = Int32Field() + f2 = StringField() + + constraint = Constraint(f2.isIn(['a', 'b', 'c'])) # check reserved keyword as constraint name + f1_constraint = Constraint(f1 > 100) + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'modelwithconstraints' + + +class ModelWithConstraints2(Model): + + date = DateField() + f1 = Int32Field() + f2 = StringField() + + constraint = Constraint(f2.isIn(['a', 'b', 'c'])) + f1_constraint_new = Constraint(f1 < 100) + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'modelwithconstraints' +