Merge remote-tracking branch 'upstream/develop' into merge-upstream

This commit is contained in:
olliemath 2025-07-01 20:28:43 +01:00
commit ee01e82372
11 changed files with 433 additions and 159 deletions

View File

@ -1,6 +1,18 @@
Change Log
==========
v2.3.0
------
Merges upstream changes:
- Fix pagination for models with alias columns
- Add `QuerySet.model` to support django-rest-framework 3
- Improve support of ClickHouse v21.9 (mangototango)
- Ignore non-numeric parts in ClickHouse version (mangototango)
- Fix precedence of ~ operator in Q objects (mangototango)
- Support for adding a column to the beginning of a table (meanmail)
- Add stddevPop and stddevSamp functions (k.peskov)
v2.2.2
------
- Unpined requirements to enhance compatability
@ -216,5 +228,3 @@ v0.7.0
v0.6.3
------
- Python 3 support

View File

@ -55,6 +55,14 @@ class ServerError(DatabaseException):
""",
re.VERBOSE | re.DOTALL,
),
# ClickHouse v21+
re.compile(
r"""
Code:\ (?P<code>\d+).
\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
""",
re.VERBOSE | re.DOTALL,
),
)
@classmethod
@ -124,14 +132,20 @@ class Database(object):
self.db_exists = self._is_existing_database()
if readonly:
if not self.db_exists:
raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
raise DatabaseException(
"Database does not exist, and cannot be created under readonly connection"
)
self.connection_readonly = self._is_connection_readonly()
self.readonly = True
elif autocreate and not self.db_exists:
self.create_database()
self.server_version = self._get_server_version()
# Versions 1.1.53981 and below don't have timezone function
self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc
self.server_timezone = (
self._get_server_timezone()
if self.server_version > (1, 1, 53981)
else pytz.utc
)
# Versions 19.1.16 and above support codec compression
self.has_codec_support = self.server_version >= (19, 1, 16)
# Version 19.0 and above support LowCardinality
@ -158,7 +172,9 @@ class Database(object):
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if model_class.engine is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
raise DatabaseException(
"%s class must define an engine" % model_class.__name__
)
self._send(model_class.create_table_sql(self))
def drop_table(self, model_class):
@ -229,7 +245,9 @@ class Database(object):
if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables")
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
fields_list = ",".join(
["`%s`" % name for name in first_instance.fields(writable=True)]
)
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
@ -289,11 +307,15 @@ class Database(object):
lines = r.iter_lines()
field_names = parse_tsv(next(lines))
field_types = parse_tsv(next(lines))
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
model_class = model_class or ModelBase.create_ad_hoc_model(
zip(field_names, field_types)
)
for line in lines:
# skip blank line left by WITH TOTALS modifier
if line:
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
yield model_class.from_tsv(
line, field_names, self.server_timezone, self
)
def raw(self, query, settings=None, stream=False):
"""
@ -306,7 +328,15 @@ class Database(object):
query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
def paginate(
self,
model_class,
order_by,
page_num=1,
page_size=100,
conditions=None,
settings=None,
):
"""
Selects records and returns a single page of model instances.
@ -330,7 +360,8 @@ class Database(object):
elif page_num < 1:
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
query = "SELECT * FROM $table"
query = "SELECT {} FROM $table".format(", ".join(model_class.fields().keys()))
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
@ -367,7 +398,9 @@ class Database(object):
self.insert(
[
MigrationHistory(
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today(),
)
]
)
@ -378,7 +411,10 @@ class Database(object):
from .migrations import MigrationHistory
self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = (
"SELECT module_name from $table WHERE package_name = '%s'"
% migrations_package_name
)
query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query))
@ -388,7 +424,9 @@ class Database(object):
if self.log_statements:
logger.info(data)
params = self._build_params(settings)
r = self.request_session.post(self.db_url, params=params, data=data, stream=stream, timeout=self.timeout)
r = self.request_session.post(
self.db_url, params=params, data=data, stream=stream, timeout=self.timeout
)
if r.status_code != 200:
raise ServerError(r.text)
return r
@ -413,7 +451,10 @@ class Database(object):
if model_class.is_system_model():
mapping["table"] = "`system`.`%s`" % model_class.table_name()
else:
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
mapping["table"] = "`%s`.`%s`" % (
self.db_name,
model_class.table_name(),
)
query = Template(query).safe_substitute(mapping)
return query
@ -432,10 +473,12 @@ class Database(object):
except ServerError as e:
logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
ver = "1.1.0"
return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
r = self._send(
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
)
return r.text.strip() == "1"
def _is_connection_readonly(self):

View File

@ -26,18 +26,30 @@ class Field(FunctionOperatorsMixin):
class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert [default, alias, materialized].count(
None
) >= 2, "Only one of default, alias and materialized parameters can be given"
def __init__(
self, default=None, alias=None, materialized=None, readonly=None, codec=None
):
assert [default, alias, materialized].count(None) >= 2, (
"Only one of default, alias and materialized parameters can be given"
)
assert (
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
alias is None
or isinstance(alias, F)
or isinstance(alias, str)
and alias != ""
), "Alias parameter must be a string or function object, if given"
assert (
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
materialized is None
or isinstance(materialized, F)
or isinstance(materialized, str)
and materialized != ""
), "Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
assert readonly is None or type(readonly) is bool, (
"readonly parameter must be bool if given"
)
assert codec is None or isinstance(codec, str) and codec != "", (
"Codec field must be string, if given"
)
if alias:
assert codec is None, "Codec cannot be used for alias fields"
@ -76,7 +88,8 @@ class Field(FunctionOperatorsMixin):
"""
if value < min_value or value > max_value:
raise ValueError(
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
"%s out of range - %s is not between %s and %s"
% (self.__class__.__name__, value, min_value, max_value)
)
def to_db_string(self, value, quote=True):
@ -117,7 +130,7 @@ class Field(FunctionOperatorsMixin):
elif self.default:
default = self.to_db_string(self.default)
sql += " DEFAULT %s" % default
if self.codec and db and db.has_codec_support:
if self.codec and db and db.has_codec_support and not self.alias:
sql += " CODEC(%s)" % self.codec
return sql
@ -141,7 +154,6 @@ class Field(FunctionOperatorsMixin):
class StringField(Field):
class_default = ""
db_type = "String"
@ -154,7 +166,9 @@ class StringField(Field):
class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
def __init__(
self, length, default=None, alias=None, materialized=None, readonly=None
):
self._length = length
self.db_type = "FixedString(%d)" % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
@ -167,11 +181,13 @@ class FixedStringField(StringField):
if isinstance(value, str):
value = value.encode("utf-8")
if len(value) > self._length:
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
raise ValueError(
"Value of %d bytes is too long for FixedStringField(%d)"
% (len(value), self._length)
)
class DateField(Field):
min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31)
class_default = min_value
@ -198,15 +214,26 @@ class DateField(Field):
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = "DateTime"
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
def __init__(
self,
default=None,
alias=None,
materialized=None,
readonly=None,
codec=None,
timezone=None,
):
super().__init__(default, alias, materialized, readonly, codec)
# assert not timezone, 'Temporarily field timezone is not supported'
if timezone:
timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone)
timezone = (
timezone
if isinstance(timezone, BaseTzInfo)
else pytz.timezone(timezone)
)
self.timezone = timezone
def get_db_type_args(self):
@ -219,7 +246,9 @@ class DateTimeField(Field):
if isinstance(value, datetime.datetime):
return value if value.tzinfo else value.replace(tzinfo=pytz.utc)
if isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
return datetime.datetime(
value.year, value.month, value.day, tzinfo=pytz.utc
)
if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str):
@ -228,7 +257,9 @@ class DateTimeField(Field):
if len(value) == 10:
try:
value = int(value)
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
return datetime.datetime.utcfromtimestamp(value).replace(
tzinfo=pytz.utc
)
except ValueError:
pass
try:
@ -251,10 +282,19 @@ class DateTime64Field(DateTimeField):
db_type = "DateTime64"
def __init__(
self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
self,
default=None,
alias=None,
materialized=None,
readonly=None,
codec=None,
timezone=None,
precision=6,
):
super().__init__(default, alias, materialized, readonly, codec, timezone)
assert precision is None or isinstance(precision, int), "Precision must be int type"
assert precision is None or isinstance(precision, int), (
"Precision must be int type"
)
self.precision = precision
def get_db_type_args(self):
@ -271,7 +311,9 @@ class DateTime64Field(DateTimeField):
"""
return escape(
"{timestamp:0{width}.{precision}f}".format(
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
timestamp=value.timestamp(),
width=11 + self.precision,
precision=self.precision,
),
quote,
)
@ -281,7 +323,9 @@ class DateTime64Field(DateTimeField):
return super().to_python(value, timezone_in_use)
except ValueError:
if isinstance(value, (int, float)):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
return datetime.datetime.utcfromtimestamp(value).replace(
tzinfo=pytz.utc
)
if isinstance(value, str):
left_part = value.split(".")[0]
if left_part == "0000-00-00 00:00:00":
@ -289,7 +333,9 @@ class DateTime64Field(DateTimeField):
if len(left_part) == 10:
try:
value = float(value)
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
return datetime.datetime.utcfromtimestamp(value).replace(
tzinfo=pytz.utc
)
except ValueError:
pass
raise
@ -304,7 +350,9 @@ class BaseIntField(Field):
try:
return int(value)
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
raise ValueError(
"Invalid value for %s - %r" % (self.__class__.__name__, value)
)
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
@ -316,58 +364,50 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField):
min_value = 0
max_value = 2 ** 8 - 1
max_value = 2**8 - 1
db_type = "UInt8"
class UInt16Field(BaseIntField):
min_value = 0
max_value = 2 ** 16 - 1
max_value = 2**16 - 1
db_type = "UInt16"
class UInt32Field(BaseIntField):
min_value = 0
max_value = 2 ** 32 - 1
max_value = 2**32 - 1
db_type = "UInt32"
class UInt64Field(BaseIntField):
min_value = 0
max_value = 2 ** 64 - 1
max_value = 2**64 - 1
db_type = "UInt64"
class Int8Field(BaseIntField):
min_value = -(2 ** 7)
max_value = 2 ** 7 - 1
min_value = -(2**7)
max_value = 2**7 - 1
db_type = "Int8"
class Int16Field(BaseIntField):
min_value = -(2 ** 15)
max_value = 2 ** 15 - 1
min_value = -(2**15)
max_value = 2**15 - 1
db_type = "Int16"
class Int32Field(BaseIntField):
min_value = -(2 ** 31)
max_value = 2 ** 31 - 1
min_value = -(2**31)
max_value = 2**31 - 1
db_type = "Int32"
class Int64Field(BaseIntField):
min_value = -(2 ** 63)
max_value = 2 ** 63 - 1
min_value = -(2**63)
max_value = 2**63 - 1
db_type = "Int64"
@ -380,7 +420,9 @@ class BaseFloatField(Field):
try:
return float(value)
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
raise ValueError(
"Invalid value for %s - %r" % (self.__class__.__name__, value)
)
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
@ -389,12 +431,10 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField):
db_type = "Float32"
class Float64Field(BaseFloatField):
db_type = "Float64"
@ -403,9 +443,19 @@ class DecimalField(Field):
Base class for all decimal fields. Can also be used directly.
"""
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
def __init__(
self,
precision,
scale,
default=None,
alias=None,
materialized=None,
readonly=None,
):
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
assert 0 <= scale <= precision, (
"Scale must be between 0 and the given precision"
)
self.precision = precision
self.scale = scale
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
@ -421,9 +471,13 @@ class DecimalField(Field):
try:
value = Decimal(value)
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
raise ValueError(
"Invalid value for %s - %r" % (self.__class__.__name__, value)
)
if not value.is_finite():
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
raise ValueError(
"Non-finite value for %s - %r" % (self.__class__.__name__, value)
)
return self._round(value)
def to_db_string(self, value, quote=True):
@ -439,20 +493,32 @@ class DecimalField(Field):
class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
def __init__(
self, scale, default=None, alias=None, materialized=None, readonly=None
):
super(Decimal32Field, self).__init__(
9, scale, default, alias, materialized, readonly
)
self.db_type = "Decimal32(%d)" % scale
class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
def __init__(
self, scale, default=None, alias=None, materialized=None, readonly=None
):
super(Decimal64Field, self).__init__(
18, scale, default, alias, materialized, readonly
)
self.db_type = "Decimal64(%d)" % scale
class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
def __init__(
self, scale, default=None, alias=None, materialized=None, readonly=None
):
super(Decimal128Field, self).__init__(
38, scale, default, alias, materialized, readonly
)
self.db_type = "Decimal128(%d)" % scale
@ -461,11 +527,21 @@ class BaseEnumField(Field):
Abstract base class for all enum-type fields.
"""
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
def __init__(
self,
enum_cls,
default=None,
alias=None,
materialized=None,
readonly=None,
codec=None,
):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec)
super(BaseEnumField, self).__init__(
default, alias, materialized, readonly, codec
)
def to_python(self, value, timezone_in_use):
if isinstance(value, self.enum_cls):
@ -512,22 +588,31 @@ class BaseEnumField(Field):
class Enum8Field(BaseEnumField):
db_type = "Enum8"
class Enum16Field(BaseEnumField):
db_type = "Enum16"
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM"
def __init__(
self,
inner_field,
default=None,
alias=None,
materialized=None,
readonly=None,
codec=None,
):
assert isinstance(inner_field, Field), (
"The first argument of ArrayField must be a Field instance"
)
assert not isinstance(inner_field, ArrayField), (
"Multidimensional array fields are not supported by the ORM"
)
self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
@ -549,14 +634,15 @@ class ArrayField(Field):
return "[" + comma_join(array) + "]"
def get_sql(self, with_default_expression=True, db=None):
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Array(%s)" % self.inner_field.get_sql(
with_default_expression=False, db=db
)
if with_default_expression and self.codec and db and db.has_codec_support:
sql += " CODEC(%s)" % self.codec
return sql
class UUIDField(Field):
class_default = UUID(int=0)
db_type = "UUID"
@ -579,7 +665,6 @@ class UUIDField(Field):
class IPv4Field(Field):
class_default = 0
db_type = "IPv4"
@ -596,7 +681,6 @@ class IPv4Field(Field):
class IPv6Field(Field):
class_default = 0
db_type = "IPv6"
@ -613,18 +697,29 @@ class IPv6Field(Field):
class NullableField(Field):
class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
assert isinstance(
inner_field, Field
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
def __init__(
self,
inner_field,
default=None,
alias=None,
materialized=None,
extra_null_values=None,
codec=None,
):
assert isinstance(inner_field, Field), (
"The first argument of NullableField must be a Field instance. Not: {}".format(
inner_field
)
)
self.inner_field = inner_field
self._null_values = [None]
if extra_null_values:
self._null_values.extend(extra_null_values)
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
super(NullableField, self).__init__(
default, alias, materialized, readonly=None, codec=codec
)
def to_python(self, value, timezone_in_use):
if value == "\\N" or value in self._null_values:
@ -640,26 +735,40 @@ class NullableField(Field):
return self.inner_field.to_db_string(value, quote=quote)
def get_sql(self, with_default_expression=True, db=None):
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Nullable(%s)" % self.inner_field.get_sql(
with_default_expression=False, db=db
)
if with_default_expression:
sql += self._extra_params(db)
return sql
class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(
inner_field, Field
), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
assert not isinstance(
inner_field, LowCardinalityField
), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(
inner_field, ArrayField
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
def __init__(
self,
inner_field,
default=None,
alias=None,
materialized=None,
readonly=None,
codec=None,
):
assert isinstance(inner_field, Field), (
"The first argument of LowCardinalityField must be a Field instance. Not: {}".format(
inner_field
)
)
assert not isinstance(inner_field, LowCardinalityField), (
"LowCardinality inner fields are not supported by the ORM"
)
assert not isinstance(inner_field, ArrayField), (
"Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
)
self.inner_field = inner_field
self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
super(LowCardinalityField, self).__init__(
default, alias, materialized, readonly, codec
)
def to_python(self, value, timezone_in_use):
return self.inner_field.to_python(value, timezone_in_use)
@ -672,7 +781,9 @@ class LowCardinalityField(Field):
def get_sql(self, with_default_expression=True, db=None):
if db and db.has_low_cardinality_support:
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
sql = "LowCardinality(%s)" % self.inner_field.get_sql(
with_default_expression=False
)
else:
sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning(

View File

@ -186,7 +186,6 @@ class FunctionOperatorsMixin(object):
class FMeta(type):
FUNCTION_COMBINATORS = {
"type_conversion": [
{"suffix": "OrZero"},
@ -230,12 +229,23 @@ class FMeta(type):
args = comma_join(extra_args)
new_sig = comma_join(extra_args)
# Get default values for args
argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty)
argdefs = tuple(
p.default for p in sig.parameters.values() if p.default != Parameter.empty
)
# Build the new function
new_code = compile(
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(**locals()), __file__, "exec"
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(
**locals()
),
__file__,
"exec",
)
new_func = FunctionType(
code=new_code.co_consts[0],
globals=globals(),
name=new_name,
argdefs=argdefs,
)
new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs)
# If base_func was parametric, new_func should be too
if getattr(base_func, "f_parametric", False):
new_func = parametric(new_func)
@ -409,7 +419,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
@staticmethod
def toQuarter(d, timezone=NO_VALUE):
return F("toQuarter", d, timezone)
return F("toQuarter", d, timezone) if timezone else F("toQuarter", d)
@staticmethod
def toMonth(d, timezone=NO_VALUE):
@ -421,7 +431,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
@staticmethod
def toISOWeek(d, timezone=NO_VALUE):
return F("toISOWeek", d, timezone)
return F("toISOWeek", d, timezone) if timezone else F("toISOWeek", d)
@staticmethod
def toDayOfYear(d, timezone=NO_VALUE):
@ -509,15 +519,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
@staticmethod
def toYYYYMM(dt, timezone=NO_VALUE):
return F("toYYYYMM", dt, timezone)
return F("toYYYYMM", dt, timezone) if timezone else F("toYYYYMM", dt)
@staticmethod
def toYYYYMMDD(dt, timezone=NO_VALUE):
return F("toYYYYMMDD", dt, timezone)
return F("toYYYYMMDD", dt, timezone) if timezone else F("toYYYYMMDD", dt)
@staticmethod
def toYYYYMMDDhhmmss(dt, timezone=NO_VALUE):
return F("toYYYYMMDDhhmmss", dt, timezone)
return (
F("toYYYYMMDDhhmmss", dt, timezone)
if timezone
else F("toYYYYMMDDhhmmss", dt)
)
@staticmethod
def toRelativeYearNum(d, timezone=NO_VALUE):
@ -1195,11 +1209,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
@staticmethod
def arrayResize(array, size, extender=None):
return F("arrayResize", array, size, extender) if extender is not None else F("arrayResize", array, size)
return (
F("arrayResize", array, size, extender)
if extender is not None
else F("arrayResize", array, size)
)
@staticmethod
def arraySlice(array, offset, length=None):
return F("arraySlice", array, offset, length) if length is not None else F("arraySlice", array, offset)
return (
F("arraySlice", array, offset, length)
if length is not None
else F("arraySlice", array, offset)
)
@staticmethod
def arrayUniq(*args):
@ -1649,6 +1671,16 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def varSamp(x):
return F("varSamp", x)
@staticmethod
@aggregate
def stddevPop(expr):
return F("stddevPop", expr)
@staticmethod
@aggregate
def stddevSamp(expr):
return F("stddevSamp", expr)
@staticmethod
@aggregate
@parametric

View File

@ -84,10 +84,12 @@ class AlterTable(ModelOperation):
is_regular_field = not (field.materialized or field.alias)
if name not in table_fields:
logger.info(" Add column %s", name)
assert prev_name, "Cannot add a column to the beginning of the table"
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field:
cmd += " AFTER %s" % prev_name
if prev_name:
cmd += " AFTER %s" % prev_name
else:
cmd += " FIRST"
self._alter_table(database, cmd)
if is_regular_field:
@ -105,13 +107,21 @@ class AlterTable(ModelOperation):
}
for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment
assert field_name in model_fields, "Model fields and table columns in disagreement"
assert field_name in model_fields, (
"Model fields and table columns in disagreement"
)
if field_sql != model_fields[field_name]:
logger.info(
" Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
" Change type of column %s from %s to %s",
field_name,
field_sql,
model_fields[field_name],
)
self._alter_table(
database,
"MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]),
)
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
class AlterTableWithBuffer(ModelOperation):

View File

@ -217,7 +217,6 @@ class Q(object):
if mode == l_child._mode and not l_child._negate:
q = deepcopy(l_child)
q._children.append(deepcopy(r_child))
else:
q = cls()
q._children = [l_child, r_child]
@ -300,6 +299,7 @@ class QuerySet(object):
Initializer. It is possible to create a queryset like this, but the standard
way is to use `MyModel.objects_in(database)`.
"""
self.model = model_cls
self._model_cls = model_cls
self._database = database
self._order_by = []

