From 6596517b253a5bcc826aab56c96c974652044b8e Mon Sep 17 00:00:00 2001 From: sswest Date: Wed, 1 Jun 2022 19:21:04 +0800 Subject: [PATCH] experimentn: support async client --- src/clickhouse_orm/aio/__init__.py | 0 src/clickhouse_orm/aio/database.py | 377 +++++++++++++++++++++++++++++ src/clickhouse_orm/query.py | 4 + 3 files changed, 381 insertions(+) create mode 100644 src/clickhouse_orm/aio/__init__.py create mode 100644 src/clickhouse_orm/aio/database.py diff --git a/src/clickhouse_orm/aio/__init__.py b/src/clickhouse_orm/aio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/clickhouse_orm/aio/database.py b/src/clickhouse_orm/aio/database.py new file mode 100644 index 0000000..ef15f51 --- /dev/null +++ b/src/clickhouse_orm/aio/database.py @@ -0,0 +1,377 @@ +import datetime +import logging +from math import ceil +from typing import Type, Optional, Generator + +import httpx +import pytz + +from clickhouse_orm.models import MODEL, ModelBase +from clickhouse_orm.utils import parse_tsv, import_submodules +from clickhouse_orm.database import Database, ServerError, DatabaseException, logger, Page + + +class AioDatabase(Database): + + def __init__( + self, db_name, db_url='http://localhost:18123/', username=None, + password=None, readonly=False, auto_create=True, timeout=60, + verify_ssl_cert=True, log_statements=False + ): + self.db_name = db_name + self.db_url = db_url + self.readonly = False + self._readonly = readonly + self.auto_create = auto_create + self.timeout = timeout + self.request_session = httpx.AsyncClient(verify=verify_ssl_cert, timeout=timeout) + if username: + self.request_session.auth = (username, password or '') + self.log_statements = log_statements + self.settings = {} + self._db_check = False + self.db_exists = False + + async def db_check(self): + if self._db_check: + return + self.db_exists = await self._is_existing_database() + if self._readonly: + if not self.db_exists: + raise DatabaseException( + 'Database does not exist, and cannot be created under readonly connection' + ) + self.connection_readonly = await self._is_connection_readonly() + self.readonly = True + elif self.auto_create and not self.db_exists: + await self.create_database() + self.server_version = await self._get_server_version() + if self.server_version > (1, 1, 53981): + self.server_timezone = await self._get_server_timezone() + else: + self.server_timezone = pytz.utc + self.has_codec_support = self.server_version >= (19, 1, 16) + self.has_low_cardinality_support = self.server_version >= (19, 0) + self._db_check = True + + async def close(self): + await self.request_session.aclose() + + async def _send(self, data, settings=None, stream=False): + r = await super()._send(data, settings, stream) + if r.status_code != 200: + raise ServerError(r.text) + return r + + async def count( + self, + model_class, + conditions=None + ) -> int: + """ + Counts the number of records in the model's table. + + - `model_class`: the model to count. + - `conditions`: optional SQL conditions (contents of the WHERE clause). + """ + from clickhouse_orm.query import Q + + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + query = 'SELECT count() FROM $table' + if conditions: + if isinstance(conditions, Q): + conditions = conditions.to_sql(model_class) + query += ' WHERE ' + str(conditions) + query = self._substitute(query, model_class) + r = await self._send(query) + return int(r.text) if r.text else 0 + + async def create_database(self): + """ + Creates the database on the ClickHouse server if it does not already exist. + """ + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) + self.db_exists = True + + async def drop_database(self): + """ + Deletes the database on the ClickHouse server. + """ + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + await self._send('DROP DATABASE `%s`' % self.db_name) + self.db_exists = False + + async def create_table(self, model_class: Type[MODEL]) -> None: + """ + Creates a table for the given model class, if it does not exist already. + """ + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + if model_class.is_system_model(): + raise DatabaseException("You can't create system table") + if getattr(model_class, 'engine') is None: + raise DatabaseException("%s class must define an engine" % model_class.__name__) + await self._send(model_class.create_table_sql(self)) + + async def drop_table(self, model_class: Type[MODEL]) -> None: + """ + Drops the database table of the given model class, if it exists. + """ + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + if model_class.is_system_model(): + raise DatabaseException("You can't drop system table") + await self._send(model_class.drop_table_sql(self)) + + async def does_table_exist(self, model_class: Type[MODEL]) -> bool: + """ + Checks whether a table for the given model class already exists. + Note that this only checks for existence of a table with the expected name. + """ + if not self._db_check: + raise DatabaseException( + 'The AioDatabase object must execute the `db_check` method before it can be used' + ) + + sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" + r = await self._send(sql % (self.db_name, model_class.table_name())) + return r.text.strip() == '1' + + async def get_model_for_table( + self, + table_name: str, + system_table: bool = False + ): + """ + Generates a model class from an existing table in the database. + This can be used for querying tables which don't have a corresponding model class, + for example system tables. + + - `table_name`: the table to create a model for + - `system_table`: whether the table is a system table, or belongs to the current database + """ + db_name = 'system' if system_table else self.db_name + sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) + lines = await self._send(sql) + fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()] + model = ModelBase.create_ad_hoc_model(fields, table_name) + if system_table: + model._system = model._readonly = True + return model + + async def insert(self, model_instances, batch_size=1000): + """ + Insert records into the database. + + - `model_instances`: any iterable containing instances of a single model class. + - `batch_size`: number of records to send per chunk (use a lower number if your records are very large). + """ + from io import BytesIO + + i = iter(model_instances) + try: + first_instance = next(i) + except StopIteration: + return # model_instances is empty + model_class = first_instance.__class__ + + if first_instance.is_read_only() or first_instance.is_system_model(): + raise DatabaseException("You can't insert into read only and system tables") + + fields_list = ','.join( + ['`%s`' % name for name in first_instance.fields(writable=True)]) + fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' + query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt) + + def gen(): + buf = BytesIO() + buf.write(self._substitute(query, model_class).encode('utf-8')) + first_instance.set_database(self) + buf.write(first_instance.to_db_string()) + # Collect lines in batches of batch_size + lines = 2 + for instance in i: + instance.set_database(self) + buf.write(instance.to_db_string()) + lines += 1 + if lines >= batch_size: + # Return the current batch of lines + yield buf.getvalue() + # Start a new batch + buf = BytesIO() + lines = 0 + # Return any remaining lines in partial batch + if lines: + yield buf.getvalue() + await self._send(gen()) + + async def select( + self, + query: str, + model_class: Optional[Type[MODEL]] = None, + settings: Optional[dict] = None + ) -> Generator[MODEL, None, None]: + """ + Performs a query and returns a generator of model instances. + + - `query`: the SQL query to execute. + - `model_class`: the model class matching the query's table, + or `None` for getting back instances of an ad-hoc model. + - `settings`: query settings to send as HTTP GET parameters + """ + query += ' FORMAT TabSeparatedWithNamesAndTypes' + query = self._substitute(query, model_class) + r = await self._send(query, settings, True) + try: + field_names, field_types = None, None + async for line in r.aiter_lines(): + # skip blank line left by WITH TOTALS modifier + if not field_names: + field_names = parse_tsv(line) + elif not field_types: + field_types = parse_tsv(line) + model_class = model_class or ModelBase.create_ad_hoc_model( + zip(field_names, field_types)) + elif line: + yield model_class.from_tsv(line, field_names, self.server_timezone, self) + finally: + await r.aclose() + + async def raw(self, query: str, settings: Optional[dict] = None, stream: bool = False) -> str: + """ + Performs a query and returns its output as text. + + - `query`: the SQL query to execute. + - `settings`: query settings to send as HTTP GET parameters + - `stream`: if true, the HTTP response from ClickHouse will be streamed. + """ + query = self._substitute(query, None) + return (await self._send(query, settings=settings, stream=stream)).text + + async def paginate( + self, + model_class: Type[MODEL], + order_by: str, + page_num: int = 1, + page_size: int = 100, + conditions=None, + settings: Optional[dict] = None + ): + """ + Selects records and returns a single page of model instances. + + - `model_class`: the model class matching the query's table, + or `None` for getting back instances of an ad-hoc model. + - `order_by`: columns to use for sorting the query (contents of the ORDER BY clause). + - `page_num`: the page number (1-based), or -1 to get the last page. + - `page_size`: number of records to return per page. + - `conditions`: optional SQL conditions (contents of the WHERE clause). + - `settings`: query settings to send as HTTP GET parameters + + The result is a namedtuple containing `objects` (list), `number_of_objects`, + `pages_total`, `number` (of the current page), and `page_size`. + """ + from clickhouse_orm.query import Q + + count = await self.count(model_class, conditions) + pages_total = int(ceil(count / float(page_size))) + if page_num == -1: + page_num = max(pages_total, 1) + elif page_num < 1: + raise ValueError('Invalid page number: %d' % page_num) + offset = (page_num - 1) * page_size + query = 'SELECT * FROM $table' + if conditions: + if isinstance(conditions, Q): + conditions = conditions.to_sql(model_class) + query += ' WHERE ' + str(conditions) + query += ' ORDER BY %s' % order_by + query += ' LIMIT %d, %d' % (offset, page_size) + query = self._substitute(query, model_class) + return Page( + objects=[r async for r in self.select(query, model_class, settings)] if count else [], + number_of_objects=count, + pages_total=pages_total, + number=page_num, + page_size=page_size + ) + + async def migrate(self, migrations_package_name, up_to=9999): + """ + Executes schema migrations. + + - `migrations_package_name` - fully qualified name of the Python package + containing the migrations. + - `up_to` - number of the last migration to apply. + """ + from ..migrations import MigrationHistory + + logger = logging.getLogger('migrations') + applied_migrations = await self._get_applied_migrations(migrations_package_name) + modules = import_submodules(migrations_package_name) + unapplied_migrations = set(modules.keys()) - applied_migrations + for name in sorted(unapplied_migrations): + logger.info('Applying migration %s...', name) + for operation in modules[name].operations: + operation.apply(self) + await self.insert([MigrationHistory( + package_name=migrations_package_name, + module_name=name, + applied=datetime.date.today() + )]) + if int(name[:4]) >= up_to: + break + + async def _is_existing_database(self): + r = await self._send( + "SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name + ) + return r.text.strip() == '1' + + async def _is_connection_readonly(self): + r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'") + return r.text.strip() != '0' + + async def _get_server_timezone(self): + try: + r = await self._send('SELECT timezone()') + return pytz.timezone(r.text.strip()) + except ServerError as e: + logger.exception('Cannot determine server timezone (%s), assuming UTC', e) + return pytz.utc + + async def _get_server_version(self, as_tuple=True): + try: + r = await self._send('SELECT version();') + ver = r.text + except ServerError as e: + logger.exception('Cannot determine server version (%s), assuming 1.1.0', e) + ver = '1.1.0' + return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver + + async def _get_applied_migrations(self, migrations_package_name): + from ..migrations import MigrationHistory + + await self.create_table(MigrationHistory) + query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name + query = self._substitute(query, MigrationHistory) + return set(obj.module_name async for obj in self.select(query)) diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index b2486dd..8b6cb09 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -349,6 +349,10 @@ class QuerySet(object): """ return self._database.select(self.as_sql(), self._model_cls) + async def __aiter__(self): + async for r in self._database.select(self.as_sql(), self._model_cls): + yield r + def __bool__(self): """ Returns true if this queryset matches any rows.