Support model class inheritance

This commit is contained in:
Itai Shirav 2016-06-29 14:52:55 +03:00
parent e469ed4e10
commit 6e786d75e9
2 changed files with 58 additions and 1 deletions

View File

@ -10,8 +10,13 @@ class ModelBase(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
# Collect fields from parent classes
base_fields = []
for base in bases:
if isinstance(base, ModelBase):
base_fields += base._fields
# Build a list of fields, in the order they were listed in the class # Build a list of fields, in the order they were listed in the class
fields = [item for item in attrs.items() if isinstance(item[1], Field)] fields = base_fields + [item for item in attrs.items() if isinstance(item[1], Field)]
fields.sort(key=lambda item: item[1].creation_counter) fields.sort(key=lambda item: item[1].creation_counter)
setattr(new_cls, '_fields', fields) setattr(new_cls, '_fields', fields)
return new_cls return new_cls

52
tests/test_inheritance.py Normal file
View File

@ -0,0 +1,52 @@
import unittest
import datetime
import pytz
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class InheritanceTestCase(unittest.TestCase):
def assertFieldNames(self, model_class, names):
self.assertEquals(names, [name for name, field in model_class._fields])
def test_field_inheritance(self):
self.assertFieldNames(ParentModel, ['date_field', 'int_field'])
self.assertFieldNames(Model1, ['date_field', 'int_field', 'string_field'])
self.assertFieldNames(Model2, ['date_field', 'int_field', 'float_field'])
def test_create_table_sql(self):
sql1 = ParentModel.create_table_sql('default')
sql2 = Model1.create_table_sql('default')
sql3 = Model2.create_table_sql('default')
self.assertNotEqual(sql1, sql2)
self.assertNotEqual(sql1, sql3)
self.assertNotEqual(sql2, sql3)
def test_get_field(self):
self.assertIsNotNone(ParentModel().get_field('date_field'))
self.assertIsNone(ParentModel().get_field('string_field'))
self.assertIsNotNone(Model1().get_field('date_field'))
self.assertIsNotNone(Model1().get_field('string_field'))
self.assertIsNone(Model1().get_field('float_field'))
class ParentModel(Model):
date_field = DateField()
int_field = Int32Field()
engine = MergeTree('date_field', ('int_field', 'date_field'))
class Model1(ParentModel):
string_field = StringField()
class Model2(ParentModel):
float_field = Float32Field()