From 6e786d75e9d88b558afb8afa2d72069c4c12577b Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Wed, 29 Jun 2016 14:52:55 +0300 Subject: [PATCH] Support model class inheritance --- src/infi/clickhouse_orm/models.py | 7 ++++- tests/test_inheritance.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/test_inheritance.py diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index 13dffa4..ece5557 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -10,8 +10,13 @@ class ModelBase(type): def __new__(cls, name, bases, attrs): new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) + # Collect fields from parent classes + base_fields = [] + for base in bases: + if isinstance(base, ModelBase): + base_fields += base._fields # Build a list of fields, in the order they were listed in the class - fields = [item for item in attrs.items() if isinstance(item[1], Field)] + fields = base_fields + [item for item in attrs.items() if isinstance(item[1], Field)] fields.sort(key=lambda item: item[1].creation_counter) setattr(new_cls, '_fields', fields) return new_cls diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py new file mode 100644 index 0000000..08bc084 --- /dev/null +++ b/tests/test_inheritance.py @@ -0,0 +1,52 @@ +import unittest +import datetime +import pytz + +from infi.clickhouse_orm.models import Model +from infi.clickhouse_orm.fields import * +from infi.clickhouse_orm.engines import * + + +class InheritanceTestCase(unittest.TestCase): + + def assertFieldNames(self, model_class, names): + self.assertEquals(names, [name for name, field in model_class._fields]) + + def test_field_inheritance(self): + self.assertFieldNames(ParentModel, ['date_field', 'int_field']) + self.assertFieldNames(Model1, ['date_field', 'int_field', 'string_field']) + self.assertFieldNames(Model2, ['date_field', 'int_field', 'float_field']) + + def test_create_table_sql(self): + sql1 = ParentModel.create_table_sql('default') + sql2 = Model1.create_table_sql('default') + sql3 = Model2.create_table_sql('default') + self.assertNotEqual(sql1, sql2) + self.assertNotEqual(sql1, sql3) + self.assertNotEqual(sql2, sql3) + + def test_get_field(self): + self.assertIsNotNone(ParentModel().get_field('date_field')) + self.assertIsNone(ParentModel().get_field('string_field')) + self.assertIsNotNone(Model1().get_field('date_field')) + self.assertIsNotNone(Model1().get_field('string_field')) + self.assertIsNone(Model1().get_field('float_field')) + + +class ParentModel(Model): + + date_field = DateField() + int_field = Int32Field() + + engine = MergeTree('date_field', ('int_field', 'date_field')) + + +class Model1(ParentModel): + + string_field = StringField() + + +class Model2(ParentModel): + + float_field = Float32Field() +