Field add db_column attribute

This commit is contained in:
sswest 2022-05-28 16:01:50 +08:00
parent 42678f06b9
commit f1c9562260
4 changed files with 52 additions and 39 deletions

View File

@ -53,8 +53,9 @@ class PointField(Field):
class_default = Point(0, 0)
db_type = 'Point'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
super().__init__(default, alias, materialized, readonly, codec)
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None):
super().__init__(default, alias, materialized, readonly, codec, db_column)
self.inner_field = Float64Field()
def to_python(self, value, timezone_in_use):

View File

@ -26,7 +26,8 @@ class Field(FunctionOperatorsMixin):
class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None):
assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
@ -38,6 +39,8 @@ class Field(FunctionOperatorsMixin):
readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given"
assert db_column is None or isinstance(db_column, str) and db_column != "", \
"db_column field must be string, if given"
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
@ -46,6 +49,7 @@ class Field(FunctionOperatorsMixin):
self.materialized = materialized
self.readonly = bool(self.alias or self.materialized or readonly)
self.codec = codec
self.db_column = db_column
def __str__(self):
return self.name
@ -151,10 +155,11 @@ class StringField(Field):
class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
self._length = length
self.db_type = 'FixedString(%d)' % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use):
value = super(FixedStringField, self).to_python(value, timezone_in_use)
@ -199,8 +204,8 @@ class DateTimeField(Field):
db_type = 'DateTime'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None):
super().__init__(default, alias, materialized, readonly, codec)
db_column=None, timezone=None):
super().__init__(default, alias, materialized, readonly, codec, db_column)
# assert not timezone, 'Temporarily field timezone is not supported'
if timezone:
timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone)
@ -248,8 +253,8 @@ class DateTime64Field(DateTimeField):
db_type = 'DateTime64'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None, precision=6):
super().__init__(default, alias, materialized, readonly, codec, timezone)
db_column=None, timezone=None, precision=6):
super().__init__(default, alias, materialized, readonly, codec, timezone, db_column)
assert precision is None or isinstance(precision, int), 'Precision must be int type'
self.precision = precision
@ -361,9 +366,9 @@ class Int64Field(BaseIntField):
class BaseFloatField(Field):
'''
"""
Abstract base class for all float-type fields.
'''
"""
def to_python(self, value, timezone_in_use):
try:
@ -391,7 +396,7 @@ class DecimalField(Field):
"""
def __init__(self, precision, scale, default=None, alias=None, materialized=None,
readonly=None):
readonly=None, db_column=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
self.precision = precision
@ -402,7 +407,7 @@ class DecimalField(Field):
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
self.min_value = -self.max_value
super(DecimalField, self).__init__(default, alias, materialized, readonly)
super(DecimalField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use):
if not isinstance(value, Decimal):
@ -428,22 +433,25 @@ class DecimalField(Field):
class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
super().__init__(9, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal32(%d)' % scale
class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
super().__init__(18, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal64(%d)' % scale
class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
def __init__(self, scale, default=None, alias=None, materialized=None,
readonly=None, db_column=None):
super().__init__(38, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal128(%d)' % scale
@ -453,11 +461,11 @@ class BaseEnumField(Field):
"""
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None,
codec=None):
codec=None, db_column=None):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec)
super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use):
if isinstance(value, self.enum_cls):
@ -514,13 +522,13 @@ class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None):
codec=None, db_column=None):
assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \
"Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use):
if isinstance(value, str):

View File

@ -3,7 +3,7 @@ import sys
from collections import OrderedDict
from itertools import chain
from logging import getLogger
from typing import TypeVar
from typing import TypeVar, Dict
import pytz
@ -120,9 +120,9 @@ class Index:
class ModelBase(type):
'''
"""
A metaclass for ORM models. It adds the _fields list to model classes.
'''
"""
ad_hoc_model_cache = {}
@ -141,7 +141,7 @@ class ModelBase(type):
# Add fields, constraints and indexes from this class
for n, obj in attrs.items():
if isinstance(obj, Field):
fields[n] = obj
fields[obj.db_column or n] = obj
elif isinstance(obj, Constraint):
constraints[n] = obj
elif isinstance(obj, Index):
@ -201,6 +201,8 @@ class ModelBase(type):
@classmethod
def create_ad_hoc_field(cls, db_type):
import clickhouse_orm.fields as orm_fields
import clickhouse_orm.contrib.geo.fields as geo_fields
# Enums
if db_type.startswith('Enum'):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
@ -247,9 +249,9 @@ class ModelBase(type):
return orm_fields.LowCardinalityField(inner_field)
# Simple fields
name = db_type + 'Field'
if not hasattr(orm_fields, name):
if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)):
raise NotImplementedError('No field class for %s' % db_type)
return getattr(orm_fields, name)()
return getattr(orm_fields, name, getattr(geo_fields, name))()
class Model(metaclass=ModelBase):
@ -273,6 +275,8 @@ class Model(metaclass=ModelBase):
_database = None
_fields: Dict[str, Field]
def __init__(self, **kwargs):
"""
Creates a model instance, using keyword arguments as field values.

View File

@ -43,7 +43,7 @@ class SimpleOperator(Operator):
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value])
return ' '.join([field.name, self._sql_operator, value])
class InOperator(Operator):
@ -63,7 +63,7 @@ class InOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value)
return '%s IN (%s)' % (field.name, value)
class GlobalInOperator(Operator):
@ -77,7 +77,7 @@ class GlobalInOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s GLOBAL IN (%s)' % (field_name, value)
return '%s GLOBAL IN (%s)' % (field.name, value)
class LikeOperator(Operator):
@ -96,9 +96,9 @@ class LikeOperator(Operator):
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
pattern = self._pattern.format(value)
if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern)
return '%s LIKE \'%s\'' % (field.name, pattern)
else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern)
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern)
class IExactOperator(Operator):
@ -109,7 +109,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value)
class NotOperator(Operator):
@ -142,11 +142,11 @@ class BetweenOperator(Operator):
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(
str(value[1])) > 0 else None
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
return '%s BETWEEN %s AND %s' % (field.name, value0, value1)
if value0 and not value1:
return ' '.join([field_name, '>=', value0])
return ' '.join([field.name, '>=', value0])
if value1 and not value0:
return ' '.join([field_name, '<=', value1])
return ' '.join([field.name, '<=', value1])
# Define the set of builtin operators