mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-25 02:03:46 +03:00
Support for model constraints
This commit is contained in:
parent
393209e624
commit
ffd9bab0ef
|
@ -1,6 +1,7 @@
|
|||
from __future__ import unicode_literals
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
from logging import getLogger
|
||||
|
||||
import pytz
|
||||
|
@ -14,6 +15,31 @@ from .engines import Merge, Distributed
|
|||
logger = getLogger('clickhouse_orm')
|
||||
|
||||
|
||||
|
||||
class Constraint():
|
||||
'''
|
||||
Defines a model constraint.
|
||||
'''
|
||||
|
||||
name = None # this is set by the parent model
|
||||
parent = None # this is set by the parent model
|
||||
|
||||
def __init__(self, expr):
|
||||
'''
|
||||
Initializer. Expects an expression that ClickHouse will verify when inserting data.
|
||||
'''
|
||||
self.expr = expr
|
||||
|
||||
def create_table_sql(self):
|
||||
'''
|
||||
Returns the SQL statement for defining this constraint during table creation.
|
||||
'''
|
||||
return 'CONSTRAINT `%s` CHECK %s' % (self.name, self.expr)
|
||||
|
||||
def str(self):
|
||||
return self.create_table_sql()
|
||||
|
||||
|
||||
class ModelBase(type):
|
||||
'''
|
||||
A metaclass for ORM models. It adds the _fields list to model classes.
|
||||
|
@ -22,18 +48,21 @@ class ModelBase(type):
|
|||
ad_hoc_model_cache = {}
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
# Collect fields from parent classes
|
||||
base_fields = dict()
|
||||
# Collect fields and constraints from parent classes
|
||||
fields = dict()
|
||||
constraints = dict()
|
||||
for base in bases:
|
||||
if isinstance(base, ModelBase):
|
||||
base_fields.update(base._fields)
|
||||
fields.update(base._fields)
|
||||
constraints.update(base._constraints)
|
||||
|
||||
fields = base_fields
|
||||
|
||||
# Build a list of fields, in the order they were listed in the class
|
||||
# Build a list of (name, field) tuples, in the order they were listed in the class
|
||||
fields.update({n: f for n, f in attrs.items() if isinstance(f, Field)})
|
||||
fields = sorted(fields.items(), key=lambda item: item[1].creation_counter)
|
||||
|
||||
# Build a list of constraints
|
||||
constraints.update({n: c for n, c in attrs.items() if isinstance(c, Constraint)})
|
||||
|
||||
# Build a dictionary of default values
|
||||
defaults = {}
|
||||
has_funcs_as_defaults = False
|
||||
|
@ -49,16 +78,17 @@ class ModelBase(type):
|
|||
attrs = dict(
|
||||
attrs,
|
||||
_fields=OrderedDict(fields),
|
||||
_constraints=constraints,
|
||||
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
|
||||
_defaults=defaults,
|
||||
_has_funcs_as_defaults=has_funcs_as_defaults
|
||||
)
|
||||
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
|
||||
|
||||
# Let each field know its parent and its own name
|
||||
for n, f in fields:
|
||||
setattr(f, 'parent', model)
|
||||
setattr(f, 'name', n)
|
||||
# Let each field and constraint know its parent and its own name
|
||||
for n, obj in chain(fields, constraints.items()):
|
||||
setattr(obj, 'parent', model)
|
||||
setattr(obj, 'name', n)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -222,17 +252,27 @@ class Model(metaclass=ModelBase):
|
|||
@classmethod
|
||||
def create_table_sql(cls, db):
|
||||
'''
|
||||
Returns the SQL command for creating a table for this model.
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
'''
|
||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
|
||||
cols = []
|
||||
for name, field in cls.fields().items():
|
||||
cols.append(' %s %s' % (name, field.get_sql(db=db)))
|
||||
parts.append(',\n'.join(cols))
|
||||
parts.append(cls._constraints_sql())
|
||||
parts.append(')')
|
||||
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
|
||||
return '\n'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def _constraints_sql(cls):
|
||||
'''
|
||||
Returns this model's contraints as SQL.
|
||||
'''
|
||||
if not cls._constraints:
|
||||
return ''
|
||||
return ',' + ',\n'.join(c.create_table_sql() for c in cls._constraints.values())
|
||||
|
||||
@classmethod
|
||||
def drop_table_sql(cls, db):
|
||||
'''
|
||||
|
@ -348,7 +388,7 @@ class BufferModel(Model):
|
|||
@classmethod
|
||||
def create_table_sql(cls, db):
|
||||
'''
|
||||
Returns the SQL command for creating a table for this model.
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
'''
|
||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name,
|
||||
cls.engine.main_model.table_name())]
|
||||
|
@ -370,6 +410,9 @@ class MergeModel(Model):
|
|||
|
||||
@classmethod
|
||||
def create_table_sql(cls, db):
|
||||
'''
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
'''
|
||||
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
|
||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
|
||||
cols = []
|
||||
|
@ -377,6 +420,7 @@ class MergeModel(Model):
|
|||
if name != '_table':
|
||||
cols.append(' %s %s' % (name, field.get_sql(db=db)))
|
||||
parts.append(',\n'.join(cols))
|
||||
parts.append(cls._constraints_sql())
|
||||
parts.append(')')
|
||||
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
|
||||
return '\n'.join(parts)
|
||||
|
@ -386,10 +430,14 @@ class MergeModel(Model):
|
|||
|
||||
class DistributedModel(Model):
|
||||
"""
|
||||
Model for Distributed engine
|
||||
Model class for use with a `Distributed` engine.
|
||||
"""
|
||||
|
||||
def set_database(self, db):
|
||||
'''
|
||||
Sets the `Database` that this model instance belongs to.
|
||||
This is done automatically when the instance is read from the database or written to it.
|
||||
'''
|
||||
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
|
||||
res = super(DistributedModel, self).set_database(db)
|
||||
return res
|
||||
|
@ -447,6 +495,9 @@ class DistributedModel(Model):
|
|||
|
||||
@classmethod
|
||||
def create_table_sql(cls, db):
|
||||
'''
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
'''
|
||||
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
|
||||
|
||||
cls.fix_engine_table()
|
||||
|
@ -459,4 +510,4 @@ class DistributedModel(Model):
|
|||
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
__all__ = get_subclass_names(locals(), Model)
|
||||
__all__ = get_subclass_names(locals(), (Model, Constraint))
|
||||
|
|
6
tests/sample_migrations/0016.py
Normal file
6
tests/sample_migrations/0016.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from infi.clickhouse_orm import migrations
|
||||
from ..test_migrations import *
|
||||
|
||||
operations = [
|
||||
migrations.CreateTable(ModelWithConstraints)
|
||||
]
|
6
tests/sample_migrations/0017.py
Normal file
6
tests/sample_migrations/0017.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from infi.clickhouse_orm import migrations
|
||||
from ..test_migrations import *
|
||||
|
||||
operations = [
|
||||
migrations.AlterConstraints(ModelWithConstraints2)
|
||||
]
|
45
tests/test_constraints.py
Normal file
45
tests/test_constraints.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import unittest
|
||||
|
||||
from infi.clickhouse_orm import *
|
||||
from .base_test_with_data import Person
|
||||
|
||||
|
||||
class ArrayFieldsTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.database = Database('test-db', log_statements=True)
|
||||
self.database.create_table(PersonWithConstraints)
|
||||
|
||||
def tearDown(self):
|
||||
self.database.drop_database()
|
||||
|
||||
def test_insert_valid_values(self):
|
||||
self.database.insert([
|
||||
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66)
|
||||
])
|
||||
|
||||
def test_insert_invalid_values(self):
|
||||
if self.database.server_version < (19, 14, 3, 3):
|
||||
raise unittest.SkipTest('ClickHouse version too old')
|
||||
|
||||
with self.assertRaises(ServerError) as e:
|
||||
self.database.insert([
|
||||
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66)
|
||||
])
|
||||
self.assertEqual(e.code, 469)
|
||||
self.assertTrue('Constraint `birthday_in_the_past`' in e.message)
|
||||
|
||||
with self.assertRaises(ServerError) as e:
|
||||
self.database.insert([
|
||||
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3)
|
||||
])
|
||||
self.assertEqual(e.code, 469)
|
||||
self.assertTrue('Constraint `max_height`' in e.message)
|
||||
|
||||
|
||||
class PersonWithConstraints(Person):
|
||||
|
||||
birthday_in_the_past = Constraint(Person.birthday <= F.today())
|
||||
max_height = Constraint(Person.height <= 2.75)
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import unicode_literals
|
||||
import unittest
|
||||
|
||||
from infi.clickhouse_orm.database import Database
|
||||
from infi.clickhouse_orm.models import Model, BufferModel
|
||||
from infi.clickhouse_orm.database import Database, ServerError
|
||||
from infi.clickhouse_orm.models import Model, BufferModel, Constraint
|
||||
from infi.clickhouse_orm.fields import *
|
||||
from infi.clickhouse_orm.engines import *
|
||||
from infi.clickhouse_orm.migrations import MigrationHistory
|
||||
|
@ -94,6 +94,7 @@ class MigrationsTestCase(unittest.TestCase):
|
|||
self.assertTrue(self.tableExists(AliasModel1))
|
||||
self.assertEqual(self.getTableFields(AliasModel1),
|
||||
[('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')])
|
||||
# Codecs and low cardinality
|
||||
self.database.migrate('tests.sample_migrations', 15)
|
||||
self.assertTrue(self.tableExists(Model4_compressed))
|
||||
if self.database.has_low_cardinality_support:
|
||||
|
@ -106,6 +107,22 @@ class MigrationsTestCase(unittest.TestCase):
|
|||
[('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'Nullable(String)'),
|
||||
('f5', 'Array(UInt64)')])
|
||||
|
||||
if self.database.server_version >= (19, 14, 3, 3):
|
||||
# Adding constraints
|
||||
self.database.migrate('tests.sample_migrations', 16)
|
||||
self.assertTrue(self.tableExists(ModelWithConstraints))
|
||||
self.database.insert([ModelWithConstraints(f1=101, f2='a')])
|
||||
with self.assertRaises(ServerError):
|
||||
self.database.insert([ModelWithConstraints(f1=99, f2='a')])
|
||||
with self.assertRaises(ServerError):
|
||||
self.database.insert([ModelWithConstraints(f1=101, f2='x')])
|
||||
# Modifying constraints
|
||||
self.database.migrate('tests.sample_migrations', 17)
|
||||
self.database.insert([ModelWithConstraints(f1=99, f2='a')])
|
||||
with self.assertRaises(ServerError):
|
||||
self.database.insert([ModelWithConstraints(f1=101, f2='a')])
|
||||
with self.assertRaises(ServerError):
|
||||
self.database.insert([ModelWithConstraints(f1=99, f2='x')])
|
||||
|
||||
# Several different models with the same table name, to simulate a table that changes over time
|
||||
|
||||
|
@ -294,3 +311,36 @@ class Model2LowCardinality(Model):
|
|||
@classmethod
|
||||
def table_name(cls):
|
||||
return 'mig'
|
||||
|
||||
|
||||
class ModelWithConstraints(Model):
|
||||
|
||||
date = DateField()
|
||||
f1 = Int32Field()
|
||||
f2 = StringField()
|
||||
|
||||
constraint = Constraint(f2.isIn(['a', 'b', 'c'])) # check reserved keyword as constraint name
|
||||
f1_constraint = Constraint(f1 > 100)
|
||||
|
||||
engine = MergeTree('date', ('date',))
|
||||
|
||||
@classmethod
|
||||
def table_name(cls):
|
||||
return 'modelwithconstraints'
|
||||
|
||||
|
||||
class ModelWithConstraints2(Model):
|
||||
|
||||
date = DateField()
|
||||
f1 = Int32Field()
|
||||
f2 = StringField()
|
||||
|
||||
constraint = Constraint(f2.isIn(['a', 'b', 'c']))
|
||||
f1_constraint_new = Constraint(f1 < 100)
|
||||
|
||||
engine = MergeTree('date', ('date',))
|
||||
|
||||
@classmethod
|
||||
def table_name(cls):
|
||||
return 'modelwithconstraints'
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user