From aab92d88aa1fdad5959edd77e7fc2754c583249d Mon Sep 17 00:00:00 2001 From: olliemath Date: Tue, 27 Jul 2021 23:02:01 +0100 Subject: [PATCH] Chore: fix linting on fields.py --- clickhouse_orm/fields.py | 323 ++++++++++++++++++++------------------- 1 file changed, 169 insertions(+), 154 deletions(-) diff --git a/clickhouse_orm/fields.py b/clickhouse_orm/fields.py index 4f631cc..a890385 100644 --- a/clickhouse_orm/fields.py +++ b/clickhouse_orm/fields.py @@ -1,39 +1,45 @@ from __future__ import unicode_literals + import datetime -import iso8601 -import pytz from calendar import timegm from decimal import Decimal, localcontext -from uuid import UUID -from logging import getLogger -from pytz import BaseTzInfo -from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names -from .funcs import F, FunctionOperatorsMixin from ipaddress import IPv4Address, IPv6Address +from logging import getLogger +from uuid import UUID -logger = getLogger('clickhouse_orm') +import iso8601 +import pytz +from pytz import BaseTzInfo + +from .funcs import F, FunctionOperatorsMixin +from .utils import comma_join, escape, get_subclass_names, parse_array, string_or_func + +logger = getLogger("clickhouse_orm") class Field(FunctionOperatorsMixin): - ''' + """ Abstract base class for all field types. - ''' - name = None # this is set by the parent model - parent = None # this is set by the parent model - creation_counter = 0 # used for keeping the model fields ordered - class_default = 0 # should be overridden by concrete subclasses - db_type = None # should be overridden by concrete subclasses + """ + + name = None # this is set by the parent model + parent = None # this is set by the parent model + creation_counter = 0 # used for keeping the model fields ordered + class_default = 0 # should be overridden by concrete subclasses + db_type = None # should be overridden by concrete subclasses def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): - assert [default, alias, materialized].count(None) >= 2, \ - "Only one of default, alias and materialized parameters can be given" - assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\ - "Alias parameter must be a string or function object, if given" - assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\ - "Materialized parameter must be a string or function object, if given" + assert [default, alias, materialized].count( + None + ) >= 2, "Only one of default, alias and materialized parameters can be given" + assert ( + alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "" + ), "Alias parameter must be a string or function object, if given" + assert ( + materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "" + ), "Materialized parameter must be a string or function object, if given" assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given" - assert codec is None or isinstance(codec, str) and codec != "", \ - "Codec field must be string, if given" + assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given" self.creation_counter = Field.creation_counter Field.creation_counter += 1 @@ -47,49 +53,51 @@ class Field(FunctionOperatorsMixin): return self.name def __repr__(self): - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def to_python(self, value, timezone_in_use): - ''' + """ Converts the input value into the expected Python data type, raising ValueError if the data can't be converted. Returns the converted value. Subclasses should override this. The timezone_in_use parameter should be consulted when parsing datetime fields. - ''' - return value # pragma: no cover + """ + return value # pragma: no cover def validate(self, value): - ''' + """ Called after to_python to validate that the value is suitable for the field's database type. Subclasses should override this. - ''' + """ pass def _range_check(self, value, min_value, max_value): - ''' + """ Utility method to check that the given value is between min_value and max_value. - ''' + """ if value < min_value or value > max_value: - raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) + raise ValueError( + "%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value) + ) def to_db_string(self, value, quote=True): - ''' + """ Returns the field's value prepared for writing to the database. When quote is true, strings are surrounded by single quotes. - ''' + """ return escape(value, quote) def get_sql(self, with_default_expression=True, db=None): - ''' + """ Returns an SQL expression describing the field (e.g. for CREATE TABLE). - `with_default_expression`: If True, adds default value to sql. It doesn't affect fields with alias and materialized values. - `db`: Database, used for checking supported features. - ''' + """ sql = self.db_type args = self.get_db_type_args() if args: - sql += '(%s)' % comma_join(args) + sql += "(%s)" % comma_join(args) if with_default_expression: sql += self._extra_params(db) return sql @@ -99,18 +107,18 @@ class Field(FunctionOperatorsMixin): return [] def _extra_params(self, db): - sql = '' + sql = "" if self.alias: - sql += ' ALIAS %s' % string_or_func(self.alias) + sql += " ALIAS %s" % string_or_func(self.alias) elif self.materialized: - sql += ' MATERIALIZED %s' % string_or_func(self.materialized) + sql += " MATERIALIZED %s" % string_or_func(self.materialized) elif isinstance(self.default, F): - sql += ' DEFAULT %s' % self.default.to_sql() + sql += " DEFAULT %s" % self.default.to_sql() elif self.default: default = self.to_db_string(self.default) - sql += ' DEFAULT %s' % default + sql += " DEFAULT %s" % default if self.codec and db and db.has_codec_support: - sql += ' CODEC(%s)' % self.codec + sql += " CODEC(%s)" % self.codec return sql def isinstance(self, types): @@ -124,43 +132,42 @@ class Field(FunctionOperatorsMixin): """ if isinstance(self, types): return True - inner_field = getattr(self, 'inner_field', None) + inner_field = getattr(self, "inner_field", None) while inner_field: if isinstance(inner_field, types): return True - inner_field = getattr(inner_field, 'inner_field', None) + inner_field = getattr(inner_field, "inner_field", None) return False class StringField(Field): - class_default = '' - db_type = 'String' + class_default = "" + db_type = "String" def to_python(self, value, timezone_in_use): if isinstance(value, str): return value if isinstance(value, bytes): - return value.decode('UTF-8') - raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) + return value.decode("UTF-8") + raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value)) class FixedStringField(StringField): - def __init__(self, length, default=None, alias=None, materialized=None, readonly=None): self._length = length - self.db_type = 'FixedString(%d)' % length + self.db_type = "FixedString(%d)" % length super(FixedStringField, self).__init__(default, alias, materialized, readonly) def to_python(self, value, timezone_in_use): value = super(FixedStringField, self).to_python(value, timezone_in_use) - return value.rstrip('\0') + return value.rstrip("\0") def validate(self, value): if isinstance(value, str): - value = value.encode('UTF-8') + value = value.encode("UTF-8") if len(value) > self._length: - raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length)) + raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length)) class DateField(Field): @@ -168,7 +175,7 @@ class DateField(Field): min_value = datetime.date(1970, 1, 1) max_value = datetime.date(2105, 12, 31) class_default = min_value - db_type = 'Date' + db_type = "Date" def to_python(self, value, timezone_in_use): if isinstance(value, datetime.datetime): @@ -178,10 +185,10 @@ class DateField(Field): if isinstance(value, int): return DateField.class_default + datetime.timedelta(days=value) if isinstance(value, str): - if value == '0000-00-00': + if value == "0000-00-00": return DateField.min_value - return datetime.datetime.strptime(value, '%Y-%m-%d').date() - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + return datetime.datetime.strptime(value, "%Y-%m-%d").date() + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def validate(self, value): self._range_check(value, DateField.min_value, DateField.max_value) @@ -193,10 +200,9 @@ class DateField(Field): class DateTimeField(Field): class_default = datetime.datetime.fromtimestamp(0, pytz.utc) - db_type = 'DateTime' + db_type = "DateTime" - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - timezone=None): + def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None): super().__init__(default, alias, materialized, readonly, codec) # assert not timezone, 'Temporarily field timezone is not supported' if timezone: @@ -217,7 +223,7 @@ class DateTimeField(Field): if isinstance(value, int): return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) if isinstance(value, str): - if value == '0000-00-00 00:00:00': + if value == "0000-00-00 00:00:00": return self.class_default if len(value) == 10: try: @@ -235,19 +241,20 @@ class DateTimeField(Field): if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: dt = timezone_in_use.localize(dt) return dt - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True): - return escape('%010d' % timegm(value.utctimetuple()), quote) + return escape("%010d" % timegm(value.utctimetuple()), quote) class DateTime64Field(DateTimeField): - db_type = 'DateTime64' + db_type = "DateTime64" - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - timezone=None, precision=6): + def __init__( + self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6 + ): super().__init__(default, alias, materialized, readonly, codec, timezone) - assert precision is None or isinstance(precision, int), 'Precision must be int type' + assert precision is None or isinstance(precision, int), "Precision must be int type" self.precision = precision def get_db_type_args(self): @@ -263,11 +270,10 @@ class DateTime64Field(DateTimeField): Returns string in 0000000000.000000 format, where remainder digits count is equal to precision """ return escape( - '{timestamp:0{width}.{precision}f}'.format( - timestamp=value.timestamp(), - width=11 + self.precision, - precision=self.precision), - quote + "{timestamp:0{width}.{precision}f}".format( + timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision + ), + quote, ) def to_python(self, value, timezone_in_use): @@ -277,8 +283,8 @@ class DateTime64Field(DateTimeField): if isinstance(value, (int, float)): return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) if isinstance(value, str): - left_part = value.split('.')[0] - if left_part == '0000-00-00 00:00:00': + left_part = value.split(".")[0] + if left_part == "0000-00-00 00:00:00": return self.class_default if len(left_part) == 10: try: @@ -290,14 +296,15 @@ class DateTime64Field(DateTimeField): class BaseIntField(Field): - ''' + """ Abstract base class for all integer-type fields. - ''' + """ + def to_python(self, value, timezone_in_use): try: return int(value) - except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + except Exception: + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True): # There's no need to call escape since numbers do not contain @@ -311,69 +318,69 @@ class BaseIntField(Field): class UInt8Field(BaseIntField): min_value = 0 - max_value = 2**8 - 1 - db_type = 'UInt8' + max_value = 2 ** 8 - 1 + db_type = "UInt8" class UInt16Field(BaseIntField): min_value = 0 - max_value = 2**16 - 1 - db_type = 'UInt16' + max_value = 2 ** 16 - 1 + db_type = "UInt16" class UInt32Field(BaseIntField): min_value = 0 - max_value = 2**32 - 1 - db_type = 'UInt32' + max_value = 2 ** 32 - 1 + db_type = "UInt32" class UInt64Field(BaseIntField): min_value = 0 - max_value = 2**64 - 1 - db_type = 'UInt64' + max_value = 2 ** 64 - 1 + db_type = "UInt64" class Int8Field(BaseIntField): - min_value = -2**7 - max_value = 2**7 - 1 - db_type = 'Int8' + min_value = -(2 ** 7) + max_value = 2 ** 7 - 1 + db_type = "Int8" class Int16Field(BaseIntField): - min_value = -2**15 - max_value = 2**15 - 1 - db_type = 'Int16' + min_value = -(2 ** 15) + max_value = 2 ** 15 - 1 + db_type = "Int16" class Int32Field(BaseIntField): - min_value = -2**31 - max_value = 2**31 - 1 - db_type = 'Int32' + min_value = -(2 ** 31) + max_value = 2 ** 31 - 1 + db_type = "Int32" class Int64Field(BaseIntField): - min_value = -2**63 - max_value = 2**63 - 1 - db_type = 'Int64' + min_value = -(2 ** 63) + max_value = 2 ** 63 - 1 + db_type = "Int64" class BaseFloatField(Field): - ''' + """ Abstract base class for all float-type fields. - ''' + """ def to_python(self, value, timezone_in_use): try: return float(value) - except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + except Exception: + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True): # There's no need to call escape since numbers do not contain @@ -383,28 +390,28 @@ class BaseFloatField(Field): class Float32Field(BaseFloatField): - db_type = 'Float32' + db_type = "Float32" class Float64Field(BaseFloatField): - db_type = 'Float64' + db_type = "Float64" class DecimalField(Field): - ''' + """ Base class for all decimal fields. Can also be used directly. - ''' + """ def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None): - assert 1 <= precision <= 38, 'Precision must be between 1 and 38' - assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' + assert 1 <= precision <= 38, "Precision must be between 1 and 38" + assert 0 <= scale <= precision, "Scale must be between 0 and the given precision" self.precision = precision self.scale = scale - self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) + self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale) with localcontext() as ctx: ctx.prec = 38 - self.exp = Decimal(10) ** -self.scale # for rounding to the required scale + self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.min_value = -self.max_value super(DecimalField, self).__init__(default, alias, materialized, readonly) @@ -413,10 +420,10 @@ class DecimalField(Field): if not isinstance(value, Decimal): try: value = Decimal(value) - except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + except Exception: + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) if not value.is_finite(): - raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value)) return self._round(value) def to_db_string(self, value, quote=True): @@ -432,30 +439,27 @@ class DecimalField(Field): class Decimal32Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly) - self.db_type = 'Decimal32(%d)' % scale + self.db_type = "Decimal32(%d)" % scale class Decimal64Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly) - self.db_type = 'Decimal64(%d)' % scale + self.db_type = "Decimal64(%d)" % scale class Decimal128Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly) - self.db_type = 'Decimal128(%d)' % scale + self.db_type = "Decimal128(%d)" % scale class BaseEnumField(Field): - ''' + """ Abstract base class for all enum-type fields. - ''' + """ def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None): self.enum_cls = enum_cls @@ -473,7 +477,7 @@ class BaseEnumField(Field): except Exception: return self.enum_cls(value) if isinstance(value, bytes): - decoded = value.decode('UTF-8') + decoded = value.decode("UTF-8") try: return self.enum_cls[decoded] except Exception: @@ -482,38 +486,39 @@ class BaseEnumField(Field): return self.enum_cls(value) except (KeyError, ValueError): pass - raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) + raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value)) def to_db_string(self, value, quote=True): return escape(value.name, quote) def get_db_type_args(self): - return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] + return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls] @classmethod def create_ad_hoc_field(cls, db_type): - ''' + """ Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" this method returns a matching enum field. - ''' + """ import re from enum import Enum + members = {} for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type): members[match.group(1)] = int(match.group(2)) - enum_cls = Enum('AdHocEnum', members) - field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field + enum_cls = Enum("AdHocEnum", members) + field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field return field_class(enum_cls) class Enum8Field(BaseEnumField): - db_type = 'Enum8' + db_type = "Enum8" class Enum16Field(BaseEnumField): - db_type = 'Enum16' + db_type = "Enum16" class ArrayField(Field): @@ -530,9 +535,9 @@ class ArrayField(Field): if isinstance(value, str): value = parse_array(value) elif isinstance(value, bytes): - value = parse_array(value.decode('UTF-8')) + value = parse_array(value.decode("UTF-8")) elif not isinstance(value, (list, tuple)): - raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) + raise ValueError("ArrayField expects list or tuple, not %s" % type(value)) return [self.inner_field.to_python(v, timezone_in_use) for v in value] def validate(self, value): @@ -541,19 +546,19 @@ class ArrayField(Field): def to_db_string(self, value, quote=True): array = [self.inner_field.to_db_string(v, quote=True) for v in value] - return '[' + comma_join(array) + ']' + return "[" + comma_join(array) + "]" def get_sql(self, with_default_expression=True, db=None): - sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) + sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression and self.codec and db and db.has_codec_support: - sql+= ' CODEC(%s)' % self.codec + sql += " CODEC(%s)" % self.codec return sql class UUIDField(Field): class_default = UUID(int=0) - db_type = 'UUID' + db_type = "UUID" def to_python(self, value, timezone_in_use): if isinstance(value, UUID): @@ -567,7 +572,7 @@ class UUIDField(Field): elif isinstance(value, tuple): return UUID(fields=value) else: - raise ValueError('Invalid value for UUIDField: %r' % value) + raise ValueError("Invalid value for UUIDField: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -576,7 +581,7 @@ class UUIDField(Field): class IPv4Field(Field): class_default = 0 - db_type = 'IPv4' + db_type = "IPv4" def to_python(self, value, timezone_in_use): if isinstance(value, IPv4Address): @@ -584,7 +589,7 @@ class IPv4Field(Field): elif isinstance(value, (bytes, str, int)): return IPv4Address(value) else: - raise ValueError('Invalid value for IPv4Address: %r' % value) + raise ValueError("Invalid value for IPv4Address: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -593,7 +598,7 @@ class IPv4Field(Field): class IPv6Field(Field): class_default = 0 - db_type = 'IPv6' + db_type = "IPv6" def to_python(self, value, timezone_in_use): if isinstance(value, IPv6Address): @@ -601,7 +606,7 @@ class IPv6Field(Field): elif isinstance(value, (bytes, str, int)): return IPv6Address(value) else: - raise ValueError('Invalid value for IPv6Address: %r' % value) + raise ValueError("Invalid value for IPv6Address: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -611,9 +616,10 @@ class NullableField(Field): class_default = None - def __init__(self, inner_field, default=None, alias=None, materialized=None, - extra_null_values=None, codec=None): - assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field) + def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None): + assert isinstance( + inner_field, Field + ), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field) self.inner_field = inner_field self._null_values = [None] if extra_null_values: @@ -621,7 +627,7 @@ class NullableField(Field): super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec) def to_python(self, value, timezone_in_use): - if value == '\\N' or value in self._null_values: + if value == "\\N" or value in self._null_values: return None return self.inner_field.to_python(value, timezone_in_use) @@ -630,22 +636,27 @@ class NullableField(Field): def to_db_string(self, value, quote=True): if value in self._null_values: - return '\\N' + return "\\N" return self.inner_field.to_db_string(value, quote=quote) def get_sql(self, with_default_expression=True, db=None): - sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) + sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression: sql += self._extra_params(db) return sql class LowCardinalityField(Field): - def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): - assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field) - assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM" - assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead" + assert isinstance( + inner_field, Field + ), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field) + assert not isinstance( + inner_field, LowCardinalityField + ), "LowCardinality inner fields are not supported by the ORM" + assert not isinstance( + inner_field, ArrayField + ), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead" self.inner_field = inner_field self.class_default = self.inner_field.class_default super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec) @@ -661,10 +672,14 @@ class LowCardinalityField(Field): def get_sql(self, with_default_expression=True, db=None): if db and db.has_low_cardinality_support: - sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) + sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False) else: sql = self.inner_field.get_sql(with_default_expression=False) - logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__)) + logger.warning( + "LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback".format( + self.inner_field.__class__.__name__ + ) + ) if with_default_expression: sql += self._extra_params(db) return sql