Support for model constraints

This commit is contained in:
Itai Shirav 2020-06-06 11:07:25 +03:00
parent 393209e624
commit ffd9bab0ef
5 changed files with 174 additions and 16 deletions

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from itertools import chain
from logging import getLogger from logging import getLogger
import pytz import pytz
@ -14,6 +15,31 @@ from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm') logger = getLogger('clickhouse_orm')
class Constraint():
'''
Defines a model constraint.
'''
name = None # this is set by the parent model
parent = None # this is set by the parent model
def __init__(self, expr):
'''
Initializer. Expects an expression that ClickHouse will verify when inserting data.
'''
self.expr = expr
def create_table_sql(self):
'''
Returns the SQL statement for defining this constraint during table creation.
'''
return 'CONSTRAINT `%s` CHECK %s' % (self.name, self.expr)
def str(self):
return self.create_table_sql()
class ModelBase(type): class ModelBase(type):
''' '''
A metaclass for ORM models. It adds the _fields list to model classes. A metaclass for ORM models. It adds the _fields list to model classes.
@ -22,18 +48,21 @@ class ModelBase(type):
ad_hoc_model_cache = {} ad_hoc_model_cache = {}
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
# Collect fields from parent classes # Collect fields and constraints from parent classes
base_fields = dict() fields = dict()
constraints = dict()
for base in bases: for base in bases:
if isinstance(base, ModelBase): if isinstance(base, ModelBase):
base_fields.update(base._fields) fields.update(base._fields)
constraints.update(base._constraints)
fields = base_fields # Build a list of (name, field) tuples, in the order they were listed in the class
# Build a list of fields, in the order they were listed in the class
fields.update({n: f for n, f in attrs.items() if isinstance(f, Field)}) fields.update({n: f for n, f in attrs.items() if isinstance(f, Field)})
fields = sorted(fields.items(), key=lambda item: item[1].creation_counter) fields = sorted(fields.items(), key=lambda item: item[1].creation_counter)
# Build a list of constraints
constraints.update({n: c for n, c in attrs.items() if isinstance(c, Constraint)})
# Build a dictionary of default values # Build a dictionary of default values
defaults = {} defaults = {}
has_funcs_as_defaults = False has_funcs_as_defaults = False
@ -49,16 +78,17 @@ class ModelBase(type):
attrs = dict( attrs = dict(
attrs, attrs,
_fields=OrderedDict(fields), _fields=OrderedDict(fields),
_constraints=constraints,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults, _defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults _has_funcs_as_defaults=has_funcs_as_defaults
) )
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
# Let each field know its parent and its own name # Let each field and constraint know its parent and its own name
for n, f in fields: for n, obj in chain(fields, constraints.items()):
setattr(f, 'parent', model) setattr(obj, 'parent', model)
setattr(f, 'name', n) setattr(obj, 'name', n)
return model return model
@ -222,17 +252,27 @@ class Model(metaclass=ModelBase):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' '''
Returns the SQL command for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' '''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
cols.append(' %s %s' % (name, field.get_sql(db=db))) cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols)) parts.append(',\n'.join(cols))
parts.append(cls._constraints_sql())
parts.append(')') parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts) return '\n'.join(parts)
@classmethod
def _constraints_sql(cls):
'''
Returns this model's contraints as SQL.
'''
if not cls._constraints:
return ''
return ',' + ',\n'.join(c.create_table_sql() for c in cls._constraints.values())
@classmethod @classmethod
def drop_table_sql(cls, db): def drop_table_sql(cls, db):
''' '''
@ -348,7 +388,7 @@ class BufferModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' '''
Returns the SQL command for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' '''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name,
cls.engine.main_model.table_name())] cls.engine.main_model.table_name())]
@ -370,6 +410,9 @@ class MergeModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
'''
Returns the SQL statement for creating a table for this model.
'''
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = [] cols = []
@ -377,6 +420,7 @@ class MergeModel(Model):
if name != '_table': if name != '_table':
cols.append(' %s %s' % (name, field.get_sql(db=db))) cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols)) parts.append(',\n'.join(cols))
parts.append(cls._constraints_sql())
parts.append(')') parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts) return '\n'.join(parts)
@ -386,10 +430,14 @@ class MergeModel(Model):
class DistributedModel(Model): class DistributedModel(Model):
""" """
Model for Distributed engine Model class for use with a `Distributed` engine.
""" """
def set_database(self, db): def set_database(self, db):
'''
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
'''
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db) res = super(DistributedModel, self).set_database(db)
return res return res
@ -447,6 +495,9 @@ class DistributedModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
'''
Returns the SQL statement for creating a table for this model.
'''
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table() cls.fix_engine_table()
@ -459,4 +510,4 @@ class DistributedModel(Model):
# Expose only relevant classes in import * # Expose only relevant classes in import *
__all__ = get_subclass_names(locals(), Model) __all__ = get_subclass_names(locals(), (Model, Constraint))

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(ModelWithConstraints)
]

