diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index 0ef6b57..49dcdaa 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -116,7 +116,7 @@ class Database(object): ''' Creates a table for the given model class, if it does not exist already. ''' - if model_class.system: + if model_class.is_system_model(): raise DatabaseException("You can't create system table") if getattr(model_class, 'engine') is None: raise DatabaseException("%s class must define an engine" % model_class.__name__) @@ -126,7 +126,7 @@ class Database(object): ''' Drops the database table of the given model class, if it exists. ''' - if model_class.system: + if model_class.is_system_model(): raise DatabaseException("You can't drop system table") self._send(model_class.drop_table_sql(self)) @@ -146,7 +146,7 @@ class Database(object): return # model_instances is empty model_class = first_instance.__class__ - if first_instance.readonly or first_instance.system: + if first_instance.is_read_only() or first_instance.is_system_model(): raise DatabaseException("You can't insert into read only and system tables") fields_list = ','.join( diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index f5e7377..c2ab40d 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -22,7 +22,6 @@ class ModelBase(type): ad_hoc_model_cache = {} def __new__(cls, name, bases, attrs): - new_cls = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) # Collect fields from parent classes base_fields = dict() for base in bases: @@ -35,9 +34,12 @@ class ModelBase(type): fields.update({n: f for n, f in iteritems(attrs) if isinstance(f, Field)}) fields = sorted(iteritems(fields), key=lambda item: item[1].creation_counter) - setattr(new_cls, '_fields', OrderedDict(fields)) - setattr(new_cls, '_writable_fields', OrderedDict([f for f in fields if not f[1].readonly])) - return new_cls + attrs = dict( + attrs, + _fields=OrderedDict(fields), + _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), + ) + return super(ModelBase, cls).__new__(cls, str(name), bases, attrs) @classmethod def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): @@ -99,10 +101,12 @@ class Model(with_metaclass(ModelBase)): engine = None # Insert operations are restricted for read only models - readonly = False + _readonly = False # Create table, drop table, insert operations are restricted for system models - system = False + _system = False + + _database = None def __init__(self, **kwargs): ''' @@ -198,11 +202,10 @@ class Model(with_metaclass(ModelBase)): return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) @classmethod - def from_tsv(cls, line, field_names=None, timezone_in_use=pytz.utc, database=None): + def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): ''' Create a model instance from a tab-separated line. The line may or may not include a newline. The `field_names` list must match the fields defined in the model, but does not have to include all of them. - If omitted, it is assumed to be the names of all fields in the model, in order of definition. - `line`: the TSV-formatted data. - `field_names`: names of the model fields in the data. @@ -210,7 +213,6 @@ class Model(with_metaclass(ModelBase)): - `database`: if given, sets the database that this instance belongs to. ''' from six import next - field_names = field_names or list(cls.fields()) values = iter(parse_tsv(line)) kwargs = {} for name in field_names: @@ -265,6 +267,20 @@ class Model(with_metaclass(ModelBase)): # noinspection PyProtectedMember,PyUnresolvedReferences return cls._writable_fields if writable else cls._fields + @classmethod + def is_read_only(cls): + ''' + Returns true if the model is marked as read only. + ''' + return cls._readonly + + @classmethod + def is_system_model(cls): + ''' + Returns true if the model represents a system table. + ''' + return cls._system + class BufferModel(Model): diff --git a/tests/test_database.py b/tests/test_database.py index 77bf1f0..331bcef 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -164,3 +164,15 @@ class DatabaseTestCase(TestCaseWithData): self.database.create_table(EnginelessModel) self.assertEqual(cm.exception.message, 'EnginelessModel class must define an engine') + def test_potentially_problematic_field_names(self): + class Model1(Model): + system = StringField() + readonly = StringField() + engine = Memory() + instance = Model1(system='s', readonly='r') + self.assertEquals(instance.to_dict(), dict(system='s', readonly='r')) + self.database.create_table(Model1) + self.database.insert([instance]) + instance = Model1.objects_in(self.database)[0] + self.assertEquals(instance.to_dict(), dict(system='s', readonly='r')) + diff --git a/tests/test_readonly.py b/tests/test_readonly.py index 62a2bf5..73c7d26 100644 --- a/tests/test_readonly.py +++ b/tests/test_readonly.py @@ -48,6 +48,7 @@ class ReadonlyTestCase(TestCaseWithData): def test_insert_readonly(self): m = ReadOnlyModel(name='readonly') + self.database.create_table(ReadOnlyModel) with self.assertRaises(DatabaseException): self.database.insert([m]) @@ -59,7 +60,7 @@ class ReadonlyTestCase(TestCaseWithData): class ReadOnlyModel(Model): - readonly = True + _readonly = True name = StringField() date = DateField() diff --git a/tests/test_system_models.py b/tests/test_system_models.py index b49cc52..b9576ac 100644 --- a/tests/test_system_models.py +++ b/tests/test_system_models.py @@ -116,4 +116,4 @@ class CustomPartitionedTable(Model): class SystemTestModel(Model): - system = True + _system = True