Chore: blacken

This commit is contained in:
olliemath 2021-07-27 23:14:56 +01:00
parent 87e7858a04
commit e60350259f
6 changed files with 524 additions and 407 deletions

View File

@ -11,16 +11,18 @@ from string import Template
import pytz import pytz
import logging import logging
logger = logging.getLogger('clickhouse_orm')
logger = logging.getLogger("clickhouse_orm")
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
class DatabaseException(Exception): class DatabaseException(Exception):
''' """
Raised when a database operation fails. Raised when a database operation fails.
''' """
pass pass
@ -28,6 +30,7 @@ class ServerError(DatabaseException):
""" """
Raised when a server returns an error. Raised when a server returns an error.
""" """
def __init__(self, message): def __init__(self, message):
self.code = None self.code = None
processed = self.get_error_code_msg(message) processed = self.get_error_code_msg(message)
@ -41,16 +44,22 @@ class ServerError(DatabaseException):
ERROR_PATTERNS = ( ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3 # ClickHouse prior to v19.3.3
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?), \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
\ e.what\(\)\ =\ (?P<type2>[^ \n]+) \ e.what\(\)\ =\ (?P<type2>[^ \n]+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
# ClickHouse v19.3.3+ # ClickHouse v19.3.3+
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+) \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
) )
@classmethod @classmethod
@ -65,7 +74,7 @@ class ServerError(DatabaseException):
match = pattern.match(full_error_message) match = pattern.match(full_error_message)
if match: if match:
# assert match.group('type1') == match.group('type2') # assert match.group('type1') == match.group('type2')
return int(match.group('code')), match.group('msg').strip() return int(match.group("code")), match.group("msg").strip()
return 0, full_error_message return 0, full_error_message
@ -75,15 +84,24 @@ class ServerError(DatabaseException):
class Database(object): class Database(object):
''' """
Database instances connect to a specific ClickHouse database for running queries, Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations. inserting data and other operations.
''' """
def __init__(self, db_name, db_url='http://localhost:8123/', def __init__(
username=None, password=None, readonly=False, autocreate=True, self,
timeout=60, verify_ssl_cert=True, log_statements=False): db_name,
''' db_url="http://localhost:8123/",
username=None,
password=None,
readonly=False,
autocreate=True,
timeout=60,
verify_ssl_cert=True,
log_statements=False,
):
"""
Initializes a database instance. Unless it's readonly, the database will be Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist. created on the ClickHouse server if it does not already exist.
@ -96,7 +114,7 @@ class Database(object):
- `timeout`: the connection timeout in seconds. - `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS. - `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged. - `log_statements`: when True, all database statements are logged.
''' """
self.db_name = db_name self.db_name = db_name
self.db_url = db_url self.db_url = db_url
self.readonly = False self.readonly = False
@ -104,14 +122,14 @@ class Database(object):
self.request_session = requests.Session() self.request_session = requests.Session()
self.request_session.verify = verify_ssl_cert self.request_session.verify = verify_ssl_cert
if username: if username:
self.request_session.auth = (username, password or '') self.request_session.auth = (username, password or "")
self.log_statements = log_statements self.log_statements = log_statements
self.settings = {} self.settings = {}
self.db_exists = False # this is required before running _is_existing_database self.db_exists = False # this is required before running _is_existing_database
self.db_exists = self._is_existing_database() self.db_exists = self._is_existing_database()
if readonly: if readonly:
if not self.db_exists: if not self.db_exists:
raise DatabaseException('Database does not exist, and cannot be created under readonly connection') raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
self.connection_readonly = self._is_connection_readonly() self.connection_readonly = self._is_connection_readonly()
self.readonly = True self.readonly = True
elif autocreate and not self.db_exists: elif autocreate and not self.db_exists:
@ -125,23 +143,23 @@ class Database(object):
self.has_low_cardinality_support = self.server_version >= (19, 0) self.has_low_cardinality_support = self.server_version >= (19, 0)
def create_database(self): def create_database(self):
''' """
Creates the database on the ClickHouse server if it does not already exist. Creates the database on the ClickHouse server if it does not already exist.
''' """
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True self.db_exists = True
def drop_database(self): def drop_database(self):
''' """
Deletes the database on the ClickHouse server. Deletes the database on the ClickHouse server.
''' """
self._send('DROP DATABASE `%s`' % self.db_name) self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False self.db_exists = False
def create_table(self, model_class): def create_table(self, model_class):
''' """
Creates a table for the given model class, if it does not exist already. Creates a table for the given model class, if it does not exist already.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't create system table") raise DatabaseException("You can't create system table")
if model_class.engine is None: if model_class.engine is None:
@ -149,32 +167,32 @@ class Database(object):
self._send(model_class.create_table_sql(self)) self._send(model_class.create_table_sql(self))
def drop_table(self, model_class): def drop_table(self, model_class):
''' """
Drops the database table of the given model class, if it exists. Drops the database table of the given model class, if it exists.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't drop system table") raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self)) self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class): def does_table_exist(self, model_class):
''' """
Checks whether a table for the given model class already exists. Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name. Note that this only checks for existence of a table with the expected name.
''' """
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name())) r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1' return r.text.strip() == "1"
def get_model_for_table(self, table_name, system_table=False): def get_model_for_table(self, table_name, system_table=False):
''' """
Generates a model class from an existing table in the database. Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class, This can be used for querying tables which don't have a corresponding model class,
for example system tables. for example system tables.
- `table_name`: the table to create a model for - `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database - `system_table`: whether the table is a system table, or belongs to the current database
''' """
db_name = 'system' if system_table else self.db_name db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines() lines = self._send(sql).iter_lines()
fields = [parse_tsv(line)[:2] for line in lines] fields = [parse_tsv(line)[:2] for line in lines]
@ -184,27 +202,28 @@ class Database(object):
return model return model
def add_setting(self, name, value): def add_setting(self, name, value):
''' """
Adds a database setting that will be sent with every request. Adds a database setting that will be sent with every request.
For example, `db.add_setting("max_execution_time", 10)` will For example, `db.add_setting("max_execution_time", 10)` will
limit query execution time to 10 seconds. limit query execution time to 10 seconds.
The name must be string, and the value is converted to string in case The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value. it isn't. To remove a setting, pass `None` as the value.
''' """
assert isinstance(name, str), 'Setting name must be a string' assert isinstance(name, str), "Setting name must be a string"
if value is None: if value is None:
self.settings.pop(name, None) self.settings.pop(name, None)
else: else:
self.settings[name] = str(value) self.settings[name] = str(value)
def insert(self, model_instances, batch_size=1000): def insert(self, model_instances, batch_size=1000):
''' """
Insert records into the database. Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class. - `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large). - `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
''' """
from io import BytesIO from io import BytesIO
i = iter(model_instances) i = iter(model_instances)
try: try:
first_instance = next(i) first_instance = next(i)
@ -215,14 +234,13 @@ class Database(object):
if first_instance.is_read_only() or first_instance.is_system_model(): if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables") raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join( fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
['`%s`' % name for name in first_instance.fields(writable=True)]) fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
def gen(): def gen():
buf = BytesIO() buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8')) buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self) first_instance.set_database(self)
buf.write(first_instance.to_db_string()) buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size # Collect lines in batches of batch_size
@ -240,35 +258,37 @@ class Database(object):
# Return any remaining lines in partial batch # Return any remaining lines in partial batch
if lines: if lines:
yield buf.getvalue() yield buf.getvalue()
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):
''' """
Counts the number of records in the model's table. Counts the number of records in the model's table.
- `model_class`: the model to count. - `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause). - `conditions`: optional SQL conditions (contents of the WHERE clause).
''' """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
query = "SELECT count() FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query) r = self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None): def select(self, query, model_class=None, settings=None):
''' """
Performs a query and returns a generator of model instances. Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model. or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
''' """
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query, settings, True) r = self._send(query, settings, True)
lines = r.iter_lines() lines = r.iter_lines()
@ -281,18 +301,18 @@ class Database(object):
yield model_class.from_tsv(line, field_names, self.server_timezone, self) yield model_class.from_tsv(line, field_names, self.server_timezone, self)
def raw(self, query, settings=None, stream=False): def raw(self, query, settings=None, stream=False):
''' """
Performs a query and returns its output as text. Performs a query and returns its output as text.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed. - `stream`: if true, the HTTP response from ClickHouse will be streamed.
''' """
query = self._substitute(query, None) query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None): def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
''' """
Selects records and returns a single page of model instances. Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
@ -305,54 +325,63 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`, The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
''' """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
if page_num == -1: if page_num == -1:
page_num = max(pages_total, 1) page_num = max(pages_total, 1)
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table' query = "SELECT * FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query += ' ORDER BY %s' % order_by query += " ORDER BY %s" % order_by
query += ' LIMIT %d, %d' % (offset, page_size) query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
return Page( return Page(
objects=list(self.select(query, model_class, settings)) if count else [], objects=list(self.select(query, model_class, settings)) if count else [],
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def migrate(self, migrations_package_name, up_to=9999): def migrate(self, migrations_package_name, up_to=9999):
''' """
Executes schema migrations. Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package - `migrations_package_name` - fully qualified name of the Python package
containing the migrations. containing the migrations.
- `up_to` - number of the last migration to apply. - `up_to` - number of the last migration to apply.
''' """
from .migrations import MigrationHistory from .migrations import MigrationHistory
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
applied_migrations = self._get_applied_migrations(migrations_package_name) applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name) modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations): for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name) logger.info("Applying migration %s...", name)
for operation in modules[name].operations: for operation in modules[name].operations:
operation.apply(self) operation.apply(self)
self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())]) self.insert(
[
MigrationHistory(
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
)
]
)
if int(name[:4]) >= up_to: if int(name[:4]) >= up_to:
break break
def _get_applied_migrations(self, migrations_package_name): def _get_applied_migrations(self, migrations_package_name):
from .migrations import MigrationHistory from .migrations import MigrationHistory
self.create_table(MigrationHistory) self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory) query = self._substitute(query, MigrationHistory)
@ -360,7 +389,7 @@ class Database(object):
def _send(self, data, settings=None, stream=False): def _send(self, data, settings=None, stream=False):
if isinstance(data, str): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode("utf-8")
if self.log_statements: if self.log_statements:
logger.info(data) logger.info(data)
params = self._build_params(settings) params = self._build_params(settings)
@ -373,50 +402,50 @@ class Database(object):
params = dict(settings or {}) params = dict(settings or {})
params.update(self.settings) params.update(self.settings)
if self.db_exists: if self.db_exists:
params['database'] = self.db_name params["database"] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error) # Send the readonly flag, unless the connection is already readonly (to prevent db error)
if self.readonly and not self.connection_readonly: if self.readonly and not self.connection_readonly:
params['readonly'] = '1' params["readonly"] = "1"
return params return params
def _substitute(self, query, model_class=None): def _substitute(self, query, model_class=None):
''' """
Replaces $db and $table placeholders in the query. Replaces $db and $table placeholders in the query.
''' """
if '$' in query: if "$" in query:
mapping = dict(db="`%s`" % self.db_name) mapping = dict(db="`%s`" % self.db_name)
if model_class: if model_class:
if model_class.is_system_model(): if model_class.is_system_model():
mapping['table'] = "`system`.`%s`" % model_class.table_name() mapping["table"] = "`system`.`%s`" % model_class.table_name()
else: else:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).safe_substitute(mapping) query = Template(query).safe_substitute(mapping)
return query return query
def _get_server_timezone(self): def _get_server_timezone(self):
try: try:
r = self._send('SELECT timezone()') r = self._send("SELECT timezone()")
return pytz.timezone(r.text.strip()) return pytz.timezone(r.text.strip())
except ServerError as e: except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e) logger.exception("Cannot determine server timezone (%s), assuming UTC", e)
return pytz.utc return pytz.utc
def _get_server_version(self, as_tuple=True): def _get_server_version(self, as_tuple=True):
try: try:
r = self._send('SELECT version();') r = self._send("SELECT version();")
ver = r.text ver = r.text
except ServerError as e: except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e) logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
ver = '1.1.0' ver = "1.1.0"
return tuple(int(n) for n in ver.split('.')) if as_tuple else ver return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
def _is_existing_database(self): def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name) r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
return r.text.strip() == '1' return r.text.strip() == "1"
def _is_connection_readonly(self): def _is_connection_readonly(self):
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'") r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0' return r.text.strip() != "0"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -4,51 +4,57 @@ import logging
from .utils import comma_join, get_subclass_names from .utils import comma_join, get_subclass_names
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger("clickhouse_orm")
class Engine(object): class Engine(object):
def create_table_sql(self, db): def create_table_sql(self, db):
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class TinyLog(Engine): class TinyLog(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'TinyLog' return "TinyLog"
class Log(Engine): class Log(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Log' return "Log"
class Memory(Engine): class Memory(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Memory' return "Memory"
class MergeTree(Engine): class MergeTree(Engine):
def __init__(
def __init__(self, date_col=None, order_by=(), sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple' sampling_expr=None,
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present' index_granularity=8192,
assert primary_key is None or type(primary_key) in (list, tuple), 'primary_key must be a list or tuple' replica_table_path=None,
assert partition_key is None or type(partition_key) in (list, tuple),\ replica_name=None,
'partition_key must be tuple or list if present' partition_key=None,
assert (replica_table_path is None) == (replica_name is None), \ primary_key=None,
'both replica_table_path and replica_name must be specified' ):
assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
assert primary_key is None or type(primary_key) in (list, tuple), "primary_key must be a list or tuple"
assert partition_key is None or type(partition_key) in (
list,
tuple,
), "partition_key must be tuple or list if present"
assert (replica_table_path is None) == (
replica_name is None
), "both replica_table_path and replica_name must be specified"
# These values conflict with each other (old and new syntax of table engines. # These values conflict with each other (old and new syntax of table engines.
# So let's control only one of them is given. # So let's control only one of them is given.
assert date_col or partition_key, "You must set either date_col or partition_key" assert date_col or partition_key, "You must set either date_col or partition_key"
self.date_col = date_col self.date_col = date_col
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,) self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
self.primary_key = primary_key self.primary_key = primary_key
self.order_by = order_by self.order_by = order_by
@ -60,26 +66,31 @@ class MergeTree(Engine):
# I changed field name for new reality and syntax # I changed field name for new reality and syntax
@property @property
def key_cols(self): def key_cols(self):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead') logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
return self.order_by return self.order_by
@key_cols.setter @key_cols.setter
def key_cols(self, value): def key_cols(self, value):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead') logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
self.order_by = value self.order_by = value
def create_table_sql(self, db): def create_table_sql(self, db):
name = self.__class__.__name__ name = self.__class__.__name__
if self.replica_name: if self.replica_name:
name = 'Replicated' + name name = "Replicated" + name
# In ClickHouse 1.1.54310 custom partitioning key was introduced # In ClickHouse 1.1.54310 custom partitioning key was introduced
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/ # https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
# Let's check version and use new syntax if available # Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310): if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \ partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
% (comma_join(self.partition_key, stringify=True), comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True)) comma_join(self.order_by, stringify=True),
)
if self.primary_key: if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
@ -92,14 +103,17 @@ class MergeTree(Engine):
elif not self.date_col: elif not self.date_col:
# Can't import it globally due to circular import # Can't import it globally due to circular import
from clickhouse_orm.database import DatabaseException from clickhouse_orm.database import DatabaseException
raise DatabaseException("Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax." raise DatabaseException(
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/") "Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
)
else: else:
partition_sql = '' partition_sql = ""
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql) return "%s(%s) %s" % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db): def _build_sql_params(self, db):
params = [] params = []
@ -114,19 +128,35 @@ class MergeTree(Engine):
params.append(self.date_col) params.append(self.date_col)
if self.sampling_expr: if self.sampling_expr:
params.append(self.sampling_expr) params.append(self.sampling_expr)
params.append('(%s)' % comma_join(self.order_by, stringify=True)) params.append("(%s)" % comma_join(self.order_by, stringify=True))
params.append(str(self.index_granularity)) params.append(str(self.index_granularity))
return params return params
class CollapsingMergeTree(MergeTree): class CollapsingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(CollapsingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, sign_col="sign",
replica_table_path, replica_name, partition_key, primary_key) sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(CollapsingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.sign_col = sign_col self.sign_col = sign_col
def _build_sql_params(self, db): def _build_sql_params(self, db):
@ -136,29 +166,61 @@ class CollapsingMergeTree(MergeTree):
class SummingMergeTree(MergeTree): class SummingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(SummingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, replica_table_path, summing_cols=None,
replica_name, partition_key, primary_key) sampling_expr=None,
assert type is None or type(summing_cols) in (list, tuple), 'summing_cols must be a list or tuple' index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(SummingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
assert type is None or type(summing_cols) in (list, tuple), "summing_cols must be a list or tuple"
self.summing_cols = summing_cols self.summing_cols = summing_cols
def _build_sql_params(self, db): def _build_sql_params(self, db):
params = super(SummingMergeTree, self)._build_sql_params(db) params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols: if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols)) params.append("(%s)" % comma_join(self.summing_cols))
return params return params
class ReplacingMergeTree(MergeTree): class ReplacingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), ver_col=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(ReplacingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, ver_col=None,
replica_table_path, replica_name, partition_key, primary_key) sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(ReplacingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.ver_col = ver_col self.ver_col = ver_col
def _build_sql_params(self, db): def _build_sql_params(self, db):
@ -176,8 +238,17 @@ class Buffer(Engine):
""" """
# Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes) # Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes)
def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, def __init__(
min_bytes=10000000, max_bytes=100000000): self,
main_model,
num_layers=16,
min_time=10,
max_time=100,
min_rows=10000,
max_rows=1000000,
min_bytes=10000000,
max_bytes=100000000,
):
self.main_model = main_model self.main_model = main_model
self.num_layers = num_layers self.num_layers = num_layers
self.min_time = min_time self.min_time = min_time
@ -190,10 +261,16 @@ class Buffer(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
# Overriden create_table_sql example: # Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' # sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
db.db_name, self.main_model.table_name(), self.num_layers, db.db_name,
self.min_time, self.max_time, self.min_rows, self.main_model.table_name(),
self.max_rows, self.min_bytes, self.max_bytes self.num_layers,
self.min_time,
self.max_time,
self.min_rows,
self.max_rows,
self.min_bytes,
self.max_bytes,
) )
return sql return sql
@ -224,6 +301,7 @@ class Distributed(Engine):
See full documentation here See full documentation here
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/ https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
""" """
def __init__(self, cluster, table=None, sharding_key=None): def __init__(self, cluster, table=None, sharding_key=None):
""" """
- `cluster`: what cluster to access data from - `cluster`: what cluster to access data from
@ -252,12 +330,11 @@ class Distributed(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
name = self.__class__.__name__ name = self.__class__.__name__
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params)) return "%s(%s)" % (name, ", ".join(params))
def _build_sql_params(self, db): def _build_sql_params(self, db):
if self.table_name is None: if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format( raise ValueError("Cannot create {} engine: specify an underlying table".format(self.__class__.__name__))
self.__class__.__name__))
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]] params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
if self.sharding_key: if self.sharding_key:

View File

@ -5,67 +5,67 @@ from .fields import DateField, StringField
from .models import BufferModel, Model from .models import BufferModel, Model
from .utils import get_subclass_names from .utils import get_subclass_names
logger = logging.getLogger('migrations') logger = logging.getLogger("migrations")
class Operation(): class Operation:
''' """
Base class for migration operations. Base class for migration operations.
''' """
def apply(self, database): def apply(self, database):
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation): class ModelOperation(Operation):
''' """
Base class for migration operations that work on a specific model. Base class for migration operations that work on a specific model.
''' """
def __init__(self, model_class): def __init__(self, model_class):
''' """
Initializer. Initializer.
''' """
self.model_class = model_class self.model_class = model_class
self.table_name = model_class.table_name() self.table_name = model_class.table_name()
def _alter_table(self, database, cmd): def _alter_table(self, database, cmd):
''' """
Utility for running ALTER TABLE commands. Utility for running ALTER TABLE commands.
''' """
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd) cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
logger.debug(cmd) logger.debug(cmd)
database.raw(cmd) database.raw(cmd)
class CreateTable(ModelOperation): class CreateTable(ModelOperation):
''' """
A migration operation that creates a table for a given model class. A migration operation that creates a table for a given model class.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Create table %s', self.table_name) logger.info(" Create table %s", self.table_name)
if issubclass(self.model_class, BufferModel): if issubclass(self.model_class, BufferModel):
database.create_table(self.model_class.engine.main_model) database.create_table(self.model_class.engine.main_model)
database.create_table(self.model_class) database.create_table(self.model_class)
class AlterTable(ModelOperation): class AlterTable(ModelOperation):
''' """
A migration operation that compares the table of a given model class to A migration operation that compares the table of a given model class to
the model's fields, and alters the table to match the model. The operation can: the model's fields, and alters the table to match the model. The operation can:
- add new columns - add new columns
- drop obsolete columns - drop obsolete columns
- modify column types - modify column types
Default values are not altered by this operation. Default values are not altered by this operation.
''' """
def _get_table_fields(self, database): def _get_table_fields(self, database):
query = "DESC `%s`.`%s`" % (database.db_name, self.table_name) query = "DESC `%s`.`%s`" % (database.db_name, self.table_name)
return [(row.name, row.type) for row in database.select(query)] return [(row.name, row.type) for row in database.select(query)]
def apply(self, database): def apply(self, database):
logger.info(' Alter table %s', self.table_name) logger.info(" Alter table %s", self.table_name)
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC, # Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
# ADD COLUMN ... AFTER doesn't affect it # ADD COLUMN ... AFTER doesn't affect it
@ -74,8 +74,8 @@ class AlterTable(ModelOperation):
# Identify fields that were deleted from the model # Identify fields that were deleted from the model
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields()) deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
for name in deleted_fields: for name in deleted_fields:
logger.info(' Drop column %s', name) logger.info(" Drop column %s", name)
self._alter_table(database, 'DROP COLUMN %s' % name) self._alter_table(database, "DROP COLUMN %s" % name)
del table_fields[name] del table_fields[name]
# Identify fields that were added to the model # Identify fields that were added to the model
@ -83,11 +83,11 @@ class AlterTable(ModelOperation):
for name, field in self.model_class.fields().items(): for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias) is_regular_field = not (field.materialized or field.alias)
if name not in table_fields: if name not in table_fields:
logger.info(' Add column %s', name) logger.info(" Add column %s", name)
assert prev_name, 'Cannot add a column to the beginning of the table' assert prev_name, "Cannot add a column to the beginning of the table"
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database)) cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field: if is_regular_field:
cmd += ' AFTER %s' % prev_name cmd += " AFTER %s" % prev_name
self._alter_table(database, cmd) self._alter_table(database, cmd)
if is_regular_field: if is_regular_field:
@ -99,24 +99,27 @@ class AlterTable(ModelOperation):
# The order of class attributes can be changed any time, so we can't count on it # The order of class attributes can be changed any time, so we can't count on it
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save # Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/clickhouse_orm/issues/47 # attribute position. Watch https://github.com/Infinidat/clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database) model_fields = {
for name, field in self.model_class.fields().items()} name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()
}
for field_name, field_sql in self._get_table_fields(database): for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment # All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement' assert field_name in model_fields, "Model fields and table columns in disagreement"
if field_sql != model_fields[field_name]: if field_sql != model_fields[field_name]:
logger.info(' Change type of column %s from %s to %s', field_name, field_sql, logger.info(
model_fields[field_name]) " Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name])) )
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
class AlterTableWithBuffer(ModelOperation): class AlterTableWithBuffer(ModelOperation):
''' """
A migration operation for altering a buffer table and its underlying on-disk table. A migration operation for altering a buffer table and its underlying on-disk table.
The buffer table is dropped, the on-disk table is altered, and then the buffer table The buffer table is dropped, the on-disk table is altered, and then the buffer table
is re-created. is re-created.
''' """
def apply(self, database): def apply(self, database):
if issubclass(self.model_class, BufferModel): if issubclass(self.model_class, BufferModel):
@ -128,149 +131,152 @@ class AlterTableWithBuffer(ModelOperation):
class DropTable(ModelOperation): class DropTable(ModelOperation):
''' """
A migration operation that drops the table of a given model class. A migration operation that drops the table of a given model class.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Drop table %s', self.table_name) logger.info(" Drop table %s", self.table_name)
database.drop_table(self.model_class) database.drop_table(self.model_class)
class AlterConstraints(ModelOperation): class AlterConstraints(ModelOperation):
''' """
A migration operation that adds new constraints from the model to the database A migration operation that adds new constraints from the model to the database
table, and drops obsolete ones. Constraints are identified by their names, so table, and drops obsolete ones. Constraints are identified by their names, so
a change in an existing constraint will not be detected unless its name was changed too. a change in an existing constraint will not be detected unless its name was changed too.
ClickHouse does not check that the constraints hold for existing data in the table. ClickHouse does not check that the constraints hold for existing data in the table.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name) logger.info(" Alter constraints for %s", self.table_name)
existing = self._get_constraint_names(database) existing = self._get_constraint_names(database)
# Go over constraints in the model # Go over constraints in the model
for constraint in self.model_class._constraints.values(): for constraint in self.model_class._constraints.values():
# Check if it's a new constraint # Check if it's a new constraint
if constraint.name not in existing: if constraint.name not in existing:
logger.info(' Add constraint %s', constraint.name) logger.info(" Add constraint %s", constraint.name)
self._alter_table(database, 'ADD %s' % constraint.create_table_sql()) self._alter_table(database, "ADD %s" % constraint.create_table_sql())
else: else:
existing.remove(constraint.name) existing.remove(constraint.name)
# Remaining constraints in `existing` are obsolete # Remaining constraints in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop constraint %s', name) logger.info(" Drop constraint %s", name)
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name) self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
def _get_constraint_names(self, database): def _get_constraint_names(self, database):
''' """
Returns a set containing the names of existing constraints in the table. Returns a set containing the names of existing constraints in the table.
''' """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
return set(matches) return set(matches)
class AlterIndexes(ModelOperation): class AlterIndexes(ModelOperation):
''' """
A migration operation that adds new indexes from the model to the database A migration operation that adds new indexes from the model to the database
table, and drops obsolete ones. Indexes are identified by their names, so table, and drops obsolete ones. Indexes are identified by their names, so
a change in an existing index will not be detected unless its name was changed too. a change in an existing index will not be detected unless its name was changed too.
''' """
def __init__(self, model_class, reindex=False): def __init__(self, model_class, reindex=False):
''' """
Initializer. Initializer.
By default ClickHouse does not build indexes over existing data, only for By default ClickHouse does not build indexes over existing data, only for
new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build
the indexes over the existing data. the indexes over the existing data.
''' """
super().__init__(model_class) super().__init__(model_class)
self.reindex = reindex self.reindex = reindex
def apply(self, database): def apply(self, database):
logger.info(' Alter indexes for %s', self.table_name) logger.info(" Alter indexes for %s", self.table_name)
existing = self._get_index_names(database) existing = self._get_index_names(database)
logger.info(existing) logger.info(existing)
# Go over indexes in the model # Go over indexes in the model
for index in self.model_class._indexes.values(): for index in self.model_class._indexes.values():
# Check if it's a new index # Check if it's a new index
if index.name not in existing: if index.name not in existing:
logger.info(' Add index %s', index.name) logger.info(" Add index %s", index.name)
self._alter_table(database, 'ADD %s' % index.create_table_sql()) self._alter_table(database, "ADD %s" % index.create_table_sql())
else: else:
existing.remove(index.name) existing.remove(index.name)
# Remaining indexes in `existing` are obsolete # Remaining indexes in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop index %s', name) logger.info(" Drop index %s", name)
self._alter_table(database, 'DROP INDEX `%s`' % name) self._alter_table(database, "DROP INDEX `%s`" % name)
# Reindex # Reindex
if self.reindex: if self.reindex:
logger.info(' Build indexes on table') logger.info(" Build indexes on table")
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name) database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
def _get_index_names(self, database): def _get_index_names(self, database):
''' """
Returns a set containing the names of existing indexes in the table. Returns a set containing the names of existing indexes in the table.
''' """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
return set(matches) return set(matches)
class RunPython(Operation): class RunPython(Operation):
''' """
A migration operation that executes a Python function. A migration operation that executes a Python function.
''' """
def __init__(self, func): def __init__(self, func):
''' """
Initializer. The given Python function will be called with a single Initializer. The given Python function will be called with a single
argument - the Database instance to apply the migration to. argument - the Database instance to apply the migration to.
''' """
assert callable(func), "'func' argument must be function" assert callable(func), "'func' argument must be function"
self._func = func self._func = func
def apply(self, database): def apply(self, database):
logger.info(' Executing python operation %s', self._func.__name__) logger.info(" Executing python operation %s", self._func.__name__)
self._func(database) self._func(database)
class RunSQL(Operation): class RunSQL(Operation):
''' """
A migration operation that executes arbitrary SQL statements. A migration operation that executes arbitrary SQL statements.
''' """
def __init__(self, sql): def __init__(self, sql):
''' """
Initializer. The given sql argument must be a valid SQL statement or Initializer. The given sql argument must be a valid SQL statement or
list of statements. list of statements.
''' """
if isinstance(sql, str): if isinstance(sql, str):
sql = [sql] sql = [sql]
assert isinstance(sql, list), "'sql' argument must be string or list of strings" assert isinstance(sql, list), "'sql' argument must be string or list of strings"
self._sql = sql self._sql = sql
def apply(self, database): def apply(self, database):
logger.info(' Executing raw SQL operations') logger.info(" Executing raw SQL operations")
for item in self._sql: for item in self._sql:
database.raw(item) database.raw(item)
class MigrationHistory(Model): class MigrationHistory(Model):
''' """
A model for storing which migrations were already applied to the containing database. A model for storing which migrations were already applied to the containing database.
''' """
package_name = StringField() package_name = StringField()
module_name = StringField() module_name = StringField()
applied = DateField() applied = DateField()
engine = MergeTree('applied', ('package_name', 'module_name')) engine = MergeTree("applied", ("package_name", "module_name"))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'infi_clickhouse_orm_migrations' return "infi_clickhouse_orm_migrations"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -11,77 +11,77 @@ from .funcs import F
from .query import QuerySet from .query import QuerySet
from .utils import NO_VALUE, arg_to_sql, get_subclass_names, parse_tsv from .utils import NO_VALUE, arg_to_sql, get_subclass_names, parse_tsv
logger = getLogger('clickhouse_orm') logger = getLogger("clickhouse_orm")
class Constraint: class Constraint:
''' """
Defines a model constraint. Defines a model constraint.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr): def __init__(self, expr):
''' """
Initializer. Expects an expression that ClickHouse will verify when inserting data. Initializer. Expects an expression that ClickHouse will verify when inserting data.
''' """
self.expr = expr self.expr = expr
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this constraint during table creation. Returns the SQL statement for defining this constraint during table creation.
''' """
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
class Index: class Index:
''' """
Defines a data-skipping index. Defines a data-skipping index.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr, type, granularity): def __init__(self, expr, type, granularity):
''' """
Initializer. Initializer.
- `expr` - a column, expression, or tuple of columns and expressions to index. - `expr` - a column, expression, or tuple of columns and expressions to index.
- `type` - the index type. Use one of the following methods to specify the type: - `type` - the index type. Use one of the following methods to specify the type:
`Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`. `Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`.
- `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine). - `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine).
''' """
self.expr = expr self.expr = expr
self.type = type self.type = type
self.granularity = granularity self.granularity = granularity
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this index during table creation. Returns the SQL statement for defining this index during table creation.
''' """
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity) return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
@staticmethod @staticmethod
def minmax(): def minmax():
''' """
An index that stores extremes of the specified expression (if the expression is tuple, then it stores An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
''' """
return 'minmax' return "minmax"
@staticmethod @staticmethod
def set(max_rows): def set(max_rows):
''' """
An index that stores unique values of the specified expression (no more than max_rows rows, An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data. on a block of data.
''' """
return 'set(%d)' % max_rows return "set(%d)" % max_rows
@staticmethod @staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing all ngrams from a block of data. An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions. Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -90,12 +90,12 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return "ngrambf_v1(%d, %d, %d, %d)" % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing string tokens. Tokens are sequences An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters. separated by non-alphanumeric characters.
@ -103,24 +103,24 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return "tokenbf_v1(%d, %d, %d)" % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
def bloom_filter(false_positive=0.025): def bloom_filter(false_positive=0.025):
''' """
An index that stores a Bloom filter containing values of the index expression. An index that stores a Bloom filter containing values of the index expression.
- `false_positive` - the probability (between 0 and 1) of receiving a false positive - `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter response from the filter
''' """
return 'bloom_filter(%f)' % false_positive return "bloom_filter(%f)" % false_positive
class ModelBase(type): class ModelBase(type):
''' """
A metaclass for ORM models. It adds the _fields list to model classes. A metaclass for ORM models. It adds the _fields list to model classes.
''' """
ad_hoc_model_cache = {} ad_hoc_model_cache = {}
@ -168,7 +168,7 @@ class ModelBase(type):
_indexes=indexes, _indexes=indexes,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults, _defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults _has_funcs_as_defaults=has_funcs_as_defaults,
) )
model = super(ModelBase, metacls).__new__(metacls, str(name), bases, attrs) model = super(ModelBase, metacls).__new__(metacls, str(name), bases, attrs)
@ -180,11 +180,11 @@ class ModelBase(type):
return model return model
@classmethod @classmethod
def create_ad_hoc_model(metacls, fields, model_name='AdHocModel'): def create_ad_hoc_model(metacls, fields, model_name="AdHocModel"):
# fields is a list of tuples (name, db_type) # fields is a list of tuples (name, db_type)
# Check if model exists in cache # Check if model exists in cache
fields = list(fields) fields = list(fields)
cache_key = model_name + ' ' + str(fields) cache_key = model_name + " " + str(fields)
if cache_key in metacls.ad_hoc_model_cache: if cache_key in metacls.ad_hoc_model_cache:
return metacls.ad_hoc_model_cache[cache_key] return metacls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class # Create an ad hoc model class
@ -201,58 +201,55 @@ class ModelBase(type):
import clickhouse_orm.fields as orm_fields import clickhouse_orm.fields as orm_fields
# Enums # Enums
if db_type.startswith('Enum'): if db_type.startswith("Enum"):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# DateTime with timezone # DateTime with timezone
if db_type.startswith('DateTime('): if db_type.startswith("DateTime("):
timezone = db_type[9:-1] timezone = db_type[9:-1]
return orm_fields.DateTimeField( return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
timezone=timezone[1:-1] if timezone else None
)
# DateTime64 # DateTime64
if db_type.startswith('DateTime64('): if db_type.startswith("DateTime64("):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')] precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
return orm_fields.DateTime64Field( return orm_fields.DateTime64Field(
precision=int(precision), precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
timezone=timezone[0][1:-1] if timezone else None
) )
# Arrays # Arrays
if db_type.startswith('Array'): if db_type.startswith("Array"):
inner_field = metacls.create_ad_hoc_field(db_type[6 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[6:-1])
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# Tuples (poor man's version - convert to array) # Tuples (poor man's version - convert to array)
if db_type.startswith('Tuple'): if db_type.startswith("Tuple"):
types = [s.strip() for s in db_type[6 : -1].split(',')] types = [s.strip() for s in db_type[6:-1].split(",")]
assert len(set(types)) == 1, 'No support for mixed types in tuples - ' + db_type assert len(set(types)) == 1, "No support for mixed types in tuples - " + db_type
inner_field = metacls.create_ad_hoc_field(types[0]) inner_field = metacls.create_ad_hoc_field(types[0])
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# FixedString # FixedString
if db_type.startswith('FixedString'): if db_type.startswith("FixedString"):
length = int(db_type[12 : -1]) length = int(db_type[12:-1])
return orm_fields.FixedStringField(length) return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128 # Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'): if db_type.startswith("Decimal"):
p = db_type.index('(') p = db_type.index("(")
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')] args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
field_class = getattr(orm_fields, db_type[:p] + 'Field') field_class = getattr(orm_fields, db_type[:p] + "Field")
return field_class(*args) return field_class(*args)
# Nullable # Nullable
if db_type.startswith('Nullable'): if db_type.startswith("Nullable"):
inner_field = metacls.create_ad_hoc_field(db_type[9 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[9:-1])
return orm_fields.NullableField(inner_field) return orm_fields.NullableField(inner_field)
# LowCardinality # LowCardinality
if db_type.startswith('LowCardinality'): if db_type.startswith("LowCardinality"):
inner_field = metacls.create_ad_hoc_field(db_type[15 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[15:-1])
return orm_fields.LowCardinalityField(inner_field) return orm_fields.LowCardinalityField(inner_field)
# Simple fields # Simple fields
name = db_type + 'Field' name = db_type + "Field"
if not hasattr(orm_fields, name): if not hasattr(orm_fields, name):
raise NotImplementedError('No field class for %s' % db_type) raise NotImplementedError("No field class for %s" % db_type)
return getattr(orm_fields, name)() return getattr(orm_fields, name)()
class Model(metaclass=ModelBase): class Model(metaclass=ModelBase):
''' """
A base class for ORM models. Each model class represent a ClickHouse table. For example: A base class for ORM models. Each model class represent a ClickHouse table. For example:
class CPUStats(Model): class CPUStats(Model):
@ -260,7 +257,7 @@ class Model(metaclass=ModelBase):
cpu_id = UInt16Field() cpu_id = UInt16Field()
cpu_percent = Float32Field() cpu_percent = Float32Field()
engine = Memory() engine = Memory()
''' """
engine = None engine = None
@ -273,12 +270,12 @@ class Model(metaclass=ModelBase):
_database = None _database = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
''' """
Creates a model instance, using keyword arguments as field values. Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type, Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised. invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`. Unrecognized field names will cause an `AttributeError`.
''' """
super(Model, self).__init__() super(Model, self).__init__()
# Assign default values # Assign default values
self.__dict__.update(self._defaults) self.__dict__.update(self._defaults)
@ -288,13 +285,13 @@ class Model(metaclass=ModelBase):
if field: if field:
setattr(self, name, value) setattr(self, name, value)
else: else:
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) raise AttributeError("%s does not have a field called %s" % (self.__class__.__name__, name))
def __setattr__(self, name, value): def __setattr__(self, name, value):
''' """
When setting a field value, converts the value to its Pythonic type and validates it. When setting a field value, converts the value to its Pythonic type and validates it.
This may raise a `ValueError`. This may raise a `ValueError`.
''' """
field = self.get_field(name) field = self.get_field(name)
if field and (value != NO_VALUE): if field and (value != NO_VALUE):
try: try:
@ -307,77 +304,78 @@ class Model(metaclass=ModelBase):
super(Model, self).__setattr__(name, value) super(Model, self).__setattr__(name, value)
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
# This can not be imported globally due to circular import # This can not be imported globally due to circular import
from .database import Database from .database import Database
assert isinstance(db, Database), "database must be database.Database instance" assert isinstance(db, Database), "database must be database.Database instance"
self._database = db self._database = db
def get_database(self): def get_database(self):
''' """
Gets the `Database` that this model instance belongs to. Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it. Returns `None` unless the instance was read from the database or written to it.
''' """
return self._database return self._database
def get_field(self, name): def get_field(self, name):
''' """
Gets a `Field` instance given its name, or `None` if not found. Gets a `Field` instance given its name, or `None` if not found.
''' """
return self._fields.get(name) return self._fields.get(name)
@classmethod @classmethod
def table_name(cls): def table_name(cls):
''' """
Returns the model's database table name. By default this is the Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use class name converted to lowercase. Override this if you want to use
a different table name. a different table name.
''' """
return cls.__name__.lower() return cls.__name__.lower()
@classmethod @classmethod
def has_funcs_as_defaults(cls): def has_funcs_as_defaults(cls):
''' """
Return True if some of the model's fields use a function expression Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances. as a default value. This requires special handling when inserting instances.
''' """
return cls._has_funcs_as_defaults return cls._has_funcs_as_defaults
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
# Fields # Fields
items = [] items = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db))) items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints # Constraints
for c in cls._constraints.values(): for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql()) items.append(" %s" % c.create_table_sql())
# Indexes # Indexes
for i in cls._indexes.values(): for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql()) items.append(" %s" % i.create_table_sql())
parts.append(',\n'.join(items)) parts.append(",\n".join(items))
# Engine # Engine
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
@classmethod @classmethod
def drop_table_sql(cls, db): def drop_table_sql(cls, db):
''' """
Returns the SQL command for deleting this model's table. Returns the SQL command for deleting this model's table.
''' """
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
''' """
Create a model instance from a tab-separated line. The line may or may not include a newline. Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them. The `field_names` list must match the fields defined in the model, but does not have to include all of them.
@ -385,12 +383,12 @@ class Model(metaclass=ModelBase):
- `field_names`: names of the model fields in the data. - `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones. - `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to. - `database`: if given, sets the database that this instance belongs to.
''' """
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
field = getattr(cls, name) field = getattr(cls, name)
field_timezone = getattr(field, 'timezone', None) or timezone_in_use field_timezone = getattr(field, "timezone", None) or timezone_in_use
kwargs[name] = field.to_python(next(values), field_timezone) kwargs[name] = field.to_python(next(values), field_timezone)
obj = cls(**kwargs) obj = cls(**kwargs)
@ -400,45 +398,45 @@ class Model(metaclass=ModelBase):
return obj return obj
def to_tsv(self, include_readonly=True): def to_tsv(self, include_readonly=True):
''' """
Returns the instance's column values as a tab-separated line. A newline is not included. Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) return "\t".join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_tskv(self, include_readonly=True): def to_tskv(self, include_readonly=True):
''' """
Returns the instance's column keys and values as a tab-separated line. A newline is not included. Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted. Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
parts = [] parts = []
for name, field in fields.items(): for name, field in fields.items():
if data[name] != NO_VALUE: if data[name] != NO_VALUE:
parts.append(name + '=' + field.to_db_string(data[name], quote=False)) parts.append(name + "=" + field.to_db_string(data[name], quote=False))
return '\t'.join(parts) return "\t".join(parts)
def to_db_string(self): def to_db_string(self):
''' """
Returns the instance as a bytestring ready to be inserted into the database. Returns the instance as a bytestring ready to be inserted into the database.
''' """
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n' s += "\n"
return s.encode('utf-8') return s.encode("utf-8")
def to_dict(self, include_readonly=True, field_names=None): def to_dict(self, include_readonly=True, field_names=None):
''' """
Returns the instance's column values as a dict. Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional) - `field_names`: an iterable of field names to return (optional)
''' """
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
if field_names is not None: if field_names is not None:
@ -449,56 +447,58 @@ class Model(metaclass=ModelBase):
@classmethod @classmethod
def objects_in(cls, database): def objects_in(cls, database):
''' """
Returns a `QuerySet` for selecting instances of this model class. Returns a `QuerySet` for selecting instances of this model class.
''' """
return QuerySet(cls, database) return QuerySet(cls, database)
@classmethod @classmethod
def fields(cls, writable=False): def fields(cls, writable=False):
''' """
Returns an `OrderedDict` of the model's fields (from name to `Field` instance). Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included. If `writable` is true, only writable fields are included.
Callers should not modify the dictionary. Callers should not modify the dictionary.
''' """
# noinspection PyProtectedMember,PyUnresolvedReferences # noinspection PyProtectedMember,PyUnresolvedReferences
return cls._writable_fields if writable else cls._fields return cls._writable_fields if writable else cls._fields
@classmethod @classmethod
def is_read_only(cls): def is_read_only(cls):
''' """
Returns true if the model is marked as read only. Returns true if the model is marked as read only.
''' """
return cls._readonly return cls._readonly
@classmethod @classmethod
def is_system_model(cls): def is_system_model(cls):
''' """
Returns true if the model represents a system table. Returns true if the model represents a system table.
''' """
return cls._system return cls._system
class BufferModel(Model): class BufferModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, parts = [
cls.engine.main_model.table_name())] "CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
% (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db) engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str) parts.append(engine_str)
return ' '.join(parts) return " ".join(parts)
class MergeModel(Model): class MergeModel(Model):
''' """
Model for Merge engine Model for Merge engine
Predefines virtual _table column an controls that rows can't be inserted to this table type Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
''' """
readonly = True readonly = True
# Virtual fields can't be inserted into database # Virtual fields can't be inserted into database
@ -506,19 +506,20 @@ class MergeModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
if name != '_table': if name != "_table":
cols.append(' %s %s' % (name, field.get_sql(db=db))) cols.append(" %s %s" % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols)) parts.append(",\n".join(cols))
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
# TODO: base class for models that require specific engine # TODO: base class for models that require specific engine
@ -529,10 +530,10 @@ class DistributedModel(Model):
""" """
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db) res = super(DistributedModel, self).set_database(db)
return res return res
@ -575,33 +576,37 @@ class DistributedModel(Model):
return return
# find out all the superclasses of the Model that store any data # find out all the superclasses of the Model that store any data
storage_models = [b for b in cls.__bases__ if issubclass(b, Model) storage_models = [b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)]
and not issubclass(b, DistributedModel)]
if not storage_models: if not storage_models:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has a parent model") "When defining Distributed engine without the table_name " "ensure that your model has a parent model"
)
if len(storage_models) > 1: if len(storage_models) > 1:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has exactly one non-distributed superclass") "When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass"
)
# enable correct SQL for engine # enable correct SQL for engine
cls.engine.table = storage_models[0] cls.engine.table = storage_models[0]
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table() cls.fix_engine_table()
parts = [ parts = [
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format( "CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
db.db_name, cls.table_name(), cls.engine.table_name), db.db_name, cls.table_name(), cls.engine.table_name
'ENGINE = ' + cls.engine.create_table_sql(db)] ),
return '\n'.join(parts) "ENGINE = " + cls.engine.create_table_sql(db),
]
return "\n".join(parts)
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -16,7 +16,8 @@ class SystemPart(Model):
This model operates only fields, described in the reference. Other fields are ignored. This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/ https://clickhouse.tech/docs/en/system_tables/system.parts/
""" """
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
_readonly = True _readonly = True
_system = True _system = True
@ -51,12 +52,13 @@ class SystemPart(Model):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'parts' return "parts"
""" """
Next methods return SQL for some operations, which can be done with partitions Next methods return SQL for some operations, which can be done with partitions
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
""" """
def _partition_operation_sql(self, operation, settings=None, from_part=None): def _partition_operation_sql(self, operation, settings=None, from_part=None):
""" """
Performs some operation over partition Performs some operation over partition
@ -83,7 +85,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DETACH', settings=settings) return self._partition_operation_sql("DETACH", settings=settings)
def drop(self, settings=None): def drop(self, settings=None):
""" """
@ -93,7 +95,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DROP', settings=settings) return self._partition_operation_sql("DROP", settings=settings)
def attach(self, settings=None): def attach(self, settings=None):
""" """
@ -103,7 +105,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('ATTACH', settings=settings) return self._partition_operation_sql("ATTACH", settings=settings)
def freeze(self, settings=None): def freeze(self, settings=None):
""" """
@ -113,7 +115,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FREEZE', settings=settings) return self._partition_operation_sql("FREEZE", settings=settings)
def fetch(self, zookeeper_path, settings=None): def fetch(self, zookeeper_path, settings=None):
""" """
@ -124,7 +126,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path) return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
@classmethod @classmethod
def get(cls, database, conditions=""): def get(cls, database, conditions=""):
@ -140,9 +142,12 @@ class SystemPart(Model):
assert isinstance(conditions, str), "conditions must be a string" assert isinstance(conditions, str), "conditions must be a string"
if conditions: if conditions:
conditions += " AND" conditions += " AND"
field_names = ','.join(cls.fields()) field_names = ",".join(cls.fields())
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" % return database.select(
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls) "SELECT %s FROM `system`.%s WHERE %s database='%s'"
% (field_names, cls.table_name(), conditions, database.db_name),
model_class=cls,
)
@classmethod @classmethod
def get_active(cls, database, conditions=""): def get_active(cls, database, conditions=""):
@ -155,8 +160,8 @@ class SystemPart(Model):
Returns: A list of SystemPart objects Returns: A list of SystemPart objects
""" """
if conditions: if conditions:
conditions += ' AND ' conditions += " AND "
conditions += 'active' conditions += "active"
return SystemPart.get(database, conditions=conditions) return SystemPart.get(database, conditions=conditions)

View File

@ -4,26 +4,18 @@ import pkgutil
import re import re
from datetime import date, datetime, tzinfo, timedelta from datetime import date, datetime, tzinfo, timedelta
SPECIAL_CHARS = { SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"}
"\b" : "\\b",
"\f" : "\\f",
"\r" : "\\r",
"\n" : "\\n",
"\t" : "\\t",
"\0" : "\\0",
"\\" : "\\\\",
"'" : "\\'"
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True): def escape(value, quote=True):
''' """
If the value is a string, escapes any special characters and optionally If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number), surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one. converts it to one.
''' """
def escape_one(match): def escape_one(match):
return SPECIAL_CHARS[match.group(0)] return SPECIAL_CHARS[match.group(0)]
@ -35,11 +27,11 @@ def escape(value, quote=True):
def unescape(value): def unescape(value):
return codecs.escape_decode(value)[0].decode('utf-8') return codecs.escape_decode(value)[0].decode("utf-8")
def string_or_func(obj): def string_or_func(obj):
return obj.to_sql() if hasattr(obj, 'to_sql') else obj return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg): def arg_to_sql(arg):
@ -49,6 +41,7 @@ def arg_to_sql(arg):
None, numbers, timezones, arrays/iterables. None, numbers, timezones, arrays/iterables.
""" """
from clickhouse_orm import Field, StringField, DateTimeField, F, QuerySet from clickhouse_orm import Field, StringField, DateTimeField, F, QuerySet
if isinstance(arg, F): if isinstance(arg, F):
return arg.to_sql() return arg.to_sql()
if isinstance(arg, Field): if isinstance(arg, Field):
@ -66,22 +59,22 @@ def arg_to_sql(arg):
if isinstance(arg, tzinfo): if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None)) return StringField().to_db_string(arg.tzname(None))
if arg is None: if arg is None:
return 'NULL' return "NULL"
if isinstance(arg, QuerySet): if isinstance(arg, QuerySet):
return "(%s)" % arg return "(%s)" % arg
if isinstance(arg, tuple): if isinstance(arg, tuple):
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')' return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
if is_iterable(arg): if is_iterable(arg):
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']' return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
return str(arg) return str(arg)
def parse_tsv(line): def parse_tsv(line):
if isinstance(line, bytes): if isinstance(line, bytes):
line = line.decode() line = line.decode()
if line and line[-1] == '\n': if line and line[-1] == "\n":
line = line[:-1] line = line[:-1]
return [unescape(value) for value in line.split(str('\t'))] return [unescape(value) for value in line.split(str("\t"))]
def parse_array(array_string): def parse_array(array_string):
@ -91,17 +84,17 @@ def parse_array(array_string):
"(1,2,3)" ==> [1, 2, 3] "(1,2,3)" ==> [1, 2, 3]
""" """
# Sanity check # Sanity check
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])': if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
raise ValueError('Invalid array string: "%s"' % array_string) raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace # Drop opening brace
array_string = array_string[1:] array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left # Go over the string, lopping off each value at the beginning until nothing is left
values = [] values = []
while True: while True:
if array_string in '])': if array_string in "])":
# End of array # End of array
return values return values
elif array_string[0] in ', ': elif array_string[0] in ", ":
# In between values # In between values
array_string = array_string[1:] array_string = array_string[1:]
elif array_string[0] == "'": elif array_string[0] == "'":
@ -110,12 +103,12 @@ def parse_array(array_string):
if match is None: if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string) raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1]) values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():] array_string = array_string[match.end() :]
else: else:
# Start of non-quoted value, find its end # Start of non-quoted value, find its end
match = re.search(r",|\]", array_string) match = re.search(r",|\]", array_string)
values.append(array_string[0 : match.start()]) values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1:] array_string = array_string[match.end() - 1 :]
def import_submodules(package_name): def import_submodules(package_name):
@ -124,7 +117,7 @@ def import_submodules(package_name):
""" """
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
return { return {
name: importlib.import_module(package_name + '.' + name) name: importlib.import_module(package_name + "." + name)
for _, name, _ in pkgutil.iter_modules(package.__path__) for _, name, _ in pkgutil.iter_modules(package.__path__)
} }
@ -134,9 +127,9 @@ def comma_join(items, stringify=False):
Joins an iterable of strings with commas. Joins an iterable of strings with commas.
""" """
if stringify: if stringify:
return ', '.join(str(item) for item in items) return ", ".join(str(item) for item in items)
else: else:
return ', '.join(items) return ", ".join(items)
def is_iterable(obj): def is_iterable(obj):
@ -152,16 +145,18 @@ def is_iterable(obj):
def get_subclass_names(locals, base_class): def get_subclass_names(locals, base_class):
from inspect import isclass from inspect import isclass
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
class NoValue: class NoValue:
''' """
A sentinel for fields with an expression for a default value, A sentinel for fields with an expression for a default value,
that were not assigned a value yet. that were not assigned a value yet.
''' """
def __repr__(self): def __repr__(self):
return 'NO_VALUE' return "NO_VALUE"
NO_VALUE = NoValue() NO_VALUE = NoValue()