Field add db_column attribute

This commit is contained in:
sswest 2022-05-28 16:01:50 +08:00
parent 42678f06b9
commit f1c9562260
4 changed files with 52 additions and 39 deletions

View File

@ -53,8 +53,9 @@ class PointField(Field):
class_default = Point(0, 0) class_default = Point(0, 0)
db_type = 'Point' db_type = 'Point'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
super().__init__(default, alias, materialized, readonly, codec) db_column=None):
super().__init__(default, alias, materialized, readonly, codec, db_column)
self.inner_field = Float64Field() self.inner_field = Float64Field()
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):

View File

@ -26,7 +26,8 @@ class Field(FunctionOperatorsMixin):
class_default = 0 # should be overridden by concrete subclasses class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None):
assert [default, alias, materialized].count(None) >= 2, \ assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given" "Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
@ -38,6 +39,8 @@ class Field(FunctionOperatorsMixin):
readonly) is bool, "readonly parameter must be bool if given" readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \ assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given" "Codec field must be string, if given"
assert db_column is None or isinstance(db_column, str) and db_column != "", \
"db_column field must be string, if given"
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
@ -46,6 +49,7 @@ class Field(FunctionOperatorsMixin):
self.materialized = materialized self.materialized = materialized
self.readonly = bool(self.alias or self.materialized or readonly) self.readonly = bool(self.alias or self.materialized or readonly)
self.codec = codec self.codec = codec
self.db_column = db_column
def __str__(self): def __str__(self):
return self.name return self.name
@ -151,10 +155,11 @@ class StringField(Field):
class FixedStringField(StringField): class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None): def __init__(self, length, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
self._length = length self._length = length
self.db_type = 'FixedString(%d)' % length self.db_type = 'FixedString(%d)' % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly) super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
value = super(FixedStringField, self).to_python(value, timezone_in_use) value = super(FixedStringField, self).to_python(value, timezone_in_use)
@ -199,8 +204,8 @@ class DateTimeField(Field):
db_type = 'DateTime' db_type = 'DateTime'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None): db_column=None, timezone=None):
super().__init__(default, alias, materialized, readonly, codec) super().__init__(default, alias, materialized, readonly, codec, db_column)
# assert not timezone, 'Temporarily field timezone is not supported' # assert not timezone, 'Temporarily field timezone is not supported'
if timezone: if timezone:
timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone) timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone)
@ -248,8 +253,8 @@ class DateTime64Field(DateTimeField):
db_type = 'DateTime64' db_type = 'DateTime64'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None, precision=6): db_column=None, timezone=None, precision=6):
super().__init__(default, alias, materialized, readonly, codec, timezone) super().__init__(default, alias, materialized, readonly, codec, timezone, db_column)
assert precision is None or isinstance(precision, int), 'Precision must be int type' assert precision is None or isinstance(precision, int), 'Precision must be int type'
self.precision = precision self.precision = precision
@ -361,9 +366,9 @@ class Int64Field(BaseIntField):
class BaseFloatField(Field): class BaseFloatField(Field):
''' """
Abstract base class for all float-type fields. Abstract base class for all float-type fields.
''' """
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
try: try:
@ -391,7 +396,7 @@ class DecimalField(Field):
""" """
def __init__(self, precision, scale, default=None, alias=None, materialized=None, def __init__(self, precision, scale, default=None, alias=None, materialized=None,
readonly=None): readonly=None, db_column=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
self.precision = precision self.precision = precision
@ -402,7 +407,7 @@ class DecimalField(Field):
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
self.min_value = -self.max_value self.min_value = -self.max_value
super(DecimalField, self).__init__(default, alias, materialized, readonly) super(DecimalField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if not isinstance(value, Decimal): if not isinstance(value, Decimal):
@ -428,22 +433,25 @@ class DecimalField(Field):
class Decimal32Field(DecimalField): class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly) db_column=None):
super().__init__(9, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal32(%d)' % scale self.db_type = 'Decimal32(%d)' % scale
class Decimal64Field(DecimalField): class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly) db_column=None):
super().__init__(18, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal64(%d)' % scale self.db_type = 'Decimal64(%d)' % scale
class Decimal128Field(DecimalField): class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None,
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly) readonly=None, db_column=None):
super().__init__(38, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal128(%d)' % scale self.db_type = 'Decimal128(%d)' % scale
@ -453,11 +461,11 @@ class BaseEnumField(Field):
""" """
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None,
codec=None): codec=None, db_column=None):
self.enum_cls = enum_cls self.enum_cls = enum_cls
if default is None: if default is None:
default = list(enum_cls)[0] default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec) super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, self.enum_cls): if isinstance(value, self.enum_cls):
@ -514,13 +522,13 @@ class ArrayField(Field):
class_default = [] class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None): codec=None, db_column=None):
assert isinstance(inner_field, Field), \ assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance" "The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \ assert not isinstance(inner_field, ArrayField), \
"Multidimensional array fields are not supported by the ORM" "Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec) super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, str): if isinstance(value, str):

View File

@ -3,7 +3,7 @@ import sys
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
from logging import getLogger from logging import getLogger
from typing import TypeVar from typing import TypeVar, Dict
import pytz import pytz
@ -120,9 +120,9 @@ class Index:
class ModelBase(type): class ModelBase(type):
''' """
A metaclass for ORM models. It adds the _fields list to model classes. A metaclass for ORM models. It adds the _fields list to model classes.
''' """
ad_hoc_model_cache = {} ad_hoc_model_cache = {}
@ -141,7 +141,7 @@ class ModelBase(type):
# Add fields, constraints and indexes from this class # Add fields, constraints and indexes from this class
for n, obj in attrs.items(): for n, obj in attrs.items():
if isinstance(obj, Field): if isinstance(obj, Field):
fields[n] = obj fields[obj.db_column or n] = obj
elif isinstance(obj, Constraint): elif isinstance(obj, Constraint):
constraints[n] = obj constraints[n] = obj
elif isinstance(obj, Index): elif isinstance(obj, Index):
@ -201,6 +201,8 @@ class ModelBase(type):
@classmethod @classmethod
def create_ad_hoc_field(cls, db_type): def create_ad_hoc_field(cls, db_type):
import clickhouse_orm.fields as orm_fields import clickhouse_orm.fields as orm_fields
import clickhouse_orm.contrib.geo.fields as geo_fields
# Enums # Enums
if db_type.startswith('Enum'): if db_type.startswith('Enum'):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
@ -229,7 +231,7 @@ class ModelBase(type):
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# FixedString # FixedString
if db_type.startswith('FixedString'): if db_type.startswith('FixedString'):
length = int(db_type[12 : -1]) length = int(db_type[12:-1])
return orm_fields.FixedStringField(length) return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128 # Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'): if db_type.startswith('Decimal'):
@ -247,9 +249,9 @@ class ModelBase(type):
return orm_fields.LowCardinalityField(inner_field) return orm_fields.LowCardinalityField(inner_field)
# Simple fields # Simple fields
name = db_type + 'Field' name = db_type + 'Field'
if not hasattr(orm_fields, name): if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)):
raise NotImplementedError('No field class for %s' % db_type) raise NotImplementedError('No field class for %s' % db_type)
return getattr(orm_fields, name)() return getattr(orm_fields, name, getattr(geo_fields, name))()
class Model(metaclass=ModelBase): class Model(metaclass=ModelBase):
@ -273,6 +275,8 @@ class Model(metaclass=ModelBase):
_database = None _database = None
_fields: Dict[str, Field]
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Creates a model instance, using keyword arguments as field values. Creates a model instance, using keyword arguments as field values.

View File

@ -43,7 +43,7 @@ class SimpleOperator(Operator):
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None: if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null]) return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value]) return ' '.join([field.name, self._sql_operator, value])
class InOperator(Operator): class InOperator(Operator):
@ -63,7 +63,7 @@ class InOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value) return '%s IN (%s)' % (field.name, value)
class GlobalInOperator(Operator): class GlobalInOperator(Operator):
@ -77,7 +77,7 @@ class GlobalInOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s GLOBAL IN (%s)' % (field_name, value) return '%s GLOBAL IN (%s)' % (field.name, value)
class LikeOperator(Operator): class LikeOperator(Operator):
@ -96,9 +96,9 @@ class LikeOperator(Operator):
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
pattern = self._pattern.format(value) pattern = self._pattern.format(value)
if self._case_sensitive: if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern) return '%s LIKE \'%s\'' % (field.name, pattern)
else: else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern) return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern)
class IExactOperator(Operator): class IExactOperator(Operator):
@ -109,7 +109,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value)
class NotOperator(Operator): class NotOperator(Operator):
@ -142,11 +142,11 @@ class BetweenOperator(Operator):
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len( value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(
str(value[1])) > 0 else None str(value[1])) > 0 else None
if value0 and value1: if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1) return '%s BETWEEN %s AND %s' % (field.name, value0, value1)
if value0 and not value1: if value0 and not value1:
return ' '.join([field_name, '>=', value0]) return ' '.join([field.name, '>=', value0])
if value1 and not value0: if value1 and not value0:
return ' '.join([field_name, '<=', value1]) return ' '.join([field.name, '<=', value1])
# Define the set of builtin operators # Define the set of builtin operators