mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-03 11:40:20 +03:00
Chore: fix linting on fields.py
This commit is contained in:
parent
ce68a8f55b
commit
aab92d88aa
|
@ -1,39 +1,45 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import iso8601
|
|
||||||
import pytz
|
|
||||||
from calendar import timegm
|
from calendar import timegm
|
||||||
from decimal import Decimal, localcontext
|
from decimal import Decimal, localcontext
|
||||||
from uuid import UUID
|
|
||||||
from logging import getLogger
|
|
||||||
from pytz import BaseTzInfo
|
|
||||||
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
|
|
||||||
from .funcs import F, FunctionOperatorsMixin
|
|
||||||
from ipaddress import IPv4Address, IPv6Address
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
|
from logging import getLogger
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
logger = getLogger('clickhouse_orm')
|
import iso8601
|
||||||
|
import pytz
|
||||||
|
from pytz import BaseTzInfo
|
||||||
|
|
||||||
|
from .funcs import F, FunctionOperatorsMixin
|
||||||
|
from .utils import comma_join, escape, get_subclass_names, parse_array, string_or_func
|
||||||
|
|
||||||
|
logger = getLogger("clickhouse_orm")
|
||||||
|
|
||||||
|
|
||||||
class Field(FunctionOperatorsMixin):
|
class Field(FunctionOperatorsMixin):
|
||||||
'''
|
"""
|
||||||
Abstract base class for all field types.
|
Abstract base class for all field types.
|
||||||
'''
|
"""
|
||||||
name = None # this is set by the parent model
|
|
||||||
parent = None # this is set by the parent model
|
name = None # this is set by the parent model
|
||||||
creation_counter = 0 # used for keeping the model fields ordered
|
parent = None # this is set by the parent model
|
||||||
class_default = 0 # should be overridden by concrete subclasses
|
creation_counter = 0 # used for keeping the model fields ordered
|
||||||
db_type = None # should be overridden by concrete subclasses
|
class_default = 0 # should be overridden by concrete subclasses
|
||||||
|
db_type = None # should be overridden by concrete subclasses
|
||||||
|
|
||||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||||
assert [default, alias, materialized].count(None) >= 2, \
|
assert [default, alias, materialized].count(
|
||||||
"Only one of default, alias and materialized parameters can be given"
|
None
|
||||||
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\
|
) >= 2, "Only one of default, alias and materialized parameters can be given"
|
||||||
"Alias parameter must be a string or function object, if given"
|
assert (
|
||||||
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\
|
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
|
||||||
"Materialized parameter must be a string or function object, if given"
|
), "Alias parameter must be a string or function object, if given"
|
||||||
|
assert (
|
||||||
|
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
|
||||||
|
), "Materialized parameter must be a string or function object, if given"
|
||||||
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
|
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
|
||||||
assert codec is None or isinstance(codec, str) and codec != "", \
|
assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
|
||||||
"Codec field must be string, if given"
|
|
||||||
|
|
||||||
self.creation_counter = Field.creation_counter
|
self.creation_counter = Field.creation_counter
|
||||||
Field.creation_counter += 1
|
Field.creation_counter += 1
|
||||||
|
@ -47,49 +53,51 @@ class Field(FunctionOperatorsMixin):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s>' % self.__class__.__name__
|
return "<%s>" % self.__class__.__name__
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
'''
|
"""
|
||||||
Converts the input value into the expected Python data type, raising ValueError if the
|
Converts the input value into the expected Python data type, raising ValueError if the
|
||||||
data can't be converted. Returns the converted value. Subclasses should override this.
|
data can't be converted. Returns the converted value. Subclasses should override this.
|
||||||
The timezone_in_use parameter should be consulted when parsing datetime fields.
|
The timezone_in_use parameter should be consulted when parsing datetime fields.
|
||||||
'''
|
"""
|
||||||
return value # pragma: no cover
|
return value # pragma: no cover
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
'''
|
"""
|
||||||
Called after to_python to validate that the value is suitable for the field's database type.
|
Called after to_python to validate that the value is suitable for the field's database type.
|
||||||
Subclasses should override this.
|
Subclasses should override this.
|
||||||
'''
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _range_check(self, value, min_value, max_value):
|
def _range_check(self, value, min_value, max_value):
|
||||||
'''
|
"""
|
||||||
Utility method to check that the given value is between min_value and max_value.
|
Utility method to check that the given value is between min_value and max_value.
|
||||||
'''
|
"""
|
||||||
if value < min_value or value > max_value:
|
if value < min_value or value > max_value:
|
||||||
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
|
raise ValueError(
|
||||||
|
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
|
||||||
|
)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
'''
|
"""
|
||||||
Returns the field's value prepared for writing to the database.
|
Returns the field's value prepared for writing to the database.
|
||||||
When quote is true, strings are surrounded by single quotes.
|
When quote is true, strings are surrounded by single quotes.
|
||||||
'''
|
"""
|
||||||
return escape(value, quote)
|
return escape(value, quote)
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
'''
|
"""
|
||||||
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
|
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
|
||||||
|
|
||||||
- `with_default_expression`: If True, adds default value to sql.
|
- `with_default_expression`: If True, adds default value to sql.
|
||||||
It doesn't affect fields with alias and materialized values.
|
It doesn't affect fields with alias and materialized values.
|
||||||
- `db`: Database, used for checking supported features.
|
- `db`: Database, used for checking supported features.
|
||||||
'''
|
"""
|
||||||
sql = self.db_type
|
sql = self.db_type
|
||||||
args = self.get_db_type_args()
|
args = self.get_db_type_args()
|
||||||
if args:
|
if args:
|
||||||
sql += '(%s)' % comma_join(args)
|
sql += "(%s)" % comma_join(args)
|
||||||
if with_default_expression:
|
if with_default_expression:
|
||||||
sql += self._extra_params(db)
|
sql += self._extra_params(db)
|
||||||
return sql
|
return sql
|
||||||
|
@ -99,18 +107,18 @@ class Field(FunctionOperatorsMixin):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _extra_params(self, db):
|
def _extra_params(self, db):
|
||||||
sql = ''
|
sql = ""
|
||||||
if self.alias:
|
if self.alias:
|
||||||
sql += ' ALIAS %s' % string_or_func(self.alias)
|
sql += " ALIAS %s" % string_or_func(self.alias)
|
||||||
elif self.materialized:
|
elif self.materialized:
|
||||||
sql += ' MATERIALIZED %s' % string_or_func(self.materialized)
|
sql += " MATERIALIZED %s" % string_or_func(self.materialized)
|
||||||
elif isinstance(self.default, F):
|
elif isinstance(self.default, F):
|
||||||
sql += ' DEFAULT %s' % self.default.to_sql()
|
sql += " DEFAULT %s" % self.default.to_sql()
|
||||||
elif self.default:
|
elif self.default:
|
||||||
default = self.to_db_string(self.default)
|
default = self.to_db_string(self.default)
|
||||||
sql += ' DEFAULT %s' % default
|
sql += " DEFAULT %s" % default
|
||||||
if self.codec and db and db.has_codec_support:
|
if self.codec and db and db.has_codec_support:
|
||||||
sql += ' CODEC(%s)' % self.codec
|
sql += " CODEC(%s)" % self.codec
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
def isinstance(self, types):
|
def isinstance(self, types):
|
||||||
|
@ -124,43 +132,42 @@ class Field(FunctionOperatorsMixin):
|
||||||
"""
|
"""
|
||||||
if isinstance(self, types):
|
if isinstance(self, types):
|
||||||
return True
|
return True
|
||||||
inner_field = getattr(self, 'inner_field', None)
|
inner_field = getattr(self, "inner_field", None)
|
||||||
while inner_field:
|
while inner_field:
|
||||||
if isinstance(inner_field, types):
|
if isinstance(inner_field, types):
|
||||||
return True
|
return True
|
||||||
inner_field = getattr(inner_field, 'inner_field', None)
|
inner_field = getattr(inner_field, "inner_field", None)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class StringField(Field):
|
class StringField(Field):
|
||||||
|
|
||||||
class_default = ''
|
class_default = ""
|
||||||
db_type = 'String'
|
db_type = "String"
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
return value.decode('UTF-8')
|
return value.decode("UTF-8")
|
||||||
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
|
|
||||||
class FixedStringField(StringField):
|
class FixedStringField(StringField):
|
||||||
|
|
||||||
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
|
||||||
self._length = length
|
self._length = length
|
||||||
self.db_type = 'FixedString(%d)' % length
|
self.db_type = "FixedString(%d)" % length
|
||||||
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
|
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
value = super(FixedStringField, self).to_python(value, timezone_in_use)
|
value = super(FixedStringField, self).to_python(value, timezone_in_use)
|
||||||
return value.rstrip('\0')
|
return value.rstrip("\0")
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = value.encode('UTF-8')
|
value = value.encode("UTF-8")
|
||||||
if len(value) > self._length:
|
if len(value) > self._length:
|
||||||
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length))
|
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
||||||
|
|
||||||
|
|
||||||
class DateField(Field):
|
class DateField(Field):
|
||||||
|
@ -168,7 +175,7 @@ class DateField(Field):
|
||||||
min_value = datetime.date(1970, 1, 1)
|
min_value = datetime.date(1970, 1, 1)
|
||||||
max_value = datetime.date(2105, 12, 31)
|
max_value = datetime.date(2105, 12, 31)
|
||||||
class_default = min_value
|
class_default = min_value
|
||||||
db_type = 'Date'
|
db_type = "Date"
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
|
@ -178,10 +185,10 @@ class DateField(Field):
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return DateField.class_default + datetime.timedelta(days=value)
|
return DateField.class_default + datetime.timedelta(days=value)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value == '0000-00-00':
|
if value == "0000-00-00":
|
||||||
return DateField.min_value
|
return DateField.min_value
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
self._range_check(value, DateField.min_value, DateField.max_value)
|
self._range_check(value, DateField.min_value, DateField.max_value)
|
||||||
|
@ -193,10 +200,9 @@ class DateField(Field):
|
||||||
class DateTimeField(Field):
|
class DateTimeField(Field):
|
||||||
|
|
||||||
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
||||||
db_type = 'DateTime'
|
db_type = "DateTime"
|
||||||
|
|
||||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
|
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
|
||||||
timezone=None):
|
|
||||||
super().__init__(default, alias, materialized, readonly, codec)
|
super().__init__(default, alias, materialized, readonly, codec)
|
||||||
# assert not timezone, 'Temporarily field timezone is not supported'
|
# assert not timezone, 'Temporarily field timezone is not supported'
|
||||||
if timezone:
|
if timezone:
|
||||||
|
@ -217,7 +223,7 @@ class DateTimeField(Field):
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value == '0000-00-00 00:00:00':
|
if value == "0000-00-00 00:00:00":
|
||||||
return self.class_default
|
return self.class_default
|
||||||
if len(value) == 10:
|
if len(value) == 10:
|
||||||
try:
|
try:
|
||||||
|
@ -235,19 +241,20 @@ class DateTimeField(Field):
|
||||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||||
dt = timezone_in_use.localize(dt)
|
dt = timezone_in_use.localize(dt)
|
||||||
return dt
|
return dt
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape('%010d' % timegm(value.utctimetuple()), quote)
|
return escape("%010d" % timegm(value.utctimetuple()), quote)
|
||||||
|
|
||||||
|
|
||||||
class DateTime64Field(DateTimeField):
|
class DateTime64Field(DateTimeField):
|
||||||
db_type = 'DateTime64'
|
db_type = "DateTime64"
|
||||||
|
|
||||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
|
def __init__(
|
||||||
timezone=None, precision=6):
|
self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
|
||||||
|
):
|
||||||
super().__init__(default, alias, materialized, readonly, codec, timezone)
|
super().__init__(default, alias, materialized, readonly, codec, timezone)
|
||||||
assert precision is None or isinstance(precision, int), 'Precision must be int type'
|
assert precision is None or isinstance(precision, int), "Precision must be int type"
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
|
||||||
def get_db_type_args(self):
|
def get_db_type_args(self):
|
||||||
|
@ -263,11 +270,10 @@ class DateTime64Field(DateTimeField):
|
||||||
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
|
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
|
||||||
"""
|
"""
|
||||||
return escape(
|
return escape(
|
||||||
'{timestamp:0{width}.{precision}f}'.format(
|
"{timestamp:0{width}.{precision}f}".format(
|
||||||
timestamp=value.timestamp(),
|
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
|
||||||
width=11 + self.precision,
|
),
|
||||||
precision=self.precision),
|
quote,
|
||||||
quote
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
|
@ -277,8 +283,8 @@ class DateTime64Field(DateTimeField):
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
left_part = value.split('.')[0]
|
left_part = value.split(".")[0]
|
||||||
if left_part == '0000-00-00 00:00:00':
|
if left_part == "0000-00-00 00:00:00":
|
||||||
return self.class_default
|
return self.class_default
|
||||||
if len(left_part) == 10:
|
if len(left_part) == 10:
|
||||||
try:
|
try:
|
||||||
|
@ -290,14 +296,15 @@ class DateTime64Field(DateTimeField):
|
||||||
|
|
||||||
|
|
||||||
class BaseIntField(Field):
|
class BaseIntField(Field):
|
||||||
'''
|
"""
|
||||||
Abstract base class for all integer-type fields.
|
Abstract base class for all integer-type fields.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
# There's no need to call escape since numbers do not contain
|
# There's no need to call escape since numbers do not contain
|
||||||
|
@ -311,69 +318,69 @@ class BaseIntField(Field):
|
||||||
class UInt8Field(BaseIntField):
|
class UInt8Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2**8 - 1
|
max_value = 2 ** 8 - 1
|
||||||
db_type = 'UInt8'
|
db_type = "UInt8"
|
||||||
|
|
||||||
|
|
||||||
class UInt16Field(BaseIntField):
|
class UInt16Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2**16 - 1
|
max_value = 2 ** 16 - 1
|
||||||
db_type = 'UInt16'
|
db_type = "UInt16"
|
||||||
|
|
||||||
|
|
||||||
class UInt32Field(BaseIntField):
|
class UInt32Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2**32 - 1
|
max_value = 2 ** 32 - 1
|
||||||
db_type = 'UInt32'
|
db_type = "UInt32"
|
||||||
|
|
||||||
|
|
||||||
class UInt64Field(BaseIntField):
|
class UInt64Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2**64 - 1
|
max_value = 2 ** 64 - 1
|
||||||
db_type = 'UInt64'
|
db_type = "UInt64"
|
||||||
|
|
||||||
|
|
||||||
class Int8Field(BaseIntField):
|
class Int8Field(BaseIntField):
|
||||||
|
|
||||||
min_value = -2**7
|
min_value = -(2 ** 7)
|
||||||
max_value = 2**7 - 1
|
max_value = 2 ** 7 - 1
|
||||||
db_type = 'Int8'
|
db_type = "Int8"
|
||||||
|
|
||||||
|
|
||||||
class Int16Field(BaseIntField):
|
class Int16Field(BaseIntField):
|
||||||
|
|
||||||
min_value = -2**15
|
min_value = -(2 ** 15)
|
||||||
max_value = 2**15 - 1
|
max_value = 2 ** 15 - 1
|
||||||
db_type = 'Int16'
|
db_type = "Int16"
|
||||||
|
|
||||||
|
|
||||||
class Int32Field(BaseIntField):
|
class Int32Field(BaseIntField):
|
||||||
|
|
||||||
min_value = -2**31
|
min_value = -(2 ** 31)
|
||||||
max_value = 2**31 - 1
|
max_value = 2 ** 31 - 1
|
||||||
db_type = 'Int32'
|
db_type = "Int32"
|
||||||
|
|
||||||
|
|
||||||
class Int64Field(BaseIntField):
|
class Int64Field(BaseIntField):
|
||||||
|
|
||||||
min_value = -2**63
|
min_value = -(2 ** 63)
|
||||||
max_value = 2**63 - 1
|
max_value = 2 ** 63 - 1
|
||||||
db_type = 'Int64'
|
db_type = "Int64"
|
||||||
|
|
||||||
|
|
||||||
class BaseFloatField(Field):
|
class BaseFloatField(Field):
|
||||||
'''
|
"""
|
||||||
Abstract base class for all float-type fields.
|
Abstract base class for all float-type fields.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
try:
|
try:
|
||||||
return float(value)
|
return float(value)
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
# There's no need to call escape since numbers do not contain
|
# There's no need to call escape since numbers do not contain
|
||||||
|
@ -383,28 +390,28 @@ class BaseFloatField(Field):
|
||||||
|
|
||||||
class Float32Field(BaseFloatField):
|
class Float32Field(BaseFloatField):
|
||||||
|
|
||||||
db_type = 'Float32'
|
db_type = "Float32"
|
||||||
|
|
||||||
|
|
||||||
class Float64Field(BaseFloatField):
|
class Float64Field(BaseFloatField):
|
||||||
|
|
||||||
db_type = 'Float64'
|
db_type = "Float64"
|
||||||
|
|
||||||
|
|
||||||
class DecimalField(Field):
|
class DecimalField(Field):
|
||||||
'''
|
"""
|
||||||
Base class for all decimal fields. Can also be used directly.
|
Base class for all decimal fields. Can also be used directly.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||||
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
|
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
|
||||||
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
|
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale)
|
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
|
||||||
with localcontext() as ctx:
|
with localcontext() as ctx:
|
||||||
ctx.prec = 38
|
ctx.prec = 38
|
||||||
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
|
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
|
||||||
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
|
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
|
||||||
self.min_value = -self.max_value
|
self.min_value = -self.max_value
|
||||||
super(DecimalField, self).__init__(default, alias, materialized, readonly)
|
super(DecimalField, self).__init__(default, alias, materialized, readonly)
|
||||||
|
@ -413,10 +420,10 @@ class DecimalField(Field):
|
||||||
if not isinstance(value, Decimal):
|
if not isinstance(value, Decimal):
|
||||||
try:
|
try:
|
||||||
value = Decimal(value)
|
value = Decimal(value)
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||||
if not value.is_finite():
|
if not value.is_finite():
|
||||||
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
|
||||||
return self._round(value)
|
return self._round(value)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
|
@ -432,30 +439,27 @@ class DecimalField(Field):
|
||||||
|
|
||||||
|
|
||||||
class Decimal32Field(DecimalField):
|
class Decimal32Field(DecimalField):
|
||||||
|
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||||
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
|
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
|
||||||
self.db_type = 'Decimal32(%d)' % scale
|
self.db_type = "Decimal32(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
class Decimal64Field(DecimalField):
|
class Decimal64Field(DecimalField):
|
||||||
|
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||||
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
|
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
|
||||||
self.db_type = 'Decimal64(%d)' % scale
|
self.db_type = "Decimal64(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
class Decimal128Field(DecimalField):
|
class Decimal128Field(DecimalField):
|
||||||
|
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||||
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
|
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
|
||||||
self.db_type = 'Decimal128(%d)' % scale
|
self.db_type = "Decimal128(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
class BaseEnumField(Field):
|
class BaseEnumField(Field):
|
||||||
'''
|
"""
|
||||||
Abstract base class for all enum-type fields.
|
Abstract base class for all enum-type fields.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||||
self.enum_cls = enum_cls
|
self.enum_cls = enum_cls
|
||||||
|
@ -473,7 +477,7 @@ class BaseEnumField(Field):
|
||||||
except Exception:
|
except Exception:
|
||||||
return self.enum_cls(value)
|
return self.enum_cls(value)
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
decoded = value.decode('UTF-8')
|
decoded = value.decode("UTF-8")
|
||||||
try:
|
try:
|
||||||
return self.enum_cls[decoded]
|
return self.enum_cls[decoded]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -482,38 +486,39 @@ class BaseEnumField(Field):
|
||||||
return self.enum_cls(value)
|
return self.enum_cls(value)
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
pass
|
pass
|
||||||
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
|
raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape(value.name, quote)
|
return escape(value.name, quote)
|
||||||
|
|
||||||
def get_db_type_args(self):
|
def get_db_type_args(self):
|
||||||
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
|
return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ad_hoc_field(cls, db_type):
|
def create_ad_hoc_field(cls, db_type):
|
||||||
'''
|
"""
|
||||||
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
|
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
|
||||||
this method returns a matching enum field.
|
this method returns a matching enum field.
|
||||||
'''
|
"""
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
members = {}
|
members = {}
|
||||||
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
|
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
|
||||||
members[match.group(1)] = int(match.group(2))
|
members[match.group(1)] = int(match.group(2))
|
||||||
enum_cls = Enum('AdHocEnum', members)
|
enum_cls = Enum("AdHocEnum", members)
|
||||||
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
|
field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
|
||||||
return field_class(enum_cls)
|
return field_class(enum_cls)
|
||||||
|
|
||||||
|
|
||||||
class Enum8Field(BaseEnumField):
|
class Enum8Field(BaseEnumField):
|
||||||
|
|
||||||
db_type = 'Enum8'
|
db_type = "Enum8"
|
||||||
|
|
||||||
|
|
||||||
class Enum16Field(BaseEnumField):
|
class Enum16Field(BaseEnumField):
|
||||||
|
|
||||||
db_type = 'Enum16'
|
db_type = "Enum16"
|
||||||
|
|
||||||
|
|
||||||
class ArrayField(Field):
|
class ArrayField(Field):
|
||||||
|
@ -530,9 +535,9 @@ class ArrayField(Field):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = parse_array(value)
|
value = parse_array(value)
|
||||||
elif isinstance(value, bytes):
|
elif isinstance(value, bytes):
|
||||||
value = parse_array(value.decode('UTF-8'))
|
value = parse_array(value.decode("UTF-8"))
|
||||||
elif not isinstance(value, (list, tuple)):
|
elif not isinstance(value, (list, tuple)):
|
||||||
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
|
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
|
||||||
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
@ -541,19 +546,19 @@ class ArrayField(Field):
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
|
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
|
||||||
return '[' + comma_join(array) + ']'
|
return "[" + comma_join(array) + "]"
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
|
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||||
sql+= ' CODEC(%s)' % self.codec
|
sql += " CODEC(%s)" % self.codec
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
class UUIDField(Field):
|
class UUIDField(Field):
|
||||||
|
|
||||||
class_default = UUID(int=0)
|
class_default = UUID(int=0)
|
||||||
db_type = 'UUID'
|
db_type = "UUID"
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, UUID):
|
if isinstance(value, UUID):
|
||||||
|
@ -567,7 +572,7 @@ class UUIDField(Field):
|
||||||
elif isinstance(value, tuple):
|
elif isinstance(value, tuple):
|
||||||
return UUID(fields=value)
|
return UUID(fields=value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid value for UUIDField: %r' % value)
|
raise ValueError("Invalid value for UUIDField: %r" % value)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape(str(value), quote)
|
return escape(str(value), quote)
|
||||||
|
@ -576,7 +581,7 @@ class UUIDField(Field):
|
||||||
class IPv4Field(Field):
|
class IPv4Field(Field):
|
||||||
|
|
||||||
class_default = 0
|
class_default = 0
|
||||||
db_type = 'IPv4'
|
db_type = "IPv4"
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, IPv4Address):
|
if isinstance(value, IPv4Address):
|
||||||
|
@ -584,7 +589,7 @@ class IPv4Field(Field):
|
||||||
elif isinstance(value, (bytes, str, int)):
|
elif isinstance(value, (bytes, str, int)):
|
||||||
return IPv4Address(value)
|
return IPv4Address(value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid value for IPv4Address: %r' % value)
|
raise ValueError("Invalid value for IPv4Address: %r" % value)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape(str(value), quote)
|
return escape(str(value), quote)
|
||||||
|
@ -593,7 +598,7 @@ class IPv4Field(Field):
|
||||||
class IPv6Field(Field):
|
class IPv6Field(Field):
|
||||||
|
|
||||||
class_default = 0
|
class_default = 0
|
||||||
db_type = 'IPv6'
|
db_type = "IPv6"
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, IPv6Address):
|
if isinstance(value, IPv6Address):
|
||||||
|
@ -601,7 +606,7 @@ class IPv6Field(Field):
|
||||||
elif isinstance(value, (bytes, str, int)):
|
elif isinstance(value, (bytes, str, int)):
|
||||||
return IPv6Address(value)
|
return IPv6Address(value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid value for IPv6Address: %r' % value)
|
raise ValueError("Invalid value for IPv6Address: %r" % value)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape(str(value), quote)
|
return escape(str(value), quote)
|
||||||
|
@ -611,9 +616,10 @@ class NullableField(Field):
|
||||||
|
|
||||||
class_default = None
|
class_default = None
|
||||||
|
|
||||||
def __init__(self, inner_field, default=None, alias=None, materialized=None,
|
def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
|
||||||
extra_null_values=None, codec=None):
|
assert isinstance(
|
||||||
assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
|
inner_field, Field
|
||||||
|
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
self._null_values = [None]
|
self._null_values = [None]
|
||||||
if extra_null_values:
|
if extra_null_values:
|
||||||
|
@ -621,7 +627,7 @@ class NullableField(Field):
|
||||||
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
|
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if value == '\\N' or value in self._null_values:
|
if value == "\\N" or value in self._null_values:
|
||||||
return None
|
return None
|
||||||
return self.inner_field.to_python(value, timezone_in_use)
|
return self.inner_field.to_python(value, timezone_in_use)
|
||||||
|
|
||||||
|
@ -630,22 +636,27 @@ class NullableField(Field):
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
if value in self._null_values:
|
if value in self._null_values:
|
||||||
return '\\N'
|
return "\\N"
|
||||||
return self.inner_field.to_db_string(value, quote=quote)
|
return self.inner_field.to_db_string(value, quote=quote)
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
|
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||||
if with_default_expression:
|
if with_default_expression:
|
||||||
sql += self._extra_params(db)
|
sql += self._extra_params(db)
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
class LowCardinalityField(Field):
|
class LowCardinalityField(Field):
|
||||||
|
|
||||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||||
assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
|
assert isinstance(
|
||||||
assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM"
|
inner_field, Field
|
||||||
assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
|
||||||
|
assert not isinstance(
|
||||||
|
inner_field, LowCardinalityField
|
||||||
|
), "LowCardinality inner fields are not supported by the ORM"
|
||||||
|
assert not isinstance(
|
||||||
|
inner_field, ArrayField
|
||||||
|
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
self.class_default = self.inner_field.class_default
|
self.class_default = self.inner_field.class_default
|
||||||
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
|
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
|
||||||
|
@ -661,10 +672,14 @@ class LowCardinalityField(Field):
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
if db and db.has_low_cardinality_support:
|
if db and db.has_low_cardinality_support:
|
||||||
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
|
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
|
||||||
else:
|
else:
|
||||||
sql = self.inner_field.get_sql(with_default_expression=False)
|
sql = self.inner_field.get_sql(with_default_expression=False)
|
||||||
logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__))
|
logger.warning(
|
||||||
|
"LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback".format(
|
||||||
|
self.inner_field.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
if with_default_expression:
|
if with_default_expression:
|
||||||
sql += self._extra_params(db)
|
sql += self._extra_params(db)
|
||||||
return sql
|
return sql
|
||||||
|
|
Loading…
Reference in New Issue
Block a user