mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-11 03:46:34 +03:00
Support model class inheritance
This commit is contained in:
parent
e469ed4e10
commit
6e786d75e9
|
@ -10,8 +10,13 @@ class ModelBase(type):
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
|
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
|
||||||
|
# Collect fields from parent classes
|
||||||
|
base_fields = []
|
||||||
|
for base in bases:
|
||||||
|
if isinstance(base, ModelBase):
|
||||||
|
base_fields += base._fields
|
||||||
# Build a list of fields, in the order they were listed in the class
|
# Build a list of fields, in the order they were listed in the class
|
||||||
fields = [item for item in attrs.items() if isinstance(item[1], Field)]
|
fields = base_fields + [item for item in attrs.items() if isinstance(item[1], Field)]
|
||||||
fields.sort(key=lambda item: item[1].creation_counter)
|
fields.sort(key=lambda item: item[1].creation_counter)
|
||||||
setattr(new_cls, '_fields', fields)
|
setattr(new_cls, '_fields', fields)
|
||||||
return new_cls
|
return new_cls
|
||||||
|
|
52
tests/test_inheritance.py
Normal file
52
tests/test_inheritance.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import unittest
|
||||||
|
import datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
from infi.clickhouse_orm.models import Model
|
||||||
|
from infi.clickhouse_orm.fields import *
|
||||||
|
from infi.clickhouse_orm.engines import *
|
||||||
|
|
||||||
|
|
||||||
|
class InheritanceTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def assertFieldNames(self, model_class, names):
|
||||||
|
self.assertEquals(names, [name for name, field in model_class._fields])
|
||||||
|
|
||||||
|
def test_field_inheritance(self):
|
||||||
|
self.assertFieldNames(ParentModel, ['date_field', 'int_field'])
|
||||||
|
self.assertFieldNames(Model1, ['date_field', 'int_field', 'string_field'])
|
||||||
|
self.assertFieldNames(Model2, ['date_field', 'int_field', 'float_field'])
|
||||||
|
|
||||||
|
def test_create_table_sql(self):
|
||||||
|
sql1 = ParentModel.create_table_sql('default')
|
||||||
|
sql2 = Model1.create_table_sql('default')
|
||||||
|
sql3 = Model2.create_table_sql('default')
|
||||||
|
self.assertNotEqual(sql1, sql2)
|
||||||
|
self.assertNotEqual(sql1, sql3)
|
||||||
|
self.assertNotEqual(sql2, sql3)
|
||||||
|
|
||||||
|
def test_get_field(self):
|
||||||
|
self.assertIsNotNone(ParentModel().get_field('date_field'))
|
||||||
|
self.assertIsNone(ParentModel().get_field('string_field'))
|
||||||
|
self.assertIsNotNone(Model1().get_field('date_field'))
|
||||||
|
self.assertIsNotNone(Model1().get_field('string_field'))
|
||||||
|
self.assertIsNone(Model1().get_field('float_field'))
|
||||||
|
|
||||||
|
|
||||||
|
class ParentModel(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
int_field = Int32Field()
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('int_field', 'date_field'))
|
||||||
|
|
||||||
|
|
||||||
|
class Model1(ParentModel):
|
||||||
|
|
||||||
|
string_field = StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class Model2(ParentModel):
|
||||||
|
|
||||||
|
float_field = Float32Field()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user