mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 03:00:09 +03:00
Merge remote-tracking branch 'upstream/develop' into merge-upstream
This commit is contained in:
commit
ee01e82372
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -1,6 +1,18 @@
|
|||
Change Log
|
||||
==========
|
||||
|
||||
v2.3.0
|
||||
------
|
||||
Merges upstream changes:
|
||||
|
||||
- Fix pagination for models with alias columns
|
||||
- Add `QuerySet.model` to support django-rest-framework 3
|
||||
- Improve support of ClickHouse v21.9 (mangototango)
|
||||
- Ignore non-numeric parts in ClickHouse version (mangototango)
|
||||
- Fix precedence of ~ operator in Q objects (mangototango)
|
||||
- Support for adding a column to the beginning of a table (meanmail)
|
||||
- Add stddevPop and stddevSamp functions (k.peskov)
|
||||
|
||||
v2.2.2
|
||||
------
|
||||
- Unpined requirements to enhance compatability
|
||||
|
@ -216,5 +228,3 @@ v0.7.0
|
|||
v0.6.3
|
||||
------
|
||||
- Python 3 support
|
||||
|
||||
|
||||
|
|
|
@ -55,6 +55,14 @@ class ServerError(DatabaseException):
|
|||
""",
|
||||
re.VERBOSE | re.DOTALL,
|
||||
),
|
||||
# ClickHouse v21+
|
||||
re.compile(
|
||||
r"""
|
||||
Code:\ (?P<code>\d+).
|
||||
\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
|
||||
""",
|
||||
re.VERBOSE | re.DOTALL,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -124,14 +132,20 @@ class Database(object):
|
|||
self.db_exists = self._is_existing_database()
|
||||
if readonly:
|
||||
if not self.db_exists:
|
||||
raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
|
||||
raise DatabaseException(
|
||||
"Database does not exist, and cannot be created under readonly connection"
|
||||
)
|
||||
self.connection_readonly = self._is_connection_readonly()
|
||||
self.readonly = True
|
||||
elif autocreate and not self.db_exists:
|
||||
self.create_database()
|
||||
self.server_version = self._get_server_version()
|
||||
# Versions 1.1.53981 and below don't have timezone function
|
||||
self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc
|
||||
self.server_timezone = (
|
||||
self._get_server_timezone()
|
||||
if self.server_version > (1, 1, 53981)
|
||||
else pytz.utc
|
||||
)
|
||||
# Versions 19.1.16 and above support codec compression
|
||||
self.has_codec_support = self.server_version >= (19, 1, 16)
|
||||
# Version 19.0 and above support LowCardinality
|
||||
|
@ -158,7 +172,9 @@ class Database(object):
|
|||
if model_class.is_system_model():
|
||||
raise DatabaseException("You can't create system table")
|
||||
if model_class.engine is None:
|
||||
raise DatabaseException("%s class must define an engine" % model_class.__name__)
|
||||
raise DatabaseException(
|
||||
"%s class must define an engine" % model_class.__name__
|
||||
)
|
||||
self._send(model_class.create_table_sql(self))
|
||||
|
||||
def drop_table(self, model_class):
|
||||
|
@ -229,7 +245,9 @@ class Database(object):
|
|||
if first_instance.is_read_only() or first_instance.is_system_model():
|
||||
raise DatabaseException("You can't insert into read only and system tables")
|
||||
|
||||
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
|
||||
fields_list = ",".join(
|
||||
["`%s`" % name for name in first_instance.fields(writable=True)]
|
||||
)
|
||||
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
|
||||
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
|
||||
|
||||
|
@ -289,11 +307,15 @@ class Database(object):
|
|||
lines = r.iter_lines()
|
||||
field_names = parse_tsv(next(lines))
|
||||
field_types = parse_tsv(next(lines))
|
||||
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
|
||||
model_class = model_class or ModelBase.create_ad_hoc_model(
|
||||
zip(field_names, field_types)
|
||||
)
|
||||
for line in lines:
|
||||
# skip blank line left by WITH TOTALS modifier
|
||||
if line:
|
||||
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
|
||||
yield model_class.from_tsv(
|
||||
line, field_names, self.server_timezone, self
|
||||
)
|
||||
|
||||
def raw(self, query, settings=None, stream=False):
|
||||
"""
|
||||
|
@ -306,7 +328,15 @@ class Database(object):
|
|||
query = self._substitute(query, None)
|
||||
return self._send(query, settings=settings, stream=stream).text
|
||||
|
||||
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
|
||||
def paginate(
|
||||
self,
|
||||
model_class,
|
||||
order_by,
|
||||
page_num=1,
|
||||
page_size=100,
|
||||
conditions=None,
|
||||
settings=None,
|
||||
):
|
||||
"""
|
||||
Selects records and returns a single page of model instances.
|
||||
|
||||
|
@ -330,7 +360,8 @@ class Database(object):
|
|||
elif page_num < 1:
|
||||
raise ValueError("Invalid page number: %d" % page_num)
|
||||
offset = (page_num - 1) * page_size
|
||||
query = "SELECT * FROM $table"
|
||||
query = "SELECT {} FROM $table".format(", ".join(model_class.fields().keys()))
|
||||
|
||||
if conditions:
|
||||
if isinstance(conditions, Q):
|
||||
conditions = conditions.to_sql(model_class)
|
||||
|
@ -367,7 +398,9 @@ class Database(object):
|
|||
self.insert(
|
||||
[
|
||||
MigrationHistory(
|
||||
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
|
||||
package_name=migrations_package_name,
|
||||
module_name=name,
|
||||
applied=datetime.date.today(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -378,7 +411,10 @@ class Database(object):
|
|||
from .migrations import MigrationHistory
|
||||
|
||||
self.create_table(MigrationHistory)
|
||||
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
||||
query = (
|
||||
"SELECT module_name from $table WHERE package_name = '%s'"
|
||||
% migrations_package_name
|
||||
)
|
||||
query = self._substitute(query, MigrationHistory)
|
||||
return set(obj.module_name for obj in self.select(query))
|
||||
|
||||
|
@ -388,7 +424,9 @@ class Database(object):
|
|||
if self.log_statements:
|
||||
logger.info(data)
|
||||
params = self._build_params(settings)
|
||||
r = self.request_session.post(self.db_url, params=params, data=data, stream=stream, timeout=self.timeout)
|
||||
r = self.request_session.post(
|
||||
self.db_url, params=params, data=data, stream=stream, timeout=self.timeout
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise ServerError(r.text)
|
||||
return r
|
||||
|
@ -413,7 +451,10 @@ class Database(object):
|
|||
if model_class.is_system_model():
|
||||
mapping["table"] = "`system`.`%s`" % model_class.table_name()
|
||||
else:
|
||||
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
||||
mapping["table"] = "`%s`.`%s`" % (
|
||||
self.db_name,
|
||||
model_class.table_name(),
|
||||
)
|
||||
query = Template(query).safe_substitute(mapping)
|
||||
return query
|
||||
|
||||
|
@ -432,10 +473,12 @@ class Database(object):
|
|||
except ServerError as e:
|
||||
logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
|
||||
ver = "1.1.0"
|
||||
return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
|
||||
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
|
||||
|
||||
def _is_existing_database(self):
|
||||
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
|
||||
r = self._send(
|
||||
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
|
||||
)
|
||||
return r.text.strip() == "1"
|
||||
|
||||
def _is_connection_readonly(self):
|
||||
|
|
|
@ -26,18 +26,30 @@ class Field(FunctionOperatorsMixin):
|
|||
class_default = 0 # should be overridden by concrete subclasses
|
||||
db_type = None # should be overridden by concrete subclasses
|
||||
|
||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||
assert [default, alias, materialized].count(
|
||||
None
|
||||
) >= 2, "Only one of default, alias and materialized parameters can be given"
|
||||
def __init__(
|
||||
self, default=None, alias=None, materialized=None, readonly=None, codec=None
|
||||
):
|
||||
assert [default, alias, materialized].count(None) >= 2, (
|
||||
"Only one of default, alias and materialized parameters can be given"
|
||||
)
|
||||
assert (
|
||||
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
|
||||
alias is None
|
||||
or isinstance(alias, F)
|
||||
or isinstance(alias, str)
|
||||
and alias != ""
|
||||
), "Alias parameter must be a string or function object, if given"
|
||||
assert (
|
||||
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
|
||||
materialized is None
|
||||
or isinstance(materialized, F)
|
||||
or isinstance(materialized, str)
|
||||
and materialized != ""
|
||||
), "Materialized parameter must be a string or function object, if given"
|
||||
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
|
||||
assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
|
||||
assert readonly is None or type(readonly) is bool, (
|
||||
"readonly parameter must be bool if given"
|
||||
)
|
||||
assert codec is None or isinstance(codec, str) and codec != "", (
|
||||
"Codec field must be string, if given"
|
||||
)
|
||||
if alias:
|
||||
assert codec is None, "Codec cannot be used for alias fields"
|
||||
|
||||
|
@ -76,7 +88,8 @@ class Field(FunctionOperatorsMixin):
|
|||
"""
|
||||
if value < min_value or value > max_value:
|
||||
raise ValueError(
|
||||
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
|
||||
"%s out of range - %s is not between %s and %s"
|
||||
% (self.__class__.__name__, value, min_value, max_value)
|
||||
)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
|
@ -117,7 +130,7 @@ class Field(FunctionOperatorsMixin):
|
|||
elif self.default:
|
||||
default = self.to_db_string(self.default)
|
||||
sql += " DEFAULT %s" % default
|
||||
if self.codec and db and db.has_codec_support:
|
||||
if self.codec and db and db.has_codec_support and not self.alias:
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
|
@ -141,7 +154,6 @@ class Field(FunctionOperatorsMixin):
|
|||
|
||||
|
||||
class StringField(Field):
|
||||
|
||||
class_default = ""
|
||||
db_type = "String"
|
||||
|
||||
|
@ -154,7 +166,9 @@ class StringField(Field):
|
|||
|
||||
|
||||
class FixedStringField(StringField):
|
||||
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
|
||||
def __init__(
|
||||
self, length, default=None, alias=None, materialized=None, readonly=None
|
||||
):
|
||||
self._length = length
|
||||
self.db_type = "FixedString(%d)" % length
|
||||
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
|
||||
|
@ -167,11 +181,13 @@ class FixedStringField(StringField):
|
|||
if isinstance(value, str):
|
||||
value = value.encode("utf-8")
|
||||
if len(value) > self._length:
|
||||
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
||||
raise ValueError(
|
||||
"Value of %d bytes is too long for FixedStringField(%d)"
|
||||
% (len(value), self._length)
|
||||
)
|
||||
|
||||
|
||||
class DateField(Field):
|
||||
|
||||
min_value = datetime.date(1970, 1, 1)
|
||||
max_value = datetime.date(2105, 12, 31)
|
||||
class_default = min_value
|
||||
|
@ -198,15 +214,26 @@ class DateField(Field):
|
|||
|
||||
|
||||
class DateTimeField(Field):
|
||||
|
||||
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
||||
db_type = "DateTime"
|
||||
|
||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
|
||||
def __init__(
|
||||
self,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
codec=None,
|
||||
timezone=None,
|
||||
):
|
||||
super().__init__(default, alias, materialized, readonly, codec)
|
||||
# assert not timezone, 'Temporarily field timezone is not supported'
|
||||
if timezone:
|
||||
timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone)
|
||||
timezone = (
|
||||
timezone
|
||||
if isinstance(timezone, BaseTzInfo)
|
||||
else pytz.timezone(timezone)
|
||||
)
|
||||
self.timezone = timezone
|
||||
|
||||
def get_db_type_args(self):
|
||||
|
@ -219,7 +246,9 @@ class DateTimeField(Field):
|
|||
if isinstance(value, datetime.datetime):
|
||||
return value if value.tzinfo else value.replace(tzinfo=pytz.utc)
|
||||
if isinstance(value, datetime.date):
|
||||
return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
|
||||
return datetime.datetime(
|
||||
value.year, value.month, value.day, tzinfo=pytz.utc
|
||||
)
|
||||
if isinstance(value, int):
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
if isinstance(value, str):
|
||||
|
@ -228,7 +257,9 @@ class DateTimeField(Field):
|
|||
if len(value) == 10:
|
||||
try:
|
||||
value = int(value)
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||
tzinfo=pytz.utc
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
|
@ -251,10 +282,19 @@ class DateTime64Field(DateTimeField):
|
|||
db_type = "DateTime64"
|
||||
|
||||
def __init__(
|
||||
self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
|
||||
self,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
codec=None,
|
||||
timezone=None,
|
||||
precision=6,
|
||||
):
|
||||
super().__init__(default, alias, materialized, readonly, codec, timezone)
|
||||
assert precision is None or isinstance(precision, int), "Precision must be int type"
|
||||
assert precision is None or isinstance(precision, int), (
|
||||
"Precision must be int type"
|
||||
)
|
||||
self.precision = precision
|
||||
|
||||
def get_db_type_args(self):
|
||||
|
@ -271,7 +311,9 @@ class DateTime64Field(DateTimeField):
|
|||
"""
|
||||
return escape(
|
||||
"{timestamp:0{width}.{precision}f}".format(
|
||||
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
|
||||
timestamp=value.timestamp(),
|
||||
width=11 + self.precision,
|
||||
precision=self.precision,
|
||||
),
|
||||
quote,
|
||||
)
|
||||
|
@ -281,7 +323,9 @@ class DateTime64Field(DateTimeField):
|
|||
return super().to_python(value, timezone_in_use)
|
||||
except ValueError:
|
||||
if isinstance(value, (int, float)):
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||
tzinfo=pytz.utc
|
||||
)
|
||||
if isinstance(value, str):
|
||||
left_part = value.split(".")[0]
|
||||
if left_part == "0000-00-00 00:00:00":
|
||||
|
@ -289,7 +333,9 @@ class DateTime64Field(DateTimeField):
|
|||
if len(left_part) == 10:
|
||||
try:
|
||||
value = float(value)
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||
tzinfo=pytz.utc
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
raise
|
||||
|
@ -304,7 +350,9 @@ class BaseIntField(Field):
|
|||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
raise ValueError(
|
||||
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||
)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
# There's no need to call escape since numbers do not contain
|
||||
|
@ -316,58 +364,50 @@ class BaseIntField(Field):
|
|||
|
||||
|
||||
class UInt8Field(BaseIntField):
|
||||
|
||||
min_value = 0
|
||||
max_value = 2 ** 8 - 1
|
||||
max_value = 2**8 - 1
|
||||
db_type = "UInt8"
|
||||
|
||||
|
||||
class UInt16Field(BaseIntField):
|
||||
|
||||
min_value = 0
|
||||
max_value = 2 ** 16 - 1
|
||||
max_value = 2**16 - 1
|
||||
db_type = "UInt16"
|
||||
|
||||
|
||||
class UInt32Field(BaseIntField):
|
||||
|
||||
min_value = 0
|
||||
max_value = 2 ** 32 - 1
|
||||
max_value = 2**32 - 1
|
||||
db_type = "UInt32"
|
||||
|
||||
|
||||
class UInt64Field(BaseIntField):
|
||||
|
||||
min_value = 0
|
||||
max_value = 2 ** 64 - 1
|
||||
max_value = 2**64 - 1
|
||||
db_type = "UInt64"
|
||||
|
||||
|
||||
class Int8Field(BaseIntField):
|
||||
|
||||
min_value = -(2 ** 7)
|
||||
max_value = 2 ** 7 - 1
|
||||
min_value = -(2**7)
|
||||
max_value = 2**7 - 1
|
||||
db_type = "Int8"
|
||||
|
||||
|
||||
class Int16Field(BaseIntField):
|
||||
|
||||
min_value = -(2 ** 15)
|
||||
max_value = 2 ** 15 - 1
|
||||
min_value = -(2**15)
|
||||
max_value = 2**15 - 1
|
||||
db_type = "Int16"
|
||||
|
||||
|
||||
class Int32Field(BaseIntField):
|
||||
|
||||
min_value = -(2 ** 31)
|
||||
max_value = 2 ** 31 - 1
|
||||
min_value = -(2**31)
|
||||
max_value = 2**31 - 1
|
||||
db_type = "Int32"
|
||||
|
||||
|
||||
class Int64Field(BaseIntField):
|
||||
|
||||
min_value = -(2 ** 63)
|
||||
max_value = 2 ** 63 - 1
|
||||
min_value = -(2**63)
|
||||
max_value = 2**63 - 1
|
||||
db_type = "Int64"
|
||||
|
||||
|
||||
|
@ -380,7 +420,9 @@ class BaseFloatField(Field):
|
|||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
raise ValueError(
|
||||
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||
)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
# There's no need to call escape since numbers do not contain
|
||||
|
@ -389,12 +431,10 @@ class BaseFloatField(Field):
|
|||
|
||||
|
||||
class Float32Field(BaseFloatField):
|
||||
|
||||
db_type = "Float32"
|
||||
|
||||
|
||||
class Float64Field(BaseFloatField):
|
||||
|
||||
db_type = "Float64"
|
||||
|
||||
|
||||
|
@ -403,9 +443,19 @@ class DecimalField(Field):
|
|||
Base class for all decimal fields. Can also be used directly.
|
||||
"""
|
||||
|
||||
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||
def __init__(
|
||||
self,
|
||||
precision,
|
||||
scale,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
):
|
||||
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
|
||||
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
|
||||
assert 0 <= scale <= precision, (
|
||||
"Scale must be between 0 and the given precision"
|
||||
)
|
||||
self.precision = precision
|
||||
self.scale = scale
|
||||
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
|
||||
|
@ -421,9 +471,13 @@ class DecimalField(Field):
|
|||
try:
|
||||
value = Decimal(value)
|
||||
except Exception:
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
raise ValueError(
|
||||
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||
)
|
||||
if not value.is_finite():
|
||||
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
|
||||
raise ValueError(
|
||||
"Non-finite value for %s - %r" % (self.__class__.__name__, value)
|
||||
)
|
||||
return self._round(value)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
|
@ -439,20 +493,32 @@ class DecimalField(Field):
|
|||
|
||||
|
||||
class Decimal32Field(DecimalField):
|
||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
|
||||
def __init__(
|
||||
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||
):
|
||||
super(Decimal32Field, self).__init__(
|
||||
9, scale, default, alias, materialized, readonly
|
||||
)
|
||||
self.db_type = "Decimal32(%d)" % scale
|
||||
|
||||
|
||||
class Decimal64Field(DecimalField):
|
||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
|
||||
def __init__(
|
||||
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||
):
|
||||
super(Decimal64Field, self).__init__(
|
||||
18, scale, default, alias, materialized, readonly
|
||||
)
|
||||
self.db_type = "Decimal64(%d)" % scale
|
||||
|
||||
|
||||
class Decimal128Field(DecimalField):
|
||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
||||
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
|
||||
def __init__(
|
||||
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||
):
|
||||
super(Decimal128Field, self).__init__(
|
||||
38, scale, default, alias, materialized, readonly
|
||||
)
|
||||
self.db_type = "Decimal128(%d)" % scale
|
||||
|
||||
|
||||
|
@ -461,11 +527,21 @@ class BaseEnumField(Field):
|
|||
Abstract base class for all enum-type fields.
|
||||
"""
|
||||
|
||||
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||
def __init__(
|
||||
self,
|
||||
enum_cls,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
codec=None,
|
||||
):
|
||||
self.enum_cls = enum_cls
|
||||
if default is None:
|
||||
default = list(enum_cls)[0]
|
||||
super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec)
|
||||
super(BaseEnumField, self).__init__(
|
||||
default, alias, materialized, readonly, codec
|
||||
)
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
if isinstance(value, self.enum_cls):
|
||||
|
@ -512,22 +588,31 @@ class BaseEnumField(Field):
|
|||
|
||||
|
||||
class Enum8Field(BaseEnumField):
|
||||
|
||||
db_type = "Enum8"
|
||||
|
||||
|
||||
class Enum16Field(BaseEnumField):
|
||||
|
||||
db_type = "Enum16"
|
||||
|
||||
|
||||
class ArrayField(Field):
|
||||
|
||||
class_default = []
|
||||
|
||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||
assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance"
|
||||
assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM"
|
||||
def __init__(
|
||||
self,
|
||||
inner_field,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
codec=None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), (
|
||||
"The first argument of ArrayField must be a Field instance"
|
||||
)
|
||||
assert not isinstance(inner_field, ArrayField), (
|
||||
"Multidimensional array fields are not supported by the ORM"
|
||||
)
|
||||
self.inner_field = inner_field
|
||||
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
|
||||
|
||||
|
@ -549,14 +634,15 @@ class ArrayField(Field):
|
|||
return "[" + comma_join(array) + "]"
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None):
|
||||
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
sql = "Array(%s)" % self.inner_field.get_sql(
|
||||
with_default_expression=False, db=db
|
||||
)
|
||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
|
||||
class UUIDField(Field):
|
||||
|
||||
class_default = UUID(int=0)
|
||||
db_type = "UUID"
|
||||
|
||||
|
@ -579,7 +665,6 @@ class UUIDField(Field):
|
|||
|
||||
|
||||
class IPv4Field(Field):
|
||||
|
||||
class_default = 0
|
||||
db_type = "IPv4"
|
||||
|
||||
|
@ -596,7 +681,6 @@ class IPv4Field(Field):
|
|||
|
||||
|
||||
class IPv6Field(Field):
|
||||
|
||||
class_default = 0
|
||||
db_type = "IPv6"
|
||||
|
||||
|
@ -613,18 +697,29 @@ class IPv6Field(Field):
|
|||
|
||||
|
||||
class NullableField(Field):
|
||||
|
||||
class_default = None
|
||||
|
||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
|
||||
assert isinstance(
|
||||
inner_field, Field
|
||||
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
|
||||
def __init__(
|
||||
self,
|
||||
inner_field,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
extra_null_values=None,
|
||||
codec=None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), (
|
||||
"The first argument of NullableField must be a Field instance. Not: {}".format(
|
||||
inner_field
|
||||
)
|
||||
)
|
||||
self.inner_field = inner_field
|
||||
self._null_values = [None]
|
||||
if extra_null_values:
|
||||
self._null_values.extend(extra_null_values)
|
||||
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
|
||||
super(NullableField, self).__init__(
|
||||
default, alias, materialized, readonly=None, codec=codec
|
||||
)
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
if value == "\\N" or value in self._null_values:
|
||||
|
@ -640,26 +735,40 @@ class NullableField(Field):
|
|||
return self.inner_field.to_db_string(value, quote=quote)
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None):
|
||||
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
sql = "Nullable(%s)" % self.inner_field.get_sql(
|
||||
with_default_expression=False, db=db
|
||||
)
|
||||
if with_default_expression:
|
||||
sql += self._extra_params(db)
|
||||
return sql
|
||||
|
||||
|
||||
class LowCardinalityField(Field):
|
||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
||||
assert isinstance(
|
||||
inner_field, Field
|
||||
), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
|
||||
assert not isinstance(
|
||||
inner_field, LowCardinalityField
|
||||
), "LowCardinality inner fields are not supported by the ORM"
|
||||
assert not isinstance(
|
||||
inner_field, ArrayField
|
||||
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
||||
def __init__(
|
||||
self,
|
||||
inner_field,
|
||||
default=None,
|
||||
alias=None,
|
||||
materialized=None,
|
||||
readonly=None,
|
||||
codec=None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), (
|
||||
"The first argument of LowCardinalityField must be a Field instance. Not: {}".format(
|
||||
inner_field
|
||||
)
|
||||
)
|
||||
assert not isinstance(inner_field, LowCardinalityField), (
|
||||
"LowCardinality inner fields are not supported by the ORM"
|
||||
)
|
||||
assert not isinstance(inner_field, ArrayField), (
|
||||
"Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
||||
)
|
||||
self.inner_field = inner_field
|
||||
self.class_default = self.inner_field.class_default
|
||||
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
|
||||
super(LowCardinalityField, self).__init__(
|
||||
default, alias, materialized, readonly, codec
|
||||
)
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
return self.inner_field.to_python(value, timezone_in_use)
|
||||
|
@ -672,7 +781,9 @@ class LowCardinalityField(Field):
|
|||
|
||||
def get_sql(self, with_default_expression=True, db=None):
|
||||
if db and db.has_low_cardinality_support:
|
||||
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
|
||||
sql = "LowCardinality(%s)" % self.inner_field.get_sql(
|
||||
with_default_expression=False
|
||||
)
|
||||
else:
|
||||
sql = self.inner_field.get_sql(with_default_expression=False)
|
||||
logger.warning(
|
||||
|
|
|
@ -186,7 +186,6 @@ class FunctionOperatorsMixin(object):
|
|||
|
||||
|
||||
class FMeta(type):
|
||||
|
||||
FUNCTION_COMBINATORS = {
|
||||
"type_conversion": [
|
||||
{"suffix": "OrZero"},
|
||||
|
@ -230,12 +229,23 @@ class FMeta(type):
|
|||
args = comma_join(extra_args)
|
||||
new_sig = comma_join(extra_args)
|
||||
# Get default values for args
|
||||
argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty)
|
||||
argdefs = tuple(
|
||||
p.default for p in sig.parameters.values() if p.default != Parameter.empty
|
||||
)
|
||||
# Build the new function
|
||||
new_code = compile(
|
||||
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(**locals()), __file__, "exec"
|
||||
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(
|
||||
**locals()
|
||||
),
|
||||
__file__,
|
||||
"exec",
|
||||
)
|
||||
new_func = FunctionType(
|
||||
code=new_code.co_consts[0],
|
||||
globals=globals(),
|
||||
name=new_name,
|
||||
argdefs=argdefs,
|
||||
)
|
||||
new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs)
|
||||
# If base_func was parametric, new_func should be too
|
||||
if getattr(base_func, "f_parametric", False):
|
||||
new_func = parametric(new_func)
|
||||
|
@ -409,7 +419,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
|||
|
||||
@staticmethod
|
||||
def toQuarter(d, timezone=NO_VALUE):
|
||||
return F("toQuarter", d, timezone)
|
||||
return F("toQuarter", d, timezone) if timezone else F("toQuarter", d)
|
||||
|
||||
@staticmethod
|
||||
def toMonth(d, timezone=NO_VALUE):
|
||||
|
@ -421,7 +431,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
|||
|
||||
@staticmethod
|
||||
def toISOWeek(d, timezone=NO_VALUE):
|
||||
return F("toISOWeek", d, timezone)
|
||||
return F("toISOWeek", d, timezone) if timezone else F("toISOWeek", d)
|
||||
|
||||
@staticmethod
|
||||
def toDayOfYear(d, timezone=NO_VALUE):
|
||||
|
@ -509,15 +519,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
|||
|
||||
@staticmethod
|
||||
def toYYYYMM(dt, timezone=NO_VALUE):
|
||||
return F("toYYYYMM", dt, timezone)
|
||||
return F("toYYYYMM", dt, timezone) if timezone else F("toYYYYMM", dt)
|
||||
|
||||
@staticmethod
|
||||
def toYYYYMMDD(dt, timezone=NO_VALUE):
|
||||
return F("toYYYYMMDD", dt, timezone)
|
||||
return F("toYYYYMMDD", dt, timezone) if timezone else F("toYYYYMMDD", dt)
|
||||
|
||||
@staticmethod
|
||||
def toYYYYMMDDhhmmss(dt, timezone=NO_VALUE):
|
||||
return F("toYYYYMMDDhhmmss", dt, timezone)
|
||||
return (
|
||||
F("toYYYYMMDDhhmmss", dt, timezone)
|
||||
if timezone
|
||||
else F("toYYYYMMDDhhmmss", dt)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def toRelativeYearNum(d, timezone=NO_VALUE):
|
||||
|
@ -1195,11 +1209,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
|||
|
||||
@staticmethod
|
||||
def arrayResize(array, size, extender=None):
|
||||
return F("arrayResize", array, size, extender) if extender is not None else F("arrayResize", array, size)
|
||||
return (
|
||||
F("arrayResize", array, size, extender)
|
||||
if extender is not None
|
||||
else F("arrayResize", array, size)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def arraySlice(array, offset, length=None):
|
||||
return F("arraySlice", array, offset, length) if length is not None else F("arraySlice", array, offset)
|
||||
return (
|
||||
F("arraySlice", array, offset, length)
|
||||
if length is not None
|
||||
else F("arraySlice", array, offset)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def arrayUniq(*args):
|
||||
|
@ -1649,6 +1671,16 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
|||
def varSamp(x):
|
||||
return F("varSamp", x)
|
||||
|
||||
@staticmethod
|
||||
@aggregate
|
||||
def stddevPop(expr):
|
||||
return F("stddevPop", expr)
|
||||
|
||||
@staticmethod
|
||||
@aggregate
|
||||
def stddevSamp(expr):
|
||||
return F("stddevSamp", expr)
|
||||
|
||||
@staticmethod
|
||||
@aggregate
|
||||
@parametric
|
||||
|
|
|
@ -84,10 +84,12 @@ class AlterTable(ModelOperation):
|
|||
is_regular_field = not (field.materialized or field.alias)
|
||||
if name not in table_fields:
|
||||
logger.info(" Add column %s", name)
|
||||
assert prev_name, "Cannot add a column to the beginning of the table"
|
||||
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
|
||||
if is_regular_field:
|
||||
cmd += " AFTER %s" % prev_name
|
||||
if prev_name:
|
||||
cmd += " AFTER %s" % prev_name
|
||||
else:
|
||||
cmd += " FIRST"
|
||||
self._alter_table(database, cmd)
|
||||
|
||||
if is_regular_field:
|
||||
|
@ -105,13 +107,21 @@ class AlterTable(ModelOperation):
|
|||
}
|
||||
for field_name, field_sql in self._get_table_fields(database):
|
||||
# All fields must have been created and dropped by this moment
|
||||
assert field_name in model_fields, "Model fields and table columns in disagreement"
|
||||
assert field_name in model_fields, (
|
||||
"Model fields and table columns in disagreement"
|
||||
)
|
||||
|
||||
if field_sql != model_fields[field_name]:
|
||||
logger.info(
|
||||
" Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
|
||||
" Change type of column %s from %s to %s",
|
||||
field_name,
|
||||
field_sql,
|
||||
model_fields[field_name],
|
||||
)
|
||||
self._alter_table(
|
||||
database,
|
||||
"MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]),
|
||||
)
|
||||
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
|
||||
|
||||
|
||||
class AlterTableWithBuffer(ModelOperation):
|
||||
|
|
|
@ -217,7 +217,6 @@ class Q(object):
|
|||
if mode == l_child._mode and not l_child._negate:
|
||||
q = deepcopy(l_child)
|
||||
q._children.append(deepcopy(r_child))
|
||||
|
||||
else:
|
||||
q = cls()
|
||||
q._children = [l_child, r_child]
|
||||
|
@ -300,6 +299,7 @@ class QuerySet(object):
|
|||
Initializer. It is possible to create a queryset like this, but the standard
|
||||
way is to use `MyModel.objects_in(database)`.
|
||||
"""
|
||||
self.model = model_cls
|
||||
self._model_cls = model_cls
|
||||
self._database = database
|
||||
self._order_by = []
|
||||
|
|
|
@ -6,10 +6,10 @@ idna==2.9
|
|||
clickhouse-orm==2.0.1
|
||||
iso8601==0.1.12
|
||||
itsdangerous==1.1.0
|
||||
Jinja2==2.11.2
|
||||
Jinja2==2.11.3
|
||||
MarkupSafe==1.1.1
|
||||
pygal==2.4.0
|
||||
pytz==2020.1
|
||||
requests==2.23.0
|
||||
urllib3==1.25.9
|
||||
urllib3==1.26.5
|
||||
Werkzeug==1.0.1
|
||||
|
|
|
@ -4,7 +4,13 @@ import unittest
|
|||
|
||||
from clickhouse_orm.database import Database, DatabaseException, ServerError
|
||||
from clickhouse_orm.engines import Memory
|
||||
from clickhouse_orm.fields import DateField, DateTimeField, Float32Field, Int32Field, StringField
|
||||
from clickhouse_orm.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
Float32Field,
|
||||
Int32Field,
|
||||
StringField,
|
||||
)
|
||||
from clickhouse_orm.funcs import F
|
||||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.query import Q
|
||||
|
@ -56,9 +62,13 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
self.assertEqual(self.database.count(Person, "birthday > '2000-01-01'"), 22)
|
||||
self.assertEqual(self.database.count(Person, "birthday < '1970-03-01'"), 0)
|
||||
# Conditions as expression
|
||||
self.assertEqual(self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22)
|
||||
self.assertEqual(
|
||||
self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22
|
||||
)
|
||||
# Conditions as Q object
|
||||
self.assertEqual(self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22)
|
||||
self.assertEqual(
|
||||
self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22
|
||||
)
|
||||
|
||||
def test_select(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
|
@ -118,7 +128,9 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
page_num = 1
|
||||
instances = set()
|
||||
while True:
|
||||
page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
|
||||
page = self.database.paginate(
|
||||
Person, "first_name, last_name", page_num, page_size
|
||||
)
|
||||
self.assertEqual(page.number_of_objects, len(data))
|
||||
self.assertGreater(page.pages_total, 0)
|
||||
[instances.add(obj.to_tsv()) for obj in page.objects]
|
||||
|
@ -133,8 +145,12 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
# Try different page sizes
|
||||
for page_size in (1, 2, 7, 10, 30, 100, 150):
|
||||
# Ask for the last page in two different ways and verify equality
|
||||
page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
|
||||
page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
|
||||
page_a = self.database.paginate(
|
||||
Person, "first_name, last_name", -1, page_size
|
||||
)
|
||||
page_b = self.database.paginate(
|
||||
Person, "first_name, last_name", page_a.pages_total, page_size
|
||||
)
|
||||
self.assertEqual(page_a[1:], page_b[1:])
|
||||
self.assertEqual(
|
||||
[obj.to_tsv() for obj in page_a.objects],
|
||||
|
@ -164,7 +180,9 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
def test_pagination_with_conditions(self):
|
||||
self._insert_and_check(self._sample_data(), len(data))
|
||||
# Conditions as string
|
||||
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
|
||||
page = self.database.paginate(
|
||||
Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'"
|
||||
)
|
||||
self.assertEqual(page.number_of_objects, 10)
|
||||
# Conditions as expression
|
||||
page = self.database.paginate(
|
||||
|
@ -176,11 +194,13 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
)
|
||||
self.assertEqual(page.number_of_objects, 10)
|
||||
# Conditions as Q object
|
||||
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
|
||||
page = self.database.paginate(
|
||||
Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava")
|
||||
)
|
||||
self.assertEqual(page.number_of_objects, 10)
|
||||
|
||||
def test_special_chars(self):
|
||||
s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
|
||||
s = "אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
|
||||
p = Person(first_name=s)
|
||||
self.database.insert([p])
|
||||
p = list(self.database.select("SELECT * from $table", Person))[0]
|
||||
|
@ -200,12 +220,13 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
Database(self.database.db_name, username="default", password="wrong")
|
||||
|
||||
exc = cm.exception
|
||||
print(exc.code, exc.message)
|
||||
if exc.code == 193: # ClickHouse version < 20.3
|
||||
self.assertTrue(exc.message.startswith("Wrong password for user default"))
|
||||
elif exc.code == 516: # ClickHouse version >= 20.3
|
||||
self.assertTrue(exc.message.startswith("default: Authentication failed"))
|
||||
else:
|
||||
raise Exception("Unexpected error code - %s" % exc.code)
|
||||
raise Exception("Unexpected error code - %s %s" % (exc.code, exc.message))
|
||||
|
||||
def test_nonexisting_db(self):
|
||||
db = Database("db_not_here", autocreate=False)
|
||||
|
@ -234,7 +255,9 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
|
||||
with self.assertRaises(DatabaseException) as cm:
|
||||
self.database.create_table(EnginelessModel)
|
||||
self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
|
||||
self.assertEqual(
|
||||
str(cm.exception), "EnginelessModel class must define an engine"
|
||||
)
|
||||
|
||||
def test_potentially_problematic_field_names(self):
|
||||
class Model1(Model):
|
||||
|
@ -274,6 +297,8 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
|
||||
query = "SELECT DISTINCT type FROM system.columns"
|
||||
for row in self.database.select(query):
|
||||
if row.type.startswith("Map"):
|
||||
continue # Not supported yet
|
||||
ModelBase.create_ad_hoc_field(row.type)
|
||||
|
||||
def test_get_model_for_table(self):
|
||||
|
@ -292,7 +317,12 @@ class DatabaseTestCase(TestCaseWithData):
|
|||
query = "SELECT name FROM system.tables WHERE database='system'"
|
||||
for row in self.database.select(query):
|
||||
print(row.name)
|
||||
model = self.database.get_model_for_table(row.name, system_table=True)
|
||||
if row.name in ("distributed_ddl_queue",):
|
||||
continue # Not supported
|
||||
try:
|
||||
model = self.database.get_model_for_table(row.name, system_table=True)
|
||||
except NotImplementedError:
|
||||
continue # Table contains an unsupported field type
|
||||
self.assertTrue(model.is_system_model())
|
||||
self.assertTrue(model.is_read_only())
|
||||
self.assertEqual(model.table_name(), row.name)
|
||||
|
|
|
@ -21,10 +21,10 @@ class DictionaryTestMixin:
|
|||
logging.info("\t==> %s", result[0].value if result else "<empty>")
|
||||
return result
|
||||
|
||||
def _test_func(self, func, expected_value):
|
||||
def _test_func(self, func, expected_value, *alternatives):
|
||||
result = self._call_func(func)
|
||||
print("Comparing %s to %s" % (result[0].value, expected_value))
|
||||
assert result[0].value == expected_value
|
||||
assert result[0].value in (expected_value,) + alternatives
|
||||
|
||||
|
||||
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
|
||||
|
@ -117,10 +117,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
|
|||
|
||||
def test_dictgethierarchy(self):
|
||||
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
|
||||
# Default behaviour changed in CH, but we're not really testing that
|
||||
default = self._call_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)))
|
||||
assert isinstance(default, list)
|
||||
assert len(default) <= 1 # either [] or [99]
|
||||
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [], [99])
|
||||
|
||||
def test_dictisin(self):
|
||||
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)
|
||||
|
|
|
@ -13,7 +13,13 @@ from clickhouse_orm.engines import (
|
|||
SummingMergeTree,
|
||||
TinyLog,
|
||||
)
|
||||
from clickhouse_orm.fields import DateField, Int8Field, UInt8Field, UInt16Field, UInt32Field
|
||||
from clickhouse_orm.fields import (
|
||||
DateField,
|
||||
Int8Field,
|
||||
UInt8Field,
|
||||
UInt16Field,
|
||||
UInt32Field,
|
||||
)
|
||||
from clickhouse_orm.funcs import F
|
||||
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
|
||||
from clickhouse_orm.system_models import SystemPart
|
||||
|
@ -30,7 +36,7 @@ class _EnginesHelperTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
class EnginesTestCase(_EnginesHelperTestCase):
|
||||
def _create_and_insert(self, model_class):
|
||||
def _create_and_insert(self, model_class, **kwargs):
|
||||
self.database.create_table(model_class)
|
||||
self.database.insert(
|
||||
[
|
||||
|
@ -40,6 +46,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
event_group=13,
|
||||
event_count=7,
|
||||
event_version=1,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -72,7 +79,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
|
||||
def test_merge_tree_with_granularity(self):
|
||||
class TestModel(SampleModel):
|
||||
engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
|
||||
engine = MergeTree(
|
||||
"date", ("date", "event_id", "event_group"), index_granularity=4096
|
||||
)
|
||||
|
||||
self._create_and_insert(TestModel)
|
||||
|
||||
|
@ -98,11 +107,15 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
|
||||
MergeTree(
|
||||
"date", ("date", "event_id", "event_group"), replica_name="{replica}"
|
||||
)
|
||||
|
||||
def test_collapsing_merge_tree(self):
|
||||
class TestModel(SampleModel):
|
||||
engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
|
||||
engine = CollapsingMergeTree(
|
||||
"date", ("date", "event_id", "event_group"), "event_version"
|
||||
)
|
||||
|
||||
self._create_and_insert(TestModel)
|
||||
|
||||
|
@ -114,7 +127,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
|
||||
def test_replacing_merge_tree(self):
|
||||
class TestModel(SampleModel):
|
||||
engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
|
||||
engine = ReplacingMergeTree(
|
||||
"date", ("date", "event_id", "event_group"), "event_uversion"
|
||||
)
|
||||
|
||||
self._create_and_insert(TestModel)
|
||||
|
||||
|
@ -236,16 +251,20 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
)
|
||||
|
||||
self._create_and_insert(TestModel)
|
||||
self._create_and_insert(TestCollapseModel)
|
||||
self._create_and_insert(TestCollapseModel, sign=1)
|
||||
|
||||
# Result order may be different, lets sort manually
|
||||
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
|
||||
|
||||
self.assertEqual(2, len(parts))
|
||||
self.assertEqual("testcollapsemodel", parts[0].table)
|
||||
self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
|
||||
self.assertEqual(
|
||||
"(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", "")
|
||||
)
|
||||
self.assertEqual("testmodel", parts[1].table)
|
||||
self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
|
||||
self.assertEqual(
|
||||
"(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", "")
|
||||
)
|
||||
|
||||
def test_custom_primary_key(self):
|
||||
if self.database.server_version < (18, 1):
|
||||
|
@ -269,13 +288,12 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
|||
)
|
||||
|
||||
self._create_and_insert(TestModel)
|
||||
self._create_and_insert(TestCollapseModel)
|
||||
self._create_and_insert(TestCollapseModel, sign=1)
|
||||
|
||||
self.assertEqual(2, len(list(SystemPart.get(self.database))))
|
||||
|
||||
|
||||
class SampleModel(Model):
|
||||
|
||||
date = DateField()
|
||||
event_id = UInt32Field()
|
||||
event_group = UInt32Field()
|
||||
|
@ -292,7 +310,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
|||
engine.create_table_sql(self.database)
|
||||
|
||||
exc = cm.exception
|
||||
self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
|
||||
self.assertEqual(
|
||||
str(exc), "Cannot create Distributed engine: specify an underlying table"
|
||||
)
|
||||
|
||||
def test_with_table_name(self):
|
||||
engine = Distributed("my_cluster", "foo")
|
||||
|
@ -317,7 +337,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
|||
|
||||
exc = cm.exception
|
||||
self.assertEqual(exc.code, 170)
|
||||
self.assertTrue(exc.message.startswith("Requested cluster 'cluster_name' not found"))
|
||||
self.assertTrue(
|
||||
exc.message.startswith("Requested cluster 'cluster_name' not found")
|
||||
)
|
||||
|
||||
def test_verbose_engine_two_superclasses(self):
|
||||
class TestModel2(SampleModel):
|
||||
|
@ -368,11 +390,16 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
|||
exc = cm.exception
|
||||
self.assertEqual(
|
||||
str(exc),
|
||||
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
|
||||
"When defining Distributed engine without the table_name ensure "
|
||||
"that your model has a parent model",
|
||||
)
|
||||
|
||||
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
|
||||
d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
|
||||
def _test_insert_select(
|
||||
self, local_to_distributed, test_model=TestModel, include_readonly=True
|
||||
):
|
||||
d_model = self._create_distributed(
|
||||
"test_shard_localhost", underlying=test_model
|
||||
)
|
||||
|
||||
if local_to_distributed:
|
||||
to_insert, to_select = test_model, d_model
|
||||
|
@ -437,4 +464,6 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
|||
class TestModel2(self.TestModel):
|
||||
event_uversion = UInt8Field(readonly=True)
|
||||
|
||||
return self._test_insert_select(local_to_distributed=False, test_model=TestModel2, include_readonly=False)
|
||||
return self._test_insert_select(
|
||||
local_to_distributed=False, test_model=TestModel2, include_readonly=False
|
||||
)
|
||||
|
|
|
@ -333,6 +333,18 @@ class QuerySetTestCase(TestCaseWithData):
|
|||
"(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')",
|
||||
)
|
||||
|
||||
def test_precedence_of_negation(self):
|
||||
p = ~Q(first_name='a')
|
||||
q = Q(last_name='b')
|
||||
r = p & q
|
||||
self.assertEqual(r.to_sql(Person), "(NOT (first_name = 'a')) AND (last_name = 'b')")
|
||||
r = q & p
|
||||
self.assertEqual(r.to_sql(Person), "(last_name = 'b') AND (NOT (first_name = 'a'))")
|
||||
r = q | p
|
||||
self.assertEqual(r.to_sql(Person), "(last_name = 'b') OR (NOT (first_name = 'a'))")
|
||||
r = ~q & p
|
||||
self.assertEqual(r.to_sql(Person), "(NOT (last_name = 'b')) AND (NOT (first_name = 'a'))")
|
||||
|
||||
def test_invalid_filter(self):
|
||||
qs = Person.objects_in(self.database)
|
||||
with self.assertRaises(TypeError):
|
||||
|
|
Loading…
Reference in New Issue
Block a user