diff --git a/src/clickhouse_orm/contrib/geo/fields.py b/src/clickhouse_orm/contrib/geo/fields.py index 1c6a3d1..9eecd0a 100644 --- a/src/clickhouse_orm/contrib/geo/fields.py +++ b/src/clickhouse_orm/contrib/geo/fields.py @@ -53,8 +53,9 @@ class PointField(Field): class_default = Point(0, 0) db_type = 'Point' - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): - super().__init__(default, alias, materialized, readonly, codec) + def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, + db_column=None): + super().__init__(default, alias, materialized, readonly, codec, db_column) self.inner_field = Float64Field() def to_python(self, value, timezone_in_use): diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index c2c9194..964345a 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -26,7 +26,8 @@ class Field(FunctionOperatorsMixin): class_default = 0 # should be overridden by concrete subclasses db_type = None # should be overridden by concrete subclasses - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): + def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, + db_column=None): assert [default, alias, materialized].count(None) >= 2, \ "Only one of default, alias and materialized parameters can be given" assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ @@ -38,6 +39,8 @@ class Field(FunctionOperatorsMixin): readonly) is bool, "readonly parameter must be bool if given" assert codec is None or isinstance(codec, str) and codec != "", \ "Codec field must be string, if given" + assert db_column is None or isinstance(db_column, str) and db_column != "", \ + "db_column field must be string, if given" self.creation_counter = Field.creation_counter Field.creation_counter += 1 @@ -46,6 +49,7 @@ class Field(FunctionOperatorsMixin): self.materialized = materialized self.readonly = bool(self.alias or self.materialized or readonly) self.codec = codec + self.db_column = db_column def __str__(self): return self.name @@ -151,10 +155,11 @@ class StringField(Field): class FixedStringField(StringField): - def __init__(self, length, default=None, alias=None, materialized=None, readonly=None): + def __init__(self, length, default=None, alias=None, materialized=None, readonly=None, + db_column=None): self._length = length self.db_type = 'FixedString(%d)' % length - super(FixedStringField, self).__init__(default, alias, materialized, readonly) + super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column) def to_python(self, value, timezone_in_use): value = super(FixedStringField, self).to_python(value, timezone_in_use) @@ -199,8 +204,8 @@ class DateTimeField(Field): db_type = 'DateTime' def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - timezone=None): - super().__init__(default, alias, materialized, readonly, codec) + db_column=None, timezone=None): + super().__init__(default, alias, materialized, readonly, codec, db_column) # assert not timezone, 'Temporarily field timezone is not supported' if timezone: timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone) @@ -248,8 +253,8 @@ class DateTime64Field(DateTimeField): db_type = 'DateTime64' def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - timezone=None, precision=6): - super().__init__(default, alias, materialized, readonly, codec, timezone) + db_column=None, timezone=None, precision=6): + super().__init__(default, alias, materialized, readonly, codec, timezone, db_column) assert precision is None or isinstance(precision, int), 'Precision must be int type' self.precision = precision @@ -361,9 +366,9 @@ class Int64Field(BaseIntField): class BaseFloatField(Field): - ''' + """ Abstract base class for all float-type fields. - ''' + """ def to_python(self, value, timezone_in_use): try: @@ -391,7 +396,7 @@ class DecimalField(Field): """ def __init__(self, precision, scale, default=None, alias=None, materialized=None, - readonly=None): + readonly=None, db_column=None): assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' self.precision = precision @@ -402,7 +407,7 @@ class DecimalField(Field): self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.min_value = -self.max_value - super(DecimalField, self).__init__(default, alias, materialized, readonly) + super(DecimalField, self).__init__(default, alias, materialized, readonly, db_column) def to_python(self, value, timezone_in_use): if not isinstance(value, Decimal): @@ -428,22 +433,25 @@ class DecimalField(Field): class Decimal32Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): - super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly) + def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None, + db_column=None): + super().__init__(9, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal32(%d)' % scale class Decimal64Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): - super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly) + def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None, + db_column=None): + super().__init__(18, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal64(%d)' % scale class Decimal128Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): - super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly) + def __init__(self, scale, default=None, alias=None, materialized=None, + readonly=None, db_column=None): + super().__init__(38, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal128(%d)' % scale @@ -453,11 +461,11 @@ class BaseEnumField(Field): """ def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, - codec=None): + codec=None, db_column=None): self.enum_cls = enum_cls if default is None: default = list(enum_cls)[0] - super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec) + super().__init__(default, alias, materialized, readonly, codec, db_column) def to_python(self, value, timezone_in_use): if isinstance(value, self.enum_cls): @@ -514,13 +522,13 @@ class ArrayField(Field): class_default = [] def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, - codec=None): + codec=None, db_column=None): assert isinstance(inner_field, Field), \ "The first argument of ArrayField must be a Field instance" assert not isinstance(inner_field, ArrayField), \ "Multidimensional array fields are not supported by the ORM" self.inner_field = inner_field - super(ArrayField, self).__init__(default, alias, materialized, readonly, codec) + super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column) def to_python(self, value, timezone_in_use): if isinstance(value, str): diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index a9f7eb8..4a3d3ec 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -3,7 +3,7 @@ import sys from collections import OrderedDict from itertools import chain from logging import getLogger -from typing import TypeVar +from typing import TypeVar, Dict import pytz @@ -120,9 +120,9 @@ class Index: class ModelBase(type): - ''' + """ A metaclass for ORM models. It adds the _fields list to model classes. - ''' + """ ad_hoc_model_cache = {} @@ -141,7 +141,7 @@ class ModelBase(type): # Add fields, constraints and indexes from this class for n, obj in attrs.items(): if isinstance(obj, Field): - fields[n] = obj + fields[obj.db_column or n] = obj elif isinstance(obj, Constraint): constraints[n] = obj elif isinstance(obj, Index): @@ -201,6 +201,8 @@ class ModelBase(type): @classmethod def create_ad_hoc_field(cls, db_type): import clickhouse_orm.fields as orm_fields + import clickhouse_orm.contrib.geo.fields as geo_fields + # Enums if db_type.startswith('Enum'): return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) @@ -229,7 +231,7 @@ class ModelBase(type): return orm_fields.ArrayField(inner_field) # FixedString if db_type.startswith('FixedString'): - length = int(db_type[12 : -1]) + length = int(db_type[12:-1]) return orm_fields.FixedStringField(length) # Decimal / Decimal32 / Decimal64 / Decimal128 if db_type.startswith('Decimal'): @@ -247,9 +249,9 @@ class ModelBase(type): return orm_fields.LowCardinalityField(inner_field) # Simple fields name = db_type + 'Field' - if not hasattr(orm_fields, name): + if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)): raise NotImplementedError('No field class for %s' % db_type) - return getattr(orm_fields, name)() + return getattr(orm_fields, name, getattr(geo_fields, name))() class Model(metaclass=ModelBase): @@ -273,6 +275,8 @@ class Model(metaclass=ModelBase): _database = None + _fields: Dict[str, Field] + def __init__(self, **kwargs): """ Creates a model instance, using keyword arguments as field values. diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index 41b180a..a842d1e 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -43,7 +43,7 @@ class SimpleOperator(Operator): value = self._value_to_sql(field, value) if value == '\\N' and self._sql_for_null is not None: return ' '.join([field_name, self._sql_for_null]) - return ' '.join([field_name, self._sql_operator, value]) + return ' '.join([field.name, self._sql_operator, value]) class InOperator(Operator): @@ -63,7 +63,7 @@ class InOperator(Operator): pass else: value = comma_join([self._value_to_sql(field, v) for v in value]) - return '%s IN (%s)' % (field_name, value) + return '%s IN (%s)' % (field.name, value) class GlobalInOperator(Operator): @@ -77,7 +77,7 @@ class GlobalInOperator(Operator): pass else: value = comma_join([self._value_to_sql(field, v) for v in value]) - return '%s GLOBAL IN (%s)' % (field_name, value) + return '%s GLOBAL IN (%s)' % (field.name, value) class LikeOperator(Operator): @@ -96,9 +96,9 @@ class LikeOperator(Operator): value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') pattern = self._pattern.format(value) if self._case_sensitive: - return '%s LIKE \'%s\'' % (field_name, pattern) + return '%s LIKE \'%s\'' % (field.name, pattern) else: - return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern) + return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern) class IExactOperator(Operator): @@ -109,7 +109,7 @@ class IExactOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) - return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) + return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value) class NotOperator(Operator): @@ -142,11 +142,11 @@ class BetweenOperator(Operator): value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len( str(value[1])) > 0 else None if value0 and value1: - return '%s BETWEEN %s AND %s' % (field_name, value0, value1) + return '%s BETWEEN %s AND %s' % (field.name, value0, value1) if value0 and not value1: - return ' '.join([field_name, '>=', value0]) + return ' '.join([field.name, '>=', value0]) if value1 and not value0: - return ' '.join([field_name, '<=', value1]) + return ' '.join([field.name, '<=', value1]) # Define the set of builtin operators