mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 19:20:14 +03:00
Merge remote-tracking branch 'upstream/develop' into merge-upstream
This commit is contained in:
commit
ee01e82372
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -1,6 +1,18 @@
|
||||||
Change Log
|
Change Log
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
v2.3.0
|
||||||
|
------
|
||||||
|
Merges upstream changes:
|
||||||
|
|
||||||
|
- Fix pagination for models with alias columns
|
||||||
|
- Add `QuerySet.model` to support django-rest-framework 3
|
||||||
|
- Improve support of ClickHouse v21.9 (mangototango)
|
||||||
|
- Ignore non-numeric parts in ClickHouse version (mangototango)
|
||||||
|
- Fix precedence of ~ operator in Q objects (mangototango)
|
||||||
|
- Support for adding a column to the beginning of a table (meanmail)
|
||||||
|
- Add stddevPop and stddevSamp functions (k.peskov)
|
||||||
|
|
||||||
v2.2.2
|
v2.2.2
|
||||||
------
|
------
|
||||||
- Unpined requirements to enhance compatability
|
- Unpined requirements to enhance compatability
|
||||||
|
@ -216,5 +228,3 @@ v0.7.0
|
||||||
v0.6.3
|
v0.6.3
|
||||||
------
|
------
|
||||||
- Python 3 support
|
- Python 3 support
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,14 @@ class ServerError(DatabaseException):
|
||||||
""",
|
""",
|
||||||
re.VERBOSE | re.DOTALL,
|
re.VERBOSE | re.DOTALL,
|
||||||
),
|
),
|
||||||
|
# ClickHouse v21+
|
||||||
|
re.compile(
|
||||||
|
r"""
|
||||||
|
Code:\ (?P<code>\d+).
|
||||||
|
\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
|
||||||
|
""",
|
||||||
|
re.VERBOSE | re.DOTALL,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -124,14 +132,20 @@ class Database(object):
|
||||||
self.db_exists = self._is_existing_database()
|
self.db_exists = self._is_existing_database()
|
||||||
if readonly:
|
if readonly:
|
||||||
if not self.db_exists:
|
if not self.db_exists:
|
||||||
raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
|
raise DatabaseException(
|
||||||
|
"Database does not exist, and cannot be created under readonly connection"
|
||||||
|
)
|
||||||
self.connection_readonly = self._is_connection_readonly()
|
self.connection_readonly = self._is_connection_readonly()
|
||||||
self.readonly = True
|
self.readonly = True
|
||||||
elif autocreate and not self.db_exists:
|
elif autocreate and not self.db_exists:
|
||||||
self.create_database()
|
self.create_database()
|
||||||
self.server_version = self._get_server_version()
|
self.server_version = self._get_server_version()
|
||||||
# Versions 1.1.53981 and below don't have timezone function
|
# Versions 1.1.53981 and below don't have timezone function
|
||||||
self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc
|
self.server_timezone = (
|
||||||
|
self._get_server_timezone()
|
||||||
|
if self.server_version > (1, 1, 53981)
|
||||||
|
else pytz.utc
|
||||||
|
)
|
||||||
# Versions 19.1.16 and above support codec compression
|
# Versions 19.1.16 and above support codec compression
|
||||||
self.has_codec_support = self.server_version >= (19, 1, 16)
|
self.has_codec_support = self.server_version >= (19, 1, 16)
|
||||||
# Version 19.0 and above support LowCardinality
|
# Version 19.0 and above support LowCardinality
|
||||||
|
@ -158,7 +172,9 @@ class Database(object):
|
||||||
if model_class.is_system_model():
|
if model_class.is_system_model():
|
||||||
raise DatabaseException("You can't create system table")
|
raise DatabaseException("You can't create system table")
|
||||||
if model_class.engine is None:
|
if model_class.engine is None:
|
||||||
raise DatabaseException("%s class must define an engine" % model_class.__name__)
|
raise DatabaseException(
|
||||||
|
"%s class must define an engine" % model_class.__name__
|
||||||
|
)
|
||||||
self._send(model_class.create_table_sql(self))
|
self._send(model_class.create_table_sql(self))
|
||||||
|
|
||||||
def drop_table(self, model_class):
|
def drop_table(self, model_class):
|
||||||
|
@ -229,7 +245,9 @@ class Database(object):
|
||||||
if first_instance.is_read_only() or first_instance.is_system_model():
|
if first_instance.is_read_only() or first_instance.is_system_model():
|
||||||
raise DatabaseException("You can't insert into read only and system tables")
|
raise DatabaseException("You can't insert into read only and system tables")
|
||||||
|
|
||||||
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
|
fields_list = ",".join(
|
||||||
|
["`%s`" % name for name in first_instance.fields(writable=True)]
|
||||||
|
)
|
||||||
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
|
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
|
||||||
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
|
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
|
||||||
|
|
||||||
|
@ -289,11 +307,15 @@ class Database(object):
|
||||||
lines = r.iter_lines()
|
lines = r.iter_lines()
|
||||||
field_names = parse_tsv(next(lines))
|
field_names = parse_tsv(next(lines))
|
||||||
field_types = parse_tsv(next(lines))
|
field_types = parse_tsv(next(lines))
|
||||||
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
|
model_class = model_class or ModelBase.create_ad_hoc_model(
|
||||||
|
zip(field_names, field_types)
|
||||||
|
)
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# skip blank line left by WITH TOTALS modifier
|
# skip blank line left by WITH TOTALS modifier
|
||||||
if line:
|
if line:
|
||||||
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
|
yield model_class.from_tsv(
|
||||||
|
line, field_names, self.server_timezone, self
|
||||||
|
)
|
||||||
|
|
||||||
def raw(self, query, settings=None, stream=False):
|
def raw(self, query, settings=None, stream=False):
|
||||||
"""
|
"""
|
||||||
|
@ -306,7 +328,15 @@ class Database(object):
|
||||||
query = self._substitute(query, None)
|
query = self._substitute(query, None)
|
||||||
return self._send(query, settings=settings, stream=stream).text
|
return self._send(query, settings=settings, stream=stream).text
|
||||||
|
|
||||||
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
|
def paginate(
|
||||||
|
self,
|
||||||
|
model_class,
|
||||||
|
order_by,
|
||||||
|
page_num=1,
|
||||||
|
page_size=100,
|
||||||
|
conditions=None,
|
||||||
|
settings=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Selects records and returns a single page of model instances.
|
Selects records and returns a single page of model instances.
|
||||||
|
|
||||||
|
@ -330,7 +360,8 @@ class Database(object):
|
||||||
elif page_num < 1:
|
elif page_num < 1:
|
||||||
raise ValueError("Invalid page number: %d" % page_num)
|
raise ValueError("Invalid page number: %d" % page_num)
|
||||||
offset = (page_num - 1) * page_size
|
offset = (page_num - 1) * page_size
|
||||||
query = "SELECT * FROM $table"
|
query = "SELECT {} FROM $table".format(", ".join(model_class.fields().keys()))
|
||||||
|
|
||||||
if conditions:
|
if conditions:
|
||||||
if isinstance(conditions, Q):
|
if isinstance(conditions, Q):
|
||||||
conditions = conditions.to_sql(model_class)
|
conditions = conditions.to_sql(model_class)
|
||||||
|
@ -367,7 +398,9 @@ class Database(object):
|
||||||
self.insert(
|
self.insert(
|
||||||
[
|
[
|
||||||
MigrationHistory(
|
MigrationHistory(
|
||||||
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
|
package_name=migrations_package_name,
|
||||||
|
module_name=name,
|
||||||
|
applied=datetime.date.today(),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -378,7 +411,10 @@ class Database(object):
|
||||||
from .migrations import MigrationHistory
|
from .migrations import MigrationHistory
|
||||||
|
|
||||||
self.create_table(MigrationHistory)
|
self.create_table(MigrationHistory)
|
||||||
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
query = (
|
||||||
|
"SELECT module_name from $table WHERE package_name = '%s'"
|
||||||
|
% migrations_package_name
|
||||||
|
)
|
||||||
query = self._substitute(query, MigrationHistory)
|
query = self._substitute(query, MigrationHistory)
|
||||||
return set(obj.module_name for obj in self.select(query))
|
return set(obj.module_name for obj in self.select(query))
|
||||||
|
|
||||||
|
@ -388,7 +424,9 @@ class Database(object):
|
||||||
if self.log_statements:
|
if self.log_statements:
|
||||||
logger.info(data)
|
logger.info(data)
|
||||||
params = self._build_params(settings)
|
params = self._build_params(settings)
|
||||||
r = self.request_session.post(self.db_url, params=params, data=data, stream=stream, timeout=self.timeout)
|
r = self.request_session.post(
|
||||||
|
self.db_url, params=params, data=data, stream=stream, timeout=self.timeout
|
||||||
|
)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
raise ServerError(r.text)
|
raise ServerError(r.text)
|
||||||
return r
|
return r
|
||||||
|
@ -413,7 +451,10 @@ class Database(object):
|
||||||
if model_class.is_system_model():
|
if model_class.is_system_model():
|
||||||
mapping["table"] = "`system`.`%s`" % model_class.table_name()
|
mapping["table"] = "`system`.`%s`" % model_class.table_name()
|
||||||
else:
|
else:
|
||||||
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
mapping["table"] = "`%s`.`%s`" % (
|
||||||
|
self.db_name,
|
||||||
|
model_class.table_name(),
|
||||||
|
)
|
||||||
query = Template(query).safe_substitute(mapping)
|
query = Template(query).safe_substitute(mapping)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
@ -432,10 +473,12 @@ class Database(object):
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
|
logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
|
||||||
ver = "1.1.0"
|
ver = "1.1.0"
|
||||||
return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
|
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
|
||||||
|
|
||||||
def _is_existing_database(self):
|
def _is_existing_database(self):
|
||||||
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
|
r = self._send(
|
||||||
|
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
|
||||||
|
)
|
||||||
return r.text.strip() == "1"
|
return r.text.strip() == "1"
|
||||||
|
|
||||||
def _is_connection_readonly(self):
|
def _is_connection_readonly(self):
|
||||||
|
|
|
@ -26,18 +26,30 @@ class Field(FunctionOperatorsMixin):
|
||||||
class_default = 0 # should be overridden by concrete subclasses
|
class_default = 0 # should be overridden by concrete subclasses
|
||||||
db_type = None # should be overridden by concrete subclasses
|
db_type = None # should be overridden by concrete subclasses
|
||||||
|
|
||||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(
|
||||||
assert [default, alias, materialized].count(
|
self, default=None, alias=None, materialized=None, readonly=None, codec=None
|
||||||
None
|
):
|
||||||
) >= 2, "Only one of default, alias and materialized parameters can be given"
|
assert [default, alias, materialized].count(None) >= 2, (
|
||||||
|
"Only one of default, alias and materialized parameters can be given"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
|
alias is None
|
||||||
|
or isinstance(alias, F)
|
||||||
|
or isinstance(alias, str)
|
||||||
|
and alias != ""
|
||||||
), "Alias parameter must be a string or function object, if given"
|
), "Alias parameter must be a string or function object, if given"
|
||||||
assert (
|
assert (
|
||||||
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
|
materialized is None
|
||||||
|
or isinstance(materialized, F)
|
||||||
|
or isinstance(materialized, str)
|
||||||
|
and materialized != ""
|
||||||
), "Materialized parameter must be a string or function object, if given"
|
), "Materialized parameter must be a string or function object, if given"
|
||||||
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
|
assert readonly is None or type(readonly) is bool, (
|
||||||
assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
|
"readonly parameter must be bool if given"
|
||||||
|
)
|
||||||
|
assert codec is None or isinstance(codec, str) and codec != "", (
|
||||||
|
"Codec field must be string, if given"
|
||||||
|
)
|
||||||
if alias:
|
if alias:
|
||||||
assert codec is None, "Codec cannot be used for alias fields"
|
assert codec is None, "Codec cannot be used for alias fields"
|
||||||
|
|
||||||
|
@ -76,7 +88,8 @@ class Field(FunctionOperatorsMixin):
|
||||||
"""
|
"""
|
||||||
if value < min_value or value > max_value:
|
if value < min_value or value > max_value:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
|
"%s out of range - %s is not between %s and %s"
|
||||||
|
% (self.__class__.__name__, value, min_value, max_value)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
|
@ -117,7 +130,7 @@ class Field(FunctionOperatorsMixin):
|
||||||
elif self.default:
|
elif self.default:
|
||||||
default = self.to_db_string(self.default)
|
default = self.to_db_string(self.default)
|
||||||
sql += " DEFAULT %s" % default
|
sql += " DEFAULT %s" % default
|
||||||
if self.codec and db and db.has_codec_support:
|
if self.codec and db and db.has_codec_support and not self.alias:
|
||||||
sql += " CODEC(%s)" % self.codec
|
sql += " CODEC(%s)" % self.codec
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
@ -141,7 +154,6 @@ class Field(FunctionOperatorsMixin):
|
||||||
|
|
||||||
|
|
||||||
class StringField(Field):
|
class StringField(Field):
|
||||||
|
|
||||||
class_default = ""
|
class_default = ""
|
||||||
db_type = "String"
|
db_type = "String"
|
||||||
|
|
||||||
|
@ -154,7 +166,9 @@ class StringField(Field):
|
||||||
|
|
||||||
|
|
||||||
class FixedStringField(StringField):
|
class FixedStringField(StringField):
|
||||||
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(
|
||||||
|
self, length, default=None, alias=None, materialized=None, readonly=None
|
||||||
|
):
|
||||||
self._length = length
|
self._length = length
|
||||||
self.db_type = "FixedString(%d)" % length
|
self.db_type = "FixedString(%d)" % length
|
||||||
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
|
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
|
||||||
|
@ -167,11 +181,13 @@ class FixedStringField(StringField):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = value.encode("utf-8")
|
value = value.encode("utf-8")
|
||||||
if len(value) > self._length:
|
if len(value) > self._length:
|
||||||
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
raise ValueError(
|
||||||
|
"Value of %d bytes is too long for FixedStringField(%d)"
|
||||||
|
% (len(value), self._length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DateField(Field):
|
class DateField(Field):
|
||||||
|
|
||||||
min_value = datetime.date(1970, 1, 1)
|
min_value = datetime.date(1970, 1, 1)
|
||||||
max_value = datetime.date(2105, 12, 31)
|
max_value = datetime.date(2105, 12, 31)
|
||||||
class_default = min_value
|
class_default = min_value
|
||||||
|
@ -198,15 +214,26 @@ class DateField(Field):
|
||||||
|
|
||||||
|
|
||||||
class DateTimeField(Field):
|
class DateTimeField(Field):
|
||||||
|
|
||||||
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
||||||
db_type = "DateTime"
|
db_type = "DateTime"
|
||||||
|
|
||||||
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
readonly=None,
|
||||||
|
codec=None,
|
||||||
|
timezone=None,
|
||||||
|
):
|
||||||
super().__init__(default, alias, materialized, readonly, codec)
|
super().__init__(default, alias, materialized, readonly, codec)
|
||||||
# assert not timezone, 'Temporarily field timezone is not supported'
|
# assert not timezone, 'Temporarily field timezone is not supported'
|
||||||
if timezone:
|
if timezone:
|
||||||
timezone = timezone if isinstance(timezone, BaseTzInfo) else pytz.timezone(timezone)
|
timezone = (
|
||||||
|
timezone
|
||||||
|
if isinstance(timezone, BaseTzInfo)
|
||||||
|
else pytz.timezone(timezone)
|
||||||
|
)
|
||||||
self.timezone = timezone
|
self.timezone = timezone
|
||||||
|
|
||||||
def get_db_type_args(self):
|
def get_db_type_args(self):
|
||||||
|
@ -219,7 +246,9 @@ class DateTimeField(Field):
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return value if value.tzinfo else value.replace(tzinfo=pytz.utc)
|
return value if value.tzinfo else value.replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, datetime.date):
|
if isinstance(value, datetime.date):
|
||||||
return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
|
return datetime.datetime(
|
||||||
|
value.year, value.month, value.day, tzinfo=pytz.utc
|
||||||
|
)
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
@ -228,7 +257,9 @@ class DateTimeField(Field):
|
||||||
if len(value) == 10:
|
if len(value) == 10:
|
||||||
try:
|
try:
|
||||||
value = int(value)
|
value = int(value)
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||||
|
tzinfo=pytz.utc
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
|
@ -251,10 +282,19 @@ class DateTime64Field(DateTimeField):
|
||||||
db_type = "DateTime64"
|
db_type = "DateTime64"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
|
self,
|
||||||
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
readonly=None,
|
||||||
|
codec=None,
|
||||||
|
timezone=None,
|
||||||
|
precision=6,
|
||||||
):
|
):
|
||||||
super().__init__(default, alias, materialized, readonly, codec, timezone)
|
super().__init__(default, alias, materialized, readonly, codec, timezone)
|
||||||
assert precision is None or isinstance(precision, int), "Precision must be int type"
|
assert precision is None or isinstance(precision, int), (
|
||||||
|
"Precision must be int type"
|
||||||
|
)
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
|
||||||
def get_db_type_args(self):
|
def get_db_type_args(self):
|
||||||
|
@ -271,7 +311,9 @@ class DateTime64Field(DateTimeField):
|
||||||
"""
|
"""
|
||||||
return escape(
|
return escape(
|
||||||
"{timestamp:0{width}.{precision}f}".format(
|
"{timestamp:0{width}.{precision}f}".format(
|
||||||
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
|
timestamp=value.timestamp(),
|
||||||
|
width=11 + self.precision,
|
||||||
|
precision=self.precision,
|
||||||
),
|
),
|
||||||
quote,
|
quote,
|
||||||
)
|
)
|
||||||
|
@ -281,7 +323,9 @@ class DateTime64Field(DateTimeField):
|
||||||
return super().to_python(value, timezone_in_use)
|
return super().to_python(value, timezone_in_use)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||||
|
tzinfo=pytz.utc
|
||||||
|
)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
left_part = value.split(".")[0]
|
left_part = value.split(".")[0]
|
||||||
if left_part == "0000-00-00 00:00:00":
|
if left_part == "0000-00-00 00:00:00":
|
||||||
|
@ -289,7 +333,9 @@ class DateTime64Field(DateTimeField):
|
||||||
if len(left_part) == 10:
|
if len(left_part) == 10:
|
||||||
try:
|
try:
|
||||||
value = float(value)
|
value = float(value)
|
||||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(
|
||||||
|
tzinfo=pytz.utc
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
raise
|
raise
|
||||||
|
@ -304,7 +350,9 @@ class BaseIntField(Field):
|
||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
raise ValueError(
|
||||||
|
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||||
|
)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
# There's no need to call escape since numbers do not contain
|
# There's no need to call escape since numbers do not contain
|
||||||
|
@ -316,58 +364,50 @@ class BaseIntField(Field):
|
||||||
|
|
||||||
|
|
||||||
class UInt8Field(BaseIntField):
|
class UInt8Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2 ** 8 - 1
|
max_value = 2**8 - 1
|
||||||
db_type = "UInt8"
|
db_type = "UInt8"
|
||||||
|
|
||||||
|
|
||||||
class UInt16Field(BaseIntField):
|
class UInt16Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2 ** 16 - 1
|
max_value = 2**16 - 1
|
||||||
db_type = "UInt16"
|
db_type = "UInt16"
|
||||||
|
|
||||||
|
|
||||||
class UInt32Field(BaseIntField):
|
class UInt32Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2 ** 32 - 1
|
max_value = 2**32 - 1
|
||||||
db_type = "UInt32"
|
db_type = "UInt32"
|
||||||
|
|
||||||
|
|
||||||
class UInt64Field(BaseIntField):
|
class UInt64Field(BaseIntField):
|
||||||
|
|
||||||
min_value = 0
|
min_value = 0
|
||||||
max_value = 2 ** 64 - 1
|
max_value = 2**64 - 1
|
||||||
db_type = "UInt64"
|
db_type = "UInt64"
|
||||||
|
|
||||||
|
|
||||||
class Int8Field(BaseIntField):
|
class Int8Field(BaseIntField):
|
||||||
|
min_value = -(2**7)
|
||||||
min_value = -(2 ** 7)
|
max_value = 2**7 - 1
|
||||||
max_value = 2 ** 7 - 1
|
|
||||||
db_type = "Int8"
|
db_type = "Int8"
|
||||||
|
|
||||||
|
|
||||||
class Int16Field(BaseIntField):
|
class Int16Field(BaseIntField):
|
||||||
|
min_value = -(2**15)
|
||||||
min_value = -(2 ** 15)
|
max_value = 2**15 - 1
|
||||||
max_value = 2 ** 15 - 1
|
|
||||||
db_type = "Int16"
|
db_type = "Int16"
|
||||||
|
|
||||||
|
|
||||||
class Int32Field(BaseIntField):
|
class Int32Field(BaseIntField):
|
||||||
|
min_value = -(2**31)
|
||||||
min_value = -(2 ** 31)
|
max_value = 2**31 - 1
|
||||||
max_value = 2 ** 31 - 1
|
|
||||||
db_type = "Int32"
|
db_type = "Int32"
|
||||||
|
|
||||||
|
|
||||||
class Int64Field(BaseIntField):
|
class Int64Field(BaseIntField):
|
||||||
|
min_value = -(2**63)
|
||||||
min_value = -(2 ** 63)
|
max_value = 2**63 - 1
|
||||||
max_value = 2 ** 63 - 1
|
|
||||||
db_type = "Int64"
|
db_type = "Int64"
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,7 +420,9 @@ class BaseFloatField(Field):
|
||||||
try:
|
try:
|
||||||
return float(value)
|
return float(value)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
raise ValueError(
|
||||||
|
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||||
|
)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
# There's no need to call escape since numbers do not contain
|
# There's no need to call escape since numbers do not contain
|
||||||
|
@ -389,12 +431,10 @@ class BaseFloatField(Field):
|
||||||
|
|
||||||
|
|
||||||
class Float32Field(BaseFloatField):
|
class Float32Field(BaseFloatField):
|
||||||
|
|
||||||
db_type = "Float32"
|
db_type = "Float32"
|
||||||
|
|
||||||
|
|
||||||
class Float64Field(BaseFloatField):
|
class Float64Field(BaseFloatField):
|
||||||
|
|
||||||
db_type = "Float64"
|
db_type = "Float64"
|
||||||
|
|
||||||
|
|
||||||
|
@ -403,9 +443,19 @@ class DecimalField(Field):
|
||||||
Base class for all decimal fields. Can also be used directly.
|
Base class for all decimal fields. Can also be used directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
precision,
|
||||||
|
scale,
|
||||||
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
readonly=None,
|
||||||
|
):
|
||||||
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
|
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
|
||||||
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
|
assert 0 <= scale <= precision, (
|
||||||
|
"Scale must be between 0 and the given precision"
|
||||||
|
)
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
|
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
|
||||||
|
@ -421,9 +471,13 @@ class DecimalField(Field):
|
||||||
try:
|
try:
|
||||||
value = Decimal(value)
|
value = Decimal(value)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
raise ValueError(
|
||||||
|
"Invalid value for %s - %r" % (self.__class__.__name__, value)
|
||||||
|
)
|
||||||
if not value.is_finite():
|
if not value.is_finite():
|
||||||
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
|
raise ValueError(
|
||||||
|
"Non-finite value for %s - %r" % (self.__class__.__name__, value)
|
||||||
|
)
|
||||||
return self._round(value)
|
return self._round(value)
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
|
@ -439,20 +493,32 @@ class DecimalField(Field):
|
||||||
|
|
||||||
|
|
||||||
class Decimal32Field(DecimalField):
|
class Decimal32Field(DecimalField):
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(
|
||||||
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
|
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||||
|
):
|
||||||
|
super(Decimal32Field, self).__init__(
|
||||||
|
9, scale, default, alias, materialized, readonly
|
||||||
|
)
|
||||||
self.db_type = "Decimal32(%d)" % scale
|
self.db_type = "Decimal32(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
class Decimal64Field(DecimalField):
|
class Decimal64Field(DecimalField):
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(
|
||||||
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
|
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||||
|
):
|
||||||
|
super(Decimal64Field, self).__init__(
|
||||||
|
18, scale, default, alias, materialized, readonly
|
||||||
|
)
|
||||||
self.db_type = "Decimal64(%d)" % scale
|
self.db_type = "Decimal64(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
class Decimal128Field(DecimalField):
|
class Decimal128Field(DecimalField):
|
||||||
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
|
def __init__(
|
||||||
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
|
self, scale, default=None, alias=None, materialized=None, readonly=None
|
||||||
|
):
|
||||||
|
super(Decimal128Field, self).__init__(
|
||||||
|
38, scale, default, alias, materialized, readonly
|
||||||
|
)
|
||||||
self.db_type = "Decimal128(%d)" % scale
|
self.db_type = "Decimal128(%d)" % scale
|
||||||
|
|
||||||
|
|
||||||
|
@ -461,11 +527,21 @@ class BaseEnumField(Field):
|
||||||
Abstract base class for all enum-type fields.
|
Abstract base class for all enum-type fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
enum_cls,
|
||||||
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
readonly=None,
|
||||||
|
codec=None,
|
||||||
|
):
|
||||||
self.enum_cls = enum_cls
|
self.enum_cls = enum_cls
|
||||||
if default is None:
|
if default is None:
|
||||||
default = list(enum_cls)[0]
|
default = list(enum_cls)[0]
|
||||||
super(BaseEnumField, self).__init__(default, alias, materialized, readonly, codec)
|
super(BaseEnumField, self).__init__(
|
||||||
|
default, alias, materialized, readonly, codec
|
||||||
|
)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, self.enum_cls):
|
if isinstance(value, self.enum_cls):
|
||||||
|
@ -512,22 +588,31 @@ class BaseEnumField(Field):
|
||||||
|
|
||||||
|
|
||||||
class Enum8Field(BaseEnumField):
|
class Enum8Field(BaseEnumField):
|
||||||
|
|
||||||
db_type = "Enum8"
|
db_type = "Enum8"
|
||||||
|
|
||||||
|
|
||||||
class Enum16Field(BaseEnumField):
|
class Enum16Field(BaseEnumField):
|
||||||
|
|
||||||
db_type = "Enum16"
|
db_type = "Enum16"
|
||||||
|
|
||||||
|
|
||||||
class ArrayField(Field):
|
class ArrayField(Field):
|
||||||
|
|
||||||
class_default = []
|
class_default = []
|
||||||
|
|
||||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(
|
||||||
assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance"
|
self,
|
||||||
assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM"
|
inner_field,
|
||||||
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
readonly=None,
|
||||||
|
codec=None,
|
||||||
|
):
|
||||||
|
assert isinstance(inner_field, Field), (
|
||||||
|
"The first argument of ArrayField must be a Field instance"
|
||||||
|
)
|
||||||
|
assert not isinstance(inner_field, ArrayField), (
|
||||||
|
"Multidimensional array fields are not supported by the ORM"
|
||||||
|
)
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
|
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
|
||||||
|
|
||||||
|
@ -549,14 +634,15 @@ class ArrayField(Field):
|
||||||
return "[" + comma_join(array) + "]"
|
return "[" + comma_join(array) + "]"
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
sql = "Array(%s)" % self.inner_field.get_sql(
|
||||||
|
with_default_expression=False, db=db
|
||||||
|
)
|
||||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||||
sql += " CODEC(%s)" % self.codec
|
sql += " CODEC(%s)" % self.codec
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
class UUIDField(Field):
|
class UUIDField(Field):
|
||||||
|
|
||||||
class_default = UUID(int=0)
|
class_default = UUID(int=0)
|
||||||
db_type = "UUID"
|
db_type = "UUID"
|
||||||
|
|
||||||
|
@ -579,7 +665,6 @@ class UUIDField(Field):
|
||||||
|
|
||||||
|
|
||||||
class IPv4Field(Field):
|
class IPv4Field(Field):
|
||||||
|
|
||||||
class_default = 0
|
class_default = 0
|
||||||
db_type = "IPv4"
|
db_type = "IPv4"
|
||||||
|
|
||||||
|
@ -596,7 +681,6 @@ class IPv4Field(Field):
|
||||||
|
|
||||||
|
|
||||||
class IPv6Field(Field):
|
class IPv6Field(Field):
|
||||||
|
|
||||||
class_default = 0
|
class_default = 0
|
||||||
db_type = "IPv6"
|
db_type = "IPv6"
|
||||||
|
|
||||||
|
@ -613,18 +697,29 @@ class IPv6Field(Field):
|
||||||
|
|
||||||
|
|
||||||
class NullableField(Field):
|
class NullableField(Field):
|
||||||
|
|
||||||
class_default = None
|
class_default = None
|
||||||
|
|
||||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
|
def __init__(
|
||||||
assert isinstance(
|
self,
|
||||||
inner_field, Field
|
inner_field,
|
||||||
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
|
default=None,
|
||||||
|
alias=None,
|
||||||
|
materialized=None,
|
||||||
|
extra_null_values=None,
|
||||||
|
codec=None,
|
||||||
|
):
|
||||||
|
assert isinstance(inner_field, Field), (
|
||||||
|
"The first argument of NullableField must be a Field instance. Not: {}".format(
|
||||||
|
inner_field
|
||||||
|
)
|
||||||
|
)
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
self._null_values = [None]
|
self._null_values = [None]
|
||||||
if extra_null_values:
|
if extra_null_values:
|
||||||
self._null_values.extend(extra_null_values)
|
self._null_values.extend(extra_null_values)
|
||||||
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
|
super(NullableField, self).__init__(
|
||||||
|
default, alias, materialized, readonly=None, codec=codec
|
||||||
|
)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
if value == "\\N" or value in self._null_values:
|
if value == "\\N" or value in self._null_values:
|
||||||
|
@ -640,26 +735,40 @@ class NullableField(Field):
|
||||||
return self.inner_field.to_db_string(value, quote=quote)
|
return self.inner_field.to_db_string(value, quote=quote)
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
sql = "Nullable(%s)" % self.inner_field.get_sql(
|
||||||
|
with_default_expression=False, db=db
|
||||||
|
)
|
||||||
if with_default_expression:
|
if with_default_expression:
|
||||||
sql += self._extra_params(db)
|
sql += self._extra_params(db)
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
class LowCardinalityField(Field):
|
class LowCardinalityField(Field):
|
||||||
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
|
def __init__(
|
||||||
assert isinstance(
|
self,
|
||||||
inner_field, Field
|
inner_field,
|
||||||
), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
|
default=None,
|
||||||
assert not isinstance(
|
alias=None,
|
||||||
inner_field, LowCardinalityField
|
materialized=None,
|
||||||
), "LowCardinality inner fields are not supported by the ORM"
|
readonly=None,
|
||||||
assert not isinstance(
|
codec=None,
|
||||||
inner_field, ArrayField
|
):
|
||||||
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
assert isinstance(inner_field, Field), (
|
||||||
|
"The first argument of LowCardinalityField must be a Field instance. Not: {}".format(
|
||||||
|
inner_field
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert not isinstance(inner_field, LowCardinalityField), (
|
||||||
|
"LowCardinality inner fields are not supported by the ORM"
|
||||||
|
)
|
||||||
|
assert not isinstance(inner_field, ArrayField), (
|
||||||
|
"Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
|
||||||
|
)
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
self.class_default = self.inner_field.class_default
|
self.class_default = self.inner_field.class_default
|
||||||
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
|
super(LowCardinalityField, self).__init__(
|
||||||
|
default, alias, materialized, readonly, codec
|
||||||
|
)
|
||||||
|
|
||||||
def to_python(self, value, timezone_in_use):
|
def to_python(self, value, timezone_in_use):
|
||||||
return self.inner_field.to_python(value, timezone_in_use)
|
return self.inner_field.to_python(value, timezone_in_use)
|
||||||
|
@ -672,7 +781,9 @@ class LowCardinalityField(Field):
|
||||||
|
|
||||||
def get_sql(self, with_default_expression=True, db=None):
|
def get_sql(self, with_default_expression=True, db=None):
|
||||||
if db and db.has_low_cardinality_support:
|
if db and db.has_low_cardinality_support:
|
||||||
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
|
sql = "LowCardinality(%s)" % self.inner_field.get_sql(
|
||||||
|
with_default_expression=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sql = self.inner_field.get_sql(with_default_expression=False)
|
sql = self.inner_field.get_sql(with_default_expression=False)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -186,7 +186,6 @@ class FunctionOperatorsMixin(object):
|
||||||
|
|
||||||
|
|
||||||
class FMeta(type):
|
class FMeta(type):
|
||||||
|
|
||||||
FUNCTION_COMBINATORS = {
|
FUNCTION_COMBINATORS = {
|
||||||
"type_conversion": [
|
"type_conversion": [
|
||||||
{"suffix": "OrZero"},
|
{"suffix": "OrZero"},
|
||||||
|
@ -230,12 +229,23 @@ class FMeta(type):
|
||||||
args = comma_join(extra_args)
|
args = comma_join(extra_args)
|
||||||
new_sig = comma_join(extra_args)
|
new_sig = comma_join(extra_args)
|
||||||
# Get default values for args
|
# Get default values for args
|
||||||
argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty)
|
argdefs = tuple(
|
||||||
|
p.default for p in sig.parameters.values() if p.default != Parameter.empty
|
||||||
|
)
|
||||||
# Build the new function
|
# Build the new function
|
||||||
new_code = compile(
|
new_code = compile(
|
||||||
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(**locals()), __file__, "exec"
|
'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(
|
||||||
|
**locals()
|
||||||
|
),
|
||||||
|
__file__,
|
||||||
|
"exec",
|
||||||
|
)
|
||||||
|
new_func = FunctionType(
|
||||||
|
code=new_code.co_consts[0],
|
||||||
|
globals=globals(),
|
||||||
|
name=new_name,
|
||||||
|
argdefs=argdefs,
|
||||||
)
|
)
|
||||||
new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs)
|
|
||||||
# If base_func was parametric, new_func should be too
|
# If base_func was parametric, new_func should be too
|
||||||
if getattr(base_func, "f_parametric", False):
|
if getattr(base_func, "f_parametric", False):
|
||||||
new_func = parametric(new_func)
|
new_func = parametric(new_func)
|
||||||
|
@ -409,7 +419,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toQuarter(d, timezone=NO_VALUE):
|
def toQuarter(d, timezone=NO_VALUE):
|
||||||
return F("toQuarter", d, timezone)
|
return F("toQuarter", d, timezone) if timezone else F("toQuarter", d)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toMonth(d, timezone=NO_VALUE):
|
def toMonth(d, timezone=NO_VALUE):
|
||||||
|
@ -421,7 +431,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toISOWeek(d, timezone=NO_VALUE):
|
def toISOWeek(d, timezone=NO_VALUE):
|
||||||
return F("toISOWeek", d, timezone)
|
return F("toISOWeek", d, timezone) if timezone else F("toISOWeek", d)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toDayOfYear(d, timezone=NO_VALUE):
|
def toDayOfYear(d, timezone=NO_VALUE):
|
||||||
|
@ -509,15 +519,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toYYYYMM(dt, timezone=NO_VALUE):
|
def toYYYYMM(dt, timezone=NO_VALUE):
|
||||||
return F("toYYYYMM", dt, timezone)
|
return F("toYYYYMM", dt, timezone) if timezone else F("toYYYYMM", dt)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toYYYYMMDD(dt, timezone=NO_VALUE):
|
def toYYYYMMDD(dt, timezone=NO_VALUE):
|
||||||
return F("toYYYYMMDD", dt, timezone)
|
return F("toYYYYMMDD", dt, timezone) if timezone else F("toYYYYMMDD", dt)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toYYYYMMDDhhmmss(dt, timezone=NO_VALUE):
|
def toYYYYMMDDhhmmss(dt, timezone=NO_VALUE):
|
||||||
return F("toYYYYMMDDhhmmss", dt, timezone)
|
return (
|
||||||
|
F("toYYYYMMDDhhmmss", dt, timezone)
|
||||||
|
if timezone
|
||||||
|
else F("toYYYYMMDDhhmmss", dt)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def toRelativeYearNum(d, timezone=NO_VALUE):
|
def toRelativeYearNum(d, timezone=NO_VALUE):
|
||||||
|
@ -1195,11 +1209,19 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def arrayResize(array, size, extender=None):
|
def arrayResize(array, size, extender=None):
|
||||||
return F("arrayResize", array, size, extender) if extender is not None else F("arrayResize", array, size)
|
return (
|
||||||
|
F("arrayResize", array, size, extender)
|
||||||
|
if extender is not None
|
||||||
|
else F("arrayResize", array, size)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def arraySlice(array, offset, length=None):
|
def arraySlice(array, offset, length=None):
|
||||||
return F("arraySlice", array, offset, length) if length is not None else F("arraySlice", array, offset)
|
return (
|
||||||
|
F("arraySlice", array, offset, length)
|
||||||
|
if length is not None
|
||||||
|
else F("arraySlice", array, offset)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def arrayUniq(*args):
|
def arrayUniq(*args):
|
||||||
|
@ -1649,6 +1671,16 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
|
||||||
def varSamp(x):
|
def varSamp(x):
|
||||||
return F("varSamp", x)
|
return F("varSamp", x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@aggregate
|
||||||
|
def stddevPop(expr):
|
||||||
|
return F("stddevPop", expr)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@aggregate
|
||||||
|
def stddevSamp(expr):
|
||||||
|
return F("stddevSamp", expr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@aggregate
|
@aggregate
|
||||||
@parametric
|
@parametric
|
||||||
|
|
|
@ -84,10 +84,12 @@ class AlterTable(ModelOperation):
|
||||||
is_regular_field = not (field.materialized or field.alias)
|
is_regular_field = not (field.materialized or field.alias)
|
||||||
if name not in table_fields:
|
if name not in table_fields:
|
||||||
logger.info(" Add column %s", name)
|
logger.info(" Add column %s", name)
|
||||||
assert prev_name, "Cannot add a column to the beginning of the table"
|
|
||||||
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
|
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
|
||||||
if is_regular_field:
|
if is_regular_field:
|
||||||
cmd += " AFTER %s" % prev_name
|
if prev_name:
|
||||||
|
cmd += " AFTER %s" % prev_name
|
||||||
|
else:
|
||||||
|
cmd += " FIRST"
|
||||||
self._alter_table(database, cmd)
|
self._alter_table(database, cmd)
|
||||||
|
|
||||||
if is_regular_field:
|
if is_regular_field:
|
||||||
|
@ -105,13 +107,21 @@ class AlterTable(ModelOperation):
|
||||||
}
|
}
|
||||||
for field_name, field_sql in self._get_table_fields(database):
|
for field_name, field_sql in self._get_table_fields(database):
|
||||||
# All fields must have been created and dropped by this moment
|
# All fields must have been created and dropped by this moment
|
||||||
assert field_name in model_fields, "Model fields and table columns in disagreement"
|
assert field_name in model_fields, (
|
||||||
|
"Model fields and table columns in disagreement"
|
||||||
|
)
|
||||||
|
|
||||||
if field_sql != model_fields[field_name]:
|
if field_sql != model_fields[field_name]:
|
||||||
logger.info(
|
logger.info(
|
||||||
" Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
|
" Change type of column %s from %s to %s",
|
||||||
|
field_name,
|
||||||
|
field_sql,
|
||||||
|
model_fields[field_name],
|
||||||
|
)
|
||||||
|
self._alter_table(
|
||||||
|
database,
|
||||||
|
"MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]),
|
||||||
)
|
)
|
||||||
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
|
|
||||||
|
|
||||||
|
|
||||||
class AlterTableWithBuffer(ModelOperation):
|
class AlterTableWithBuffer(ModelOperation):
|
||||||
|
|
|
@ -217,7 +217,6 @@ class Q(object):
|
||||||
if mode == l_child._mode and not l_child._negate:
|
if mode == l_child._mode and not l_child._negate:
|
||||||
q = deepcopy(l_child)
|
q = deepcopy(l_child)
|
||||||
q._children.append(deepcopy(r_child))
|
q._children.append(deepcopy(r_child))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q = cls()
|
q = cls()
|
||||||
q._children = [l_child, r_child]
|
q._children = [l_child, r_child]
|
||||||
|
@ -300,6 +299,7 @@ class QuerySet(object):
|
||||||
Initializer. It is possible to create a queryset like this, but the standard
|
Initializer. It is possible to create a queryset like this, but the standard
|
||||||
way is to use `MyModel.objects_in(database)`.
|
way is to use `MyModel.objects_in(database)`.
|
||||||
"""
|
"""
|
||||||
|
self.model = model_cls
|
||||||
self._model_cls = model_cls
|
self._model_cls = model_cls
|
||||||
self._database = database
|
self._database = database
|
||||||
self._order_by = []
|
self._order_by = []
|
||||||
|
|
|
@ -6,10 +6,10 @@ idna==2.9
|
||||||
clickhouse-orm==2.0.1
|
clickhouse-orm==2.0.1
|
||||||
iso8601==0.1.12
|
iso8601==0.1.12
|
||||||
itsdangerous==1.1.0
|
itsdangerous==1.1.0
|
||||||
Jinja2==2.11.2
|
Jinja2==2.11.3
|
||||||
MarkupSafe==1.1.1
|
MarkupSafe==1.1.1
|
||||||
pygal==2.4.0
|
pygal==2.4.0
|
||||||
pytz==2020.1
|
pytz==2020.1
|
||||||
requests==2.23.0
|
requests==2.23.0
|
||||||
urllib3==1.25.9
|
urllib3==1.26.5
|
||||||
Werkzeug==1.0.1
|
Werkzeug==1.0.1
|
||||||
|
|
|
@ -4,7 +4,13 @@ import unittest
|
||||||
|
|
||||||
from clickhouse_orm.database import Database, DatabaseException, ServerError
|
from clickhouse_orm.database import Database, DatabaseException, ServerError
|
||||||
from clickhouse_orm.engines import Memory
|
from clickhouse_orm.engines import Memory
|
||||||
from clickhouse_orm.fields import DateField, DateTimeField, Float32Field, Int32Field, StringField
|
from clickhouse_orm.fields import (
|
||||||
|
DateField,
|
||||||
|
DateTimeField,
|
||||||
|
Float32Field,
|
||||||
|
Int32Field,
|
||||||
|
StringField,
|
||||||
|
)
|
||||||
from clickhouse_orm.funcs import F
|
from clickhouse_orm.funcs import F
|
||||||
from clickhouse_orm.models import Model
|
from clickhouse_orm.models import Model
|
||||||
from clickhouse_orm.query import Q
|
from clickhouse_orm.query import Q
|
||||||
|
@ -56,9 +62,13 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
self.assertEqual(self.database.count(Person, "birthday > '2000-01-01'"), 22)
|
self.assertEqual(self.database.count(Person, "birthday > '2000-01-01'"), 22)
|
||||||
self.assertEqual(self.database.count(Person, "birthday < '1970-03-01'"), 0)
|
self.assertEqual(self.database.count(Person, "birthday < '1970-03-01'"), 0)
|
||||||
# Conditions as expression
|
# Conditions as expression
|
||||||
self.assertEqual(self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22)
|
self.assertEqual(
|
||||||
|
self.database.count(Person, Person.birthday > datetime.date(2000, 1, 1)), 22
|
||||||
|
)
|
||||||
# Conditions as Q object
|
# Conditions as Q object
|
||||||
self.assertEqual(self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22)
|
self.assertEqual(
|
||||||
|
self.database.count(Person, Q(birthday__gt=datetime.date(2000, 1, 1))), 22
|
||||||
|
)
|
||||||
|
|
||||||
def test_select(self):
|
def test_select(self):
|
||||||
self._insert_and_check(self._sample_data(), len(data))
|
self._insert_and_check(self._sample_data(), len(data))
|
||||||
|
@ -118,7 +128,9 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
page_num = 1
|
page_num = 1
|
||||||
instances = set()
|
instances = set()
|
||||||
while True:
|
while True:
|
||||||
page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
|
page = self.database.paginate(
|
||||||
|
Person, "first_name, last_name", page_num, page_size
|
||||||
|
)
|
||||||
self.assertEqual(page.number_of_objects, len(data))
|
self.assertEqual(page.number_of_objects, len(data))
|
||||||
self.assertGreater(page.pages_total, 0)
|
self.assertGreater(page.pages_total, 0)
|
||||||
[instances.add(obj.to_tsv()) for obj in page.objects]
|
[instances.add(obj.to_tsv()) for obj in page.objects]
|
||||||
|
@ -133,8 +145,12 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
# Try different page sizes
|
# Try different page sizes
|
||||||
for page_size in (1, 2, 7, 10, 30, 100, 150):
|
for page_size in (1, 2, 7, 10, 30, 100, 150):
|
||||||
# Ask for the last page in two different ways and verify equality
|
# Ask for the last page in two different ways and verify equality
|
||||||
page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
|
page_a = self.database.paginate(
|
||||||
page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
|
Person, "first_name, last_name", -1, page_size
|
||||||
|
)
|
||||||
|
page_b = self.database.paginate(
|
||||||
|
Person, "first_name, last_name", page_a.pages_total, page_size
|
||||||
|
)
|
||||||
self.assertEqual(page_a[1:], page_b[1:])
|
self.assertEqual(page_a[1:], page_b[1:])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[obj.to_tsv() for obj in page_a.objects],
|
[obj.to_tsv() for obj in page_a.objects],
|
||||||
|
@ -164,7 +180,9 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
def test_pagination_with_conditions(self):
|
def test_pagination_with_conditions(self):
|
||||||
self._insert_and_check(self._sample_data(), len(data))
|
self._insert_and_check(self._sample_data(), len(data))
|
||||||
# Conditions as string
|
# Conditions as string
|
||||||
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
|
page = self.database.paginate(
|
||||||
|
Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'"
|
||||||
|
)
|
||||||
self.assertEqual(page.number_of_objects, 10)
|
self.assertEqual(page.number_of_objects, 10)
|
||||||
# Conditions as expression
|
# Conditions as expression
|
||||||
page = self.database.paginate(
|
page = self.database.paginate(
|
||||||
|
@ -176,11 +194,13 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
)
|
)
|
||||||
self.assertEqual(page.number_of_objects, 10)
|
self.assertEqual(page.number_of_objects, 10)
|
||||||
# Conditions as Q object
|
# Conditions as Q object
|
||||||
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
|
page = self.database.paginate(
|
||||||
|
Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava")
|
||||||
|
)
|
||||||
self.assertEqual(page.number_of_objects, 10)
|
self.assertEqual(page.number_of_objects, 10)
|
||||||
|
|
||||||
def test_special_chars(self):
|
def test_special_chars(self):
|
||||||
s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
|
s = "אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
|
||||||
p = Person(first_name=s)
|
p = Person(first_name=s)
|
||||||
self.database.insert([p])
|
self.database.insert([p])
|
||||||
p = list(self.database.select("SELECT * from $table", Person))[0]
|
p = list(self.database.select("SELECT * from $table", Person))[0]
|
||||||
|
@ -200,12 +220,13 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
Database(self.database.db_name, username="default", password="wrong")
|
Database(self.database.db_name, username="default", password="wrong")
|
||||||
|
|
||||||
exc = cm.exception
|
exc = cm.exception
|
||||||
|
print(exc.code, exc.message)
|
||||||
if exc.code == 193: # ClickHouse version < 20.3
|
if exc.code == 193: # ClickHouse version < 20.3
|
||||||
self.assertTrue(exc.message.startswith("Wrong password for user default"))
|
self.assertTrue(exc.message.startswith("Wrong password for user default"))
|
||||||
elif exc.code == 516: # ClickHouse version >= 20.3
|
elif exc.code == 516: # ClickHouse version >= 20.3
|
||||||
self.assertTrue(exc.message.startswith("default: Authentication failed"))
|
self.assertTrue(exc.message.startswith("default: Authentication failed"))
|
||||||
else:
|
else:
|
||||||
raise Exception("Unexpected error code - %s" % exc.code)
|
raise Exception("Unexpected error code - %s %s" % (exc.code, exc.message))
|
||||||
|
|
||||||
def test_nonexisting_db(self):
|
def test_nonexisting_db(self):
|
||||||
db = Database("db_not_here", autocreate=False)
|
db = Database("db_not_here", autocreate=False)
|
||||||
|
@ -234,7 +255,9 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
|
|
||||||
with self.assertRaises(DatabaseException) as cm:
|
with self.assertRaises(DatabaseException) as cm:
|
||||||
self.database.create_table(EnginelessModel)
|
self.database.create_table(EnginelessModel)
|
||||||
self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
|
self.assertEqual(
|
||||||
|
str(cm.exception), "EnginelessModel class must define an engine"
|
||||||
|
)
|
||||||
|
|
||||||
def test_potentially_problematic_field_names(self):
|
def test_potentially_problematic_field_names(self):
|
||||||
class Model1(Model):
|
class Model1(Model):
|
||||||
|
@ -274,6 +297,8 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
|
|
||||||
query = "SELECT DISTINCT type FROM system.columns"
|
query = "SELECT DISTINCT type FROM system.columns"
|
||||||
for row in self.database.select(query):
|
for row in self.database.select(query):
|
||||||
|
if row.type.startswith("Map"):
|
||||||
|
continue # Not supported yet
|
||||||
ModelBase.create_ad_hoc_field(row.type)
|
ModelBase.create_ad_hoc_field(row.type)
|
||||||
|
|
||||||
def test_get_model_for_table(self):
|
def test_get_model_for_table(self):
|
||||||
|
@ -292,7 +317,12 @@ class DatabaseTestCase(TestCaseWithData):
|
||||||
query = "SELECT name FROM system.tables WHERE database='system'"
|
query = "SELECT name FROM system.tables WHERE database='system'"
|
||||||
for row in self.database.select(query):
|
for row in self.database.select(query):
|
||||||
print(row.name)
|
print(row.name)
|
||||||
model = self.database.get_model_for_table(row.name, system_table=True)
|
if row.name in ("distributed_ddl_queue",):
|
||||||
|
continue # Not supported
|
||||||
|
try:
|
||||||
|
model = self.database.get_model_for_table(row.name, system_table=True)
|
||||||
|
except NotImplementedError:
|
||||||
|
continue # Table contains an unsupported field type
|
||||||
self.assertTrue(model.is_system_model())
|
self.assertTrue(model.is_system_model())
|
||||||
self.assertTrue(model.is_read_only())
|
self.assertTrue(model.is_read_only())
|
||||||
self.assertEqual(model.table_name(), row.name)
|
self.assertEqual(model.table_name(), row.name)
|
||||||
|
|
|
@ -21,10 +21,10 @@ class DictionaryTestMixin:
|
||||||
logging.info("\t==> %s", result[0].value if result else "<empty>")
|
logging.info("\t==> %s", result[0].value if result else "<empty>")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _test_func(self, func, expected_value):
|
def _test_func(self, func, expected_value, *alternatives):
|
||||||
result = self._call_func(func)
|
result = self._call_func(func)
|
||||||
print("Comparing %s to %s" % (result[0].value, expected_value))
|
print("Comparing %s to %s" % (result[0].value, expected_value))
|
||||||
assert result[0].value == expected_value
|
assert result[0].value in (expected_value,) + alternatives
|
||||||
|
|
||||||
|
|
||||||
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
|
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
|
||||||
|
@ -117,10 +117,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
|
||||||
|
|
||||||
def test_dictgethierarchy(self):
|
def test_dictgethierarchy(self):
|
||||||
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
|
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
|
||||||
# Default behaviour changed in CH, but we're not really testing that
|
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [], [99])
|
||||||
default = self._call_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)))
|
|
||||||
assert isinstance(default, list)
|
|
||||||
assert len(default) <= 1 # either [] or [99]
|
|
||||||
|
|
||||||
def test_dictisin(self):
|
def test_dictisin(self):
|
||||||
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)
|
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)
|
||||||
|
|
|
@ -13,7 +13,13 @@ from clickhouse_orm.engines import (
|
||||||
SummingMergeTree,
|
SummingMergeTree,
|
||||||
TinyLog,
|
TinyLog,
|
||||||
)
|
)
|
||||||
from clickhouse_orm.fields import DateField, Int8Field, UInt8Field, UInt16Field, UInt32Field
|
from clickhouse_orm.fields import (
|
||||||
|
DateField,
|
||||||
|
Int8Field,
|
||||||
|
UInt8Field,
|
||||||
|
UInt16Field,
|
||||||
|
UInt32Field,
|
||||||
|
)
|
||||||
from clickhouse_orm.funcs import F
|
from clickhouse_orm.funcs import F
|
||||||
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
|
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
|
||||||
from clickhouse_orm.system_models import SystemPart
|
from clickhouse_orm.system_models import SystemPart
|
||||||
|
@ -30,7 +36,7 @@ class _EnginesHelperTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class EnginesTestCase(_EnginesHelperTestCase):
|
class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
def _create_and_insert(self, model_class):
|
def _create_and_insert(self, model_class, **kwargs):
|
||||||
self.database.create_table(model_class)
|
self.database.create_table(model_class)
|
||||||
self.database.insert(
|
self.database.insert(
|
||||||
[
|
[
|
||||||
|
@ -40,6 +46,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
event_group=13,
|
event_group=13,
|
||||||
event_count=7,
|
event_count=7,
|
||||||
event_version=1,
|
event_version=1,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -72,7 +79,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
|
|
||||||
def test_merge_tree_with_granularity(self):
|
def test_merge_tree_with_granularity(self):
|
||||||
class TestModel(SampleModel):
|
class TestModel(SampleModel):
|
||||||
engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
|
engine = MergeTree(
|
||||||
|
"date", ("date", "event_id", "event_group"), index_granularity=4096
|
||||||
|
)
|
||||||
|
|
||||||
self._create_and_insert(TestModel)
|
self._create_and_insert(TestModel)
|
||||||
|
|
||||||
|
@ -98,11 +107,15 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
|
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
|
||||||
)
|
)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
|
MergeTree(
|
||||||
|
"date", ("date", "event_id", "event_group"), replica_name="{replica}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_collapsing_merge_tree(self):
|
def test_collapsing_merge_tree(self):
|
||||||
class TestModel(SampleModel):
|
class TestModel(SampleModel):
|
||||||
engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
|
engine = CollapsingMergeTree(
|
||||||
|
"date", ("date", "event_id", "event_group"), "event_version"
|
||||||
|
)
|
||||||
|
|
||||||
self._create_and_insert(TestModel)
|
self._create_and_insert(TestModel)
|
||||||
|
|
||||||
|
@ -114,7 +127,9 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
|
|
||||||
def test_replacing_merge_tree(self):
|
def test_replacing_merge_tree(self):
|
||||||
class TestModel(SampleModel):
|
class TestModel(SampleModel):
|
||||||
engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
|
engine = ReplacingMergeTree(
|
||||||
|
"date", ("date", "event_id", "event_group"), "event_uversion"
|
||||||
|
)
|
||||||
|
|
||||||
self._create_and_insert(TestModel)
|
self._create_and_insert(TestModel)
|
||||||
|
|
||||||
|
@ -236,16 +251,20 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._create_and_insert(TestModel)
|
self._create_and_insert(TestModel)
|
||||||
self._create_and_insert(TestCollapseModel)
|
self._create_and_insert(TestCollapseModel, sign=1)
|
||||||
|
|
||||||
# Result order may be different, lets sort manually
|
# Result order may be different, lets sort manually
|
||||||
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
|
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
|
||||||
|
|
||||||
self.assertEqual(2, len(parts))
|
self.assertEqual(2, len(parts))
|
||||||
self.assertEqual("testcollapsemodel", parts[0].table)
|
self.assertEqual("testcollapsemodel", parts[0].table)
|
||||||
self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
|
self.assertEqual(
|
||||||
|
"(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", "")
|
||||||
|
)
|
||||||
self.assertEqual("testmodel", parts[1].table)
|
self.assertEqual("testmodel", parts[1].table)
|
||||||
self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
|
self.assertEqual(
|
||||||
|
"(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", "")
|
||||||
|
)
|
||||||
|
|
||||||
def test_custom_primary_key(self):
|
def test_custom_primary_key(self):
|
||||||
if self.database.server_version < (18, 1):
|
if self.database.server_version < (18, 1):
|
||||||
|
@ -269,13 +288,12 @@ class EnginesTestCase(_EnginesHelperTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._create_and_insert(TestModel)
|
self._create_and_insert(TestModel)
|
||||||
self._create_and_insert(TestCollapseModel)
|
self._create_and_insert(TestCollapseModel, sign=1)
|
||||||
|
|
||||||
self.assertEqual(2, len(list(SystemPart.get(self.database))))
|
self.assertEqual(2, len(list(SystemPart.get(self.database))))
|
||||||
|
|
||||||
|
|
||||||
class SampleModel(Model):
|
class SampleModel(Model):
|
||||||
|
|
||||||
date = DateField()
|
date = DateField()
|
||||||
event_id = UInt32Field()
|
event_id = UInt32Field()
|
||||||
event_group = UInt32Field()
|
event_group = UInt32Field()
|
||||||
|
@ -292,7 +310,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
||||||
engine.create_table_sql(self.database)
|
engine.create_table_sql(self.database)
|
||||||
|
|
||||||
exc = cm.exception
|
exc = cm.exception
|
||||||
self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
|
self.assertEqual(
|
||||||
|
str(exc), "Cannot create Distributed engine: specify an underlying table"
|
||||||
|
)
|
||||||
|
|
||||||
def test_with_table_name(self):
|
def test_with_table_name(self):
|
||||||
engine = Distributed("my_cluster", "foo")
|
engine = Distributed("my_cluster", "foo")
|
||||||
|
@ -317,7 +337,9 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
||||||
|
|
||||||
exc = cm.exception
|
exc = cm.exception
|
||||||
self.assertEqual(exc.code, 170)
|
self.assertEqual(exc.code, 170)
|
||||||
self.assertTrue(exc.message.startswith("Requested cluster 'cluster_name' not found"))
|
self.assertTrue(
|
||||||
|
exc.message.startswith("Requested cluster 'cluster_name' not found")
|
||||||
|
)
|
||||||
|
|
||||||
def test_verbose_engine_two_superclasses(self):
|
def test_verbose_engine_two_superclasses(self):
|
||||||
class TestModel2(SampleModel):
|
class TestModel2(SampleModel):
|
||||||
|
@ -368,11 +390,16 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
||||||
exc = cm.exception
|
exc = cm.exception
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(exc),
|
str(exc),
|
||||||
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
|
"When defining Distributed engine without the table_name ensure "
|
||||||
|
"that your model has a parent model",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
|
def _test_insert_select(
|
||||||
d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
|
self, local_to_distributed, test_model=TestModel, include_readonly=True
|
||||||
|
):
|
||||||
|
d_model = self._create_distributed(
|
||||||
|
"test_shard_localhost", underlying=test_model
|
||||||
|
)
|
||||||
|
|
||||||
if local_to_distributed:
|
if local_to_distributed:
|
||||||
to_insert, to_select = test_model, d_model
|
to_insert, to_select = test_model, d_model
|
||||||
|
@ -437,4 +464,6 @@ class DistributedTestCase(_EnginesHelperTestCase):
|
||||||
class TestModel2(self.TestModel):
|
class TestModel2(self.TestModel):
|
||||||
event_uversion = UInt8Field(readonly=True)
|
event_uversion = UInt8Field(readonly=True)
|
||||||
|
|
||||||
return self._test_insert_select(local_to_distributed=False, test_model=TestModel2, include_readonly=False)
|
return self._test_insert_select(
|
||||||
|
local_to_distributed=False, test_model=TestModel2, include_readonly=False
|
||||||
|
)
|
||||||
|
|
|
@ -333,6 +333,18 @@ class QuerySetTestCase(TestCaseWithData):
|
||||||
"(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')",
|
"(first_name = 'a') AND (greater(`height`, 1.7)) AND (last_name = 'b')",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_precedence_of_negation(self):
|
||||||
|
p = ~Q(first_name='a')
|
||||||
|
q = Q(last_name='b')
|
||||||
|
r = p & q
|
||||||
|
self.assertEqual(r.to_sql(Person), "(NOT (first_name = 'a')) AND (last_name = 'b')")
|
||||||
|
r = q & p
|
||||||
|
self.assertEqual(r.to_sql(Person), "(last_name = 'b') AND (NOT (first_name = 'a'))")
|
||||||
|
r = q | p
|
||||||
|
self.assertEqual(r.to_sql(Person), "(last_name = 'b') OR (NOT (first_name = 'a'))")
|
||||||
|
r = ~q & p
|
||||||
|
self.assertEqual(r.to_sql(Person), "(NOT (last_name = 'b')) AND (NOT (first_name = 'a'))")
|
||||||
|
|
||||||
def test_invalid_filter(self):
|
def test_invalid_filter(self):
|
||||||
qs = Person.objects_in(self.database)
|
qs = Person.objects_in(self.database)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user