View File

@ -0,0 +1,6 @@
from infi.clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterConstraints(ModelWithConstraints2)
]

45
tests/test_constraints.py Normal file
View File

@ -0,0 +1,45 @@
import unittest
from infi.clickhouse_orm import *
from .base_test_with_data import Person
class ArrayFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(PersonWithConstraints)
def tearDown(self):
self.database.drop_database()
def test_insert_valid_values(self):
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66)
])
def test_insert_invalid_values(self):
if self.database.server_version < (19, 14, 3, 3):
raise unittest.SkipTest('ClickHouse version too old')
with self.assertRaises(ServerError) as e:
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66)
])
self.assertEqual(e.code, 469)
self.assertTrue('Constraint `birthday_in_the_past`' in e.message)
with self.assertRaises(ServerError) as e:
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3)
])
self.assertEqual(e.code, 469)
self.assertTrue('Constraint `max_height`' in e.message)
class PersonWithConstraints(Person):
birthday_in_the_past = Constraint(Person.birthday <= F.today())
max_height = Constraint(Person.height <= 2.75)

View File

@ -1,8 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import unittest import unittest
from infi.clickhouse_orm.database import Database from infi.clickhouse_orm.database import Database, ServerError
from infi.clickhouse_orm.models import Model, BufferModel from infi.clickhouse_orm.models import Model, BufferModel, Constraint
from infi.clickhouse_orm.fields import * from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import * from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.migrations import MigrationHistory from infi.clickhouse_orm.migrations import MigrationHistory
@ -94,6 +94,7 @@ class MigrationsTestCase(unittest.TestCase):
self.assertTrue(self.tableExists(AliasModel1)) self.assertTrue(self.tableExists(AliasModel1))
self.assertEqual(self.getTableFields(AliasModel1), self.assertEqual(self.getTableFields(AliasModel1),
[('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')]) [('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')])
# Codecs and low cardinality
self.database.migrate('tests.sample_migrations', 15) self.database.migrate('tests.sample_migrations', 15)
self.assertTrue(self.tableExists(Model4_compressed)) self.assertTrue(self.tableExists(Model4_compressed))
if self.database.has_low_cardinality_support: if self.database.has_low_cardinality_support:
@ -106,6 +107,22 @@ class MigrationsTestCase(unittest.TestCase):
[('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'Nullable(String)'), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'Nullable(String)'),
('f5', 'Array(UInt64)')]) ('f5', 'Array(UInt64)')])
if self.database.server_version >= (19, 14, 3, 3):
# Adding constraints
self.database.migrate('tests.sample_migrations', 16)
self.assertTrue(self.tableExists(ModelWithConstraints))
self.database.insert([ModelWithConstraints(f1=101, f2='a')])
with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=99, f2='a')])
with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=101, f2='x')])
# Modifying constraints
self.database.migrate('tests.sample_migrations', 17)
self.database.insert([ModelWithConstraints(f1=99, f2='a')])
with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=101, f2='a')])
with self.assertRaises(ServerError):
self.database.insert([ModelWithConstraints(f1=99, f2='x')])
# Several different models with the same table name, to simulate a table that changes over time # Several different models with the same table name, to simulate a table that changes over time
@ -294,3 +311,36 @@ class Model2LowCardinality(Model):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'mig' return 'mig'
class ModelWithConstraints(Model):
date = DateField()
f1 = Int32Field()
f2 = StringField()
constraint = Constraint(f2.isIn(['a', 'b', 'c'])) # check reserved keyword as constraint name
f1_constraint = Constraint(f1 > 100)
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'modelwithconstraints'
class ModelWithConstraints2(Model):
date = DateField()
f1 = Int32Field()
f2 = StringField()
constraint = Constraint(f2.isIn(['a', 'b', 'c']))
f1_constraint_new = Constraint(f1 < 100)
engine = MergeTree('date', ('date',))
@classmethod
def table_name(cls):
return 'modelwithconstraints'