View File

@ -6,10 +6,10 @@ idna==2.9
clickhouse-orm==2.0.1
iso8601==0.1.12
itsdangerous==1.1.0
Jinja2==2.11.2
Jinja2==2.11.3
MarkupSafe==1.1.1
pygal==2.4.0
pytz==2020.1
requests==2.23.0
urllib3==1.25.9
urllib3==1.26.5
Werkzeug==1.0.1

View File

@ -4,7 +4,13 @@ import unittest
from clickhouse_orm.database import Database, DatabaseException, ServerError
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import DateField, DateTimeField, Float32Field, Int32Field, StringField
from clickhouse_orm.fields import (
DateField,
DateTimeField,
Float32Field,
Int32Field,
StringField,
)
from clickhouse_orm.funcs import F
from clickhouse_orm.models import Model
from clickhouse_orm.query import Q
@ -56,9 +62,13 @@ class DatabaseTestCase(TestCaseWithData):
self.assertEqual(self.database.count(Person, "birthday > '2000-01-01'"), 22)
self.assertEqual(self.database.count(Person, "birthday < '1970-03-01'"), 0)
# Conditions as expression
self.assertEqual(self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22)
self.assertEqual(
self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22
)
# Conditions as Q object
self.assertEqual(self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22)
self.assertEqual(
self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22
)
def test_select(self):
self._insert_and_check(self._sample_data(), len(data))
@ -118,7 +128,9 @@ class DatabaseTestCase(TestCaseWithData):
page_num = 1
instances = set()
while True:
page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
page = self.database.paginate(
Person, "first_name, last_name", page_num, page_size
)
self.assertEqual(page.number_of_objects, len(data))
self.assertGreater(page.pages_total, 0)
[instances.add(obj.to_tsv()) for obj in page.objects]
@ -133,8 +145,12 @@ class DatabaseTestCase(TestCaseWithData):
# Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality
page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
page_a = self.database.paginate(
Person, "first_name, last_name", -1, page_size
)
page_b = self.database.paginate(
Person, "first_name, last_name", page_a.pages_total, page_size
)
self.assertEqual(page_a[1:], page_b[1:])
self.assertEqual(
[obj.to_tsv() for obj in page_a.objects],
@ -164,7 +180,9 @@ class DatabaseTestCase(TestCaseWithData):
def test_pagination_with_conditions(self):
self._insert_and_check(self._sample_data(), len(data))
# Conditions as string
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
page = self.database.paginate(
Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'"
)
self.assertEqual(page.number_of_objects, 10)
# Conditions as expression
page = self.database.paginate(
@ -176,11 +194,13 @@ class DatabaseTestCase(TestCaseWithData):
)
self.assertEqual(page.number_of_objects, 10)
# Conditions as Q object
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
page = self.database.paginate(
Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava")
)
self.assertEqual(page.number_of_objects, 10)
def test_special_chars(self):
s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
s = "אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
p = Person(first_name=s)
self.database.insert([p])
p = list(self.database.select("SELECT * from $table", Person))[0]
@ -200,12 +220,13 @@ class DatabaseTestCase(TestCaseWithData):
Database(self.database.db_name, username="default", password="wrong")
exc = cm.exception
print(exc.code, exc.message)
if exc.code == 193: # ClickHouse version < 20.3
self.assertTrue(exc.message.startswith("Wrong password for user default"))
elif exc.code == 516: # ClickHouse version >= 20.3
self.assertTrue(exc.message.startswith("default: Authentication failed"))
else:
raise Exception("Unexpected error code - %s" % exc.code)
raise Exception("Unexpected error code - %s %s" % (exc.code, exc.message))
def test_nonexisting_db(self):
db = Database("db_not_here", autocreate=False)
@ -234,7 +255,9 @@ class DatabaseTestCase(TestCaseWithData):
with self.assertRaises(DatabaseException) as cm:
self.database.create_table(EnginelessModel)
self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
self.assertEqual(
str(cm.exception), "EnginelessModel class must define an engine"
)
def test_potentially_problematic_field_names(self):
class Model1(Model):
@ -274,6 +297,8 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT DISTINCT type FROM system.columns"
for row in self.database.select(query):
if row.type.startswith("Map"):
continue # Not supported yet
ModelBase.create_ad_hoc_field(row.type)
def test_get_model_for_table(self):
@ -292,7 +317,12 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT name FROM system.tables WHERE database='system'"
for row in self.database.select(query):
print(row.name)
model = self.database.get_model_for_table(row.name, system_table=True)
if row.name in ("distributed_ddl_queue",):
continue # Not supported
try:
model = self.database.get_model_for_table(row.name, system_table=True)
except NotImplementedError:
continue # Table contains an unsupported field type
self.assertTrue(model.is_system_model())
self.assertTrue(model.is_read_only())
self.assertEqual(model.table_name(), row.name)

View File

@ -21,10 +21,10 @@ class DictionaryTestMixin:
logging.info("\t==> %s", result[0].value if result else "<empty>")
return result
def _test_func(self, func, expected_value):
def _test_func(self, func, expected_value, *alternatives):
result = self._call_func(func)
print("Comparing %s to %s" % (result[0].value, expected_value))
assert result[0].value == expected_value
assert result[0].value in (expected_value,) + alternatives
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
@ -117,10 +117,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def test_dictgethierarchy(self):
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
# Default behaviour changed in CH, but we're not really testing that
default = self._call_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)))
assert isinstance(default, list)
assert len(default) <= 1 # either [] or [99]
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [], [99])
def test_dictisin(self):
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)

