experimentn: support async client

This commit is contained in:
sswest 2022-06-01 19:21:04 +08:00
parent 9ade7fa6a5
commit 6596517b25
3 changed files with 381 additions and 0 deletions

View File

View File

@ -0,0 +1,377 @@
import datetime
import logging
from math import ceil
from typing import Type, Optional, Generator
import httpx
import pytz
from clickhouse_orm.models import MODEL, ModelBase
from clickhouse_orm.utils import parse_tsv, import_submodules
from clickhouse_orm.database import Database, ServerError, DatabaseException, logger, Page
class AioDatabase(Database):
def __init__(
self, db_name, db_url='http://localhost:18123/', username=None,
password=None, readonly=False, auto_create=True, timeout=60,
verify_ssl_cert=True, log_statements=False
):
self.db_name = db_name
self.db_url = db_url
self.readonly = False
self._readonly = readonly
self.auto_create = auto_create
self.timeout = timeout
self.request_session = httpx.AsyncClient(verify=verify_ssl_cert, timeout=timeout)
if username:
self.request_session.auth = (username, password or '')
self.log_statements = log_statements
self.settings = {}
self._db_check = False
self.db_exists = False
async def db_check(self):
if self._db_check:
return
self.db_exists = await self._is_existing_database()
if self._readonly:
if not self.db_exists:
raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection'
)
self.connection_readonly = await self._is_connection_readonly()
self.readonly = True
elif self.auto_create and not self.db_exists:
await self.create_database()
self.server_version = await self._get_server_version()
if self.server_version > (1, 1, 53981):
self.server_timezone = await self._get_server_timezone()
else:
self.server_timezone = pytz.utc
self.has_codec_support = self.server_version >= (19, 1, 16)
self.has_low_cardinality_support = self.server_version >= (19, 0)
self._db_check = True
async def close(self):
await self.request_session.aclose()
async def _send(self, data, settings=None, stream=False):
r = await super()._send(data, settings, stream)
if r.status_code != 200:
raise ServerError(r.text)
return r
async def count(
self,
model_class,
conditions=None
) -> int:
"""
Counts the number of records in the model's table.
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
"""
from clickhouse_orm.query import Q
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
query = 'SELECT count() FROM $table'
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query = self._substitute(query, model_class)
r = await self._send(query)
return int(r.text) if r.text else 0
async def create_database(self):
"""
Creates the database on the ClickHouse server if it does not already exist.
"""
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
self.db_exists = True
async def drop_database(self):
"""
Deletes the database on the ClickHouse server.
"""
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
await self._send('DROP DATABASE `%s`' % self.db_name)
self.db_exists = False
async def create_table(self, model_class: Type[MODEL]) -> None:
"""
Creates a table for the given model class, if it does not exist already.
"""
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
await self._send(model_class.create_table_sql(self))
async def drop_table(self, model_class: Type[MODEL]) -> None:
"""
Drops the database table of the given model class, if it exists.
"""
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
if model_class.is_system_model():
raise DatabaseException("You can't drop system table")
await self._send(model_class.drop_table_sql(self))
async def does_table_exist(self, model_class: Type[MODEL]) -> bool:
"""
Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name.
"""
if not self._db_check:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
)
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = await self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1'
async def get_model_for_table(
self,
table_name: str,
system_table: bool = False
):
"""
Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class,
for example system tables.
- `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database
"""
db_name = 'system' if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = await self._send(sql)
fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()]
model = ModelBase.create_ad_hoc_model(fields, table_name)
if system_table:
model._system = model._readonly = True
return model
async def insert(self, model_instances, batch_size=1000):
"""
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
"""
from io import BytesIO
i = iter(model_instances)
try:
first_instance = next(i)
except StopIteration:
return # model_instances is empty
model_class = first_instance.__class__
if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join(
['`%s`' % name for name in first_instance.fields(writable=True)])
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
def gen():
buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8'))
first_instance.set_database(self)
buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size
lines = 2
for instance in i:
instance.set_database(self)
buf.write(instance.to_db_string())
lines += 1
if lines >= batch_size:
# Return the current batch of lines
yield buf.getvalue()
# Start a new batch
buf = BytesIO()
lines = 0
# Return any remaining lines in partial batch
if lines:
yield buf.getvalue()
await self._send(gen())
async def select(
self,
query: str,
model_class: Optional[Type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
"""
Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
"""
query += ' FORMAT TabSeparatedWithNamesAndTypes'
query = self._substitute(query, model_class)
r = await self._send(query, settings, True)
try:
field_names, field_types = None, None
async for line in r.aiter_lines():
# skip blank line left by WITH TOTALS modifier
if not field_names:
field_names = parse_tsv(line)
elif not field_types:
field_types = parse_tsv(line)
model_class = model_class or ModelBase.create_ad_hoc_model(
zip(field_names, field_types))
elif line:
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
finally:
await r.aclose()
async def raw(self, query: str, settings: Optional[dict] = None, stream: bool = False) -> str:
"""
Performs a query and returns its output as text.
- `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
"""
query = self._substitute(query, None)
return (await self._send(query, settings=settings, stream=stream)).text
async def paginate(
self,
model_class: Type[MODEL],
order_by: str,
page_num: int = 1,
page_size: int = 100,
conditions=None,
settings: Optional[dict] = None
):
"""
Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause).
- `page_num`: the page number (1-based), or -1 to get the last page.
- `page_size`: number of records to return per page.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
- `settings`: query settings to send as HTTP GET parameters
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
"""
from clickhouse_orm.query import Q
count = await self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = max(pages_total, 1)
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table'
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size)
query = self._substitute(query, model_class)
return Page(
objects=[r async for r in self.select(query, model_class, settings)] if count else [],
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
)
async def migrate(self, migrations_package_name, up_to=9999):
"""
Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package
containing the migrations.
- `up_to` - number of the last migration to apply.
"""
from ..migrations import MigrationHistory
logger = logging.getLogger('migrations')
applied_migrations = await self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name)
for operation in modules[name].operations:
operation.apply(self)
await self.insert([MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today()
)])
if int(name[:4]) >= up_to:
break
async def _is_existing_database(self):
r = await self._send(
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
)
return r.text.strip() == '1'
async def _is_connection_readonly(self):
r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0'
async def _get_server_timezone(self):
try:
r = await self._send('SELECT timezone()')
return pytz.timezone(r.text.strip())
except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e)
return pytz.utc
async def _get_server_version(self, as_tuple=True):
try:
r = await self._send('SELECT version();')
ver = r.text
except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
async def _get_applied_migrations(self, migrations_package_name):
from ..migrations import MigrationHistory
await self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
return set(obj.module_name async for obj in self.select(query))

View File

@ -349,6 +349,10 @@ class QuerySet(object):
"""
return self._database.select(self.as_sql(), self._model_cls)
async def __aiter__(self):
async for r in self._database.select(self.as_sql(), self._model_cls):
yield r
def __bool__(self):
"""
Returns true if this queryset matches any rows.