View File

@ -13,7 +13,13 @@ from clickhouse_orm.engines import (
SummingMergeTree,
TinyLog,
)
from clickhouse_orm.fields import DateField, Int8Field, UInt8Field, UInt16Field, UInt32Field
from clickhouse_orm.fields import (
DateField,
Int8Field,
UInt8Field,
UInt16Field,
UInt32Field,
)
from clickhouse_orm.funcs import F
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
from clickhouse_orm.system_models import SystemPart
@ -30,7 +36,7 @@ class _EnginesHelperTestCase(unittest.TestCase):
class EnginesTestCase(_EnginesHelperTestCase):
def _create_and_insert(self, model_class):
def _create_and_insert(self, model_class, **kwargs):
self.database.create_table(model_class)
self.database.insert(
[
@ -40,6 +46,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
event_group=13,
event_count=7,
event_version=1,
**kwargs,
)
]
)
@ -72,7 +79,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
def test_merge_tree_with_granularity(self):
class TestModel(SampleModel):
engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
engine = MergeTree(
"date", ("date", "event_id", "event_group"), index_granularity=4096
)
self._create_and_insert(TestModel)
@ -98,11 +107,15 @@ class EnginesTestCase(_EnginesHelperTestCase):
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
)
with self.assertRaises(AssertionError):
MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
MergeTree(
"date", ("date", "event_id", "event_group"), replica_name="{replica}"
)
def test_collapsing_merge_tree(self):
class TestModel(SampleModel):
engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
engine = CollapsingMergeTree(
"date", ("date", "event_id", "event_group"), "event_version"
)
self._create_and_insert(TestModel)
@ -114,7 +127,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
def test_replacing_merge_tree(self):
class TestModel(SampleModel):
engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
engine = ReplacingMergeTree(
"date", ("date", "event_id", "event_group"), "event_uversion"
)
self._create_and_insert(TestModel)
@ -236,16 +251,20 @@ class EnginesTestCase(_EnginesHelperTestCase):
)
self._create_and_insert(TestModel)
self._create_and_insert(TestCollapseModel)
self._create_and_insert(TestCollapseModel, sign=1)
# Result order may be different, lets sort manually
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
self.assertEqual(2, len(parts))
self.assertEqual("testcollapsemodel", parts[0].table)
self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
self.assertEqual(
"(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", "")
)
self.assertEqual("testmodel", parts[1].table)
self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
self.assertEqual(
"(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", "")
)
def test_custom_primary_key(self):
if self.database.server_version < (18, 1):
@ -269,13 +288,12 @@ class EnginesTestCase(_EnginesHelperTestCase):
)
self._create_and_insert(TestModel)
self._create_and_insert(TestCollapseModel)
self._create_and_insert(TestCollapseModel, sign=1)
self.assertEqual(2, len(list(SystemPart.get(self.database))))
class SampleModel(Model):
date = DateField()
event_id = UInt32Field()
event_group = UInt32Field()
@ -292,7 +310,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine.create_table_sql(self.database)
exc = cm.exception
self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
self.assertEqual(
str(exc), "Cannot create Distributed engine: specify an underlying table"
)
def test_with_table_name(self):
engine = Distributed("my_cluster", "foo")
@ -317,7 +337,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
exc = cm.exception
self.assertEqual(exc.code, 170)
self.assertTrue(exc.message.startswith("Requested cluster 'cluster_name' not found"))
self.assertTrue(
exc.message.startswith("Requested cluster 'cluster_name' not found")
)
def test_verbose_engine_two_superclasses(self):
class TestModel2(SampleModel):
@ -368,11 +390,16 @@ class DistributedTestCase(_EnginesHelperTestCase):
exc = cm.exception
self.assertEqual(
str(exc),
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
"When defining Distributed engine without the table_name ensure "
"that your model has a parent model",
)
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
def _test_insert_select(
self, local_to_distributed, test_model=TestModel, include_readonly=True
):
d_model = self._create_distributed(
"test_shard_localhost", underlying=test_model
)
if local_to_distributed:
to_insert, to_select = test_model, d_model
@ -437,4 +464,6 @@ class DistributedTestCase(_EnginesHelperTestCase):
class TestModel2(self.TestModel):
event_uversion = UInt8Field(readonly=True)
return self._test_insert_select(local_to_distributed=False, test_model=TestModel2, include_readonly=False)
return self._test_insert_select(
local_to_distributed=False, test_model=TestModel2, include_readonly=False
)

View File

@ -333,6 +333,18 @@ class QuerySetTestCase(TestCaseWithData):
"(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')",
)
def test_precedence_of_negation(self):
p = ~Q(first_name='a')
q = Q(last_name='b')
r = p & q
self.assertEqual(r.to_sql(Person), "(NOT (first_name = 'a')) AND (last_name = 'b')")
r = q & p
self.assertEqual(r.to_sql(Person), "(last_name = 'b') AND (NOT (first_name = 'a'))")
r = q | p
self.assertEqual(r.to_sql(Person), "(last_name = 'b') OR (NOT (first_name = 'a'))")
r = ~q & p
self.assertEqual(r.to_sql(Person), "(NOT (last_name = 'b')) AND (NOT (first_name = 'a'))")
def test_invalid_filter(self):
qs = Person.objects_in(self.database)
with self.assertRaises(TypeError):