Migrate code style to Black

This commit is contained in:
sw 2022-06-04 21:25:34 +08:00
parent df2d778919
commit d22683f28c
13 changed files with 1188 additions and 1548 deletions

View File

@ -1,526 +0,0 @@
Class Reference
===============
infi.clickhouse_orm.database
----------------------------
### Database
#### Database(db_name, db_url="http://localhost:8123/", username=None, password=None, readonly=False)
Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist.
- `db_name`: name of the database to connect to.
- `db_url`: URL of the ClickHouse server.
- `username`: optional connection credentials.
- `password`: optional connection credentials.
- `readonly`: use a read-only connection.
#### count(model_class, conditions=None)
Counts the number of records in the model's table.
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
#### create_database()
Creates the database on the ClickHouse server if it does not already exist.
#### create_table(model_class)
Creates a table for the given model class, if it does not exist already.
#### drop_database()
Deletes the database on the ClickHouse server.
#### drop_table(model_class)
Drops the database table of the given model class, if it exists.
#### insert(model_instances, batch_size=1000)
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
#### migrate(migrations_package_name, up_to=9999)
Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package
containing the migrations.
- `up_to` - number of the last migration to apply.
#### paginate(model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None)
Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause).
- `page_num`: the page number (1-based), or -1 to get the last page.
- `page_size`: number of records to return per page.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
- `settings`: query settings to send as HTTP GET parameters
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
#### raw(query, settings=None, stream=False)
Performs a query and returns its output as text.
- `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
#### select(query, model_class=None, settings=None)
Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
### DatabaseException
Extends Exception
Raised when a database operation fails.
infi.clickhouse_orm.models
--------------------------
### Model
A base class for ORM models.
#### Model(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### Model.create_table_sql(db)
Returns the SQL command for creating a table for this model.
#### Model.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### Model.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
- `database`: if given, sets the database that this instance belongs to.
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### Model.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### Model.table_name()
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
### BufferModel
Extends Model
#### BufferModel(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### BufferModel.create_table_sql(db)
Returns the SQL command for creating a table for this model.
#### BufferModel.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### BufferModel.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
- `database`: if given, sets the database that this instance belongs to.
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### BufferModel.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### BufferModel.table_name()
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
infi.clickhouse_orm.fields
--------------------------
### Field
Abstract base class for all field types.
#### Field(default=None, alias=None, materialized=None)
### StringField
Extends Field
#### StringField(default=None, alias=None, materialized=None)
### DateField
Extends Field
#### DateField(default=None, alias=None, materialized=None)
### DateTimeField
Extends Field
#### DateTimeField(default=None, alias=None, materialized=None)
### BaseIntField
Extends Field
Abstract base class for all integer-type fields.
#### BaseIntField(default=None, alias=None, materialized=None)
### BaseFloatField
Extends Field
Abstract base class for all float-type fields.
#### BaseFloatField(default=None, alias=None, materialized=None)
### BaseEnumField
Extends Field
Abstract base class for all enum-type fields.
#### BaseEnumField(enum_cls, default=None, alias=None, materialized=None)
### ArrayField
Extends Field
#### ArrayField(inner_field, default=None, alias=None, materialized=None)
### FixedStringField
Extends StringField
#### FixedStringField(length, default=None, alias=None, materialized=None)
### UInt8Field
Extends BaseIntField
#### UInt8Field(default=None, alias=None, materialized=None)
### UInt16Field
Extends BaseIntField
#### UInt16Field(default=None, alias=None, materialized=None)
### UInt32Field
Extends BaseIntField
#### UInt32Field(default=None, alias=None, materialized=None)
### UInt64Field
Extends BaseIntField
#### UInt64Field(default=None, alias=None, materialized=None)
### Int8Field
Extends BaseIntField
#### Int8Field(default=None, alias=None, materialized=None)
### Int16Field
Extends BaseIntField
#### Int16Field(default=None, alias=None, materialized=None)
### Int32Field
Extends BaseIntField
#### Int32Field(default=None, alias=None, materialized=None)
### Int64Field
Extends BaseIntField
#### Int64Field(default=None, alias=None, materialized=None)
### Float32Field
Extends BaseFloatField
#### Float32Field(default=None, alias=None, materialized=None)
### Float64Field
Extends BaseFloatField
#### Float64Field(default=None, alias=None, materialized=None)
### Enum8Field
Extends BaseEnumField
#### Enum8Field(enum_cls, default=None, alias=None, materialized=None)
### Enum16Field
Extends BaseEnumField
#### Enum16Field(enum_cls, default=None, alias=None, materialized=None)
infi.clickhouse_orm.engines
---------------------------
### Engine
### TinyLog
Extends Engine
### Log
Extends Engine
### Memory
Extends Engine
### MergeTree
Extends Engine
#### MergeTree(date_col, key_cols, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### Buffer
Extends Engine
Here we define Buffer engine
Read more here https://clickhouse.tech/reference_en.html#Buffer
#### Buffer(main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, min_bytes=10000000, max_bytes=100000000)
### CollapsingMergeTree
Extends MergeTree
#### CollapsingMergeTree(date_col, key_cols, sign_col, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### SummingMergeTree
Extends MergeTree
#### SummingMergeTree(date_col, key_cols, summing_cols=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### ReplacingMergeTree
Extends MergeTree
#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
infi.clickhouse_orm.query
-------------------------
### QuerySet
#### QuerySet(model_cls, database)
#### conditions_as_sql(prewhere=True)
Return the contents of the queryset's WHERE or `PREWHERE` clause.
#### count()
Returns the number of matching model instances.
#### exclude(**kwargs)
Returns a new QuerySet instance that excludes all rows matching the conditions.
#### filter(**kwargs)
Returns a new QuerySet instance that includes only rows matching the conditions.
#### only(*field_names)
Limit the query to return only the specified field names.
Useful when there are large fields that are not needed,
or for creating a subquery to use with an IN operator.
#### order_by(*field_names)
Returns a new QuerySet instance with the ordering changed.
#### order_by_as_sql()
Return the contents of the queryset's ORDER BY clause.
#### query()
Return the the queryset as SQL.

View File

@ -15,6 +15,7 @@ from clickhouse_orm.database import Database, ServerError, DatabaseException, lo
# pylint: disable=C0116 # pylint: disable=C0116
class AioDatabase(Database): class AioDatabase(Database):
_client_class = httpx.AsyncClient _client_class = httpx.AsyncClient
@ -25,7 +26,7 @@ class AioDatabase(Database):
if self._readonly: if self._readonly:
if not self.db_exists: if not self.db_exists:
raise DatabaseException( raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection' "Database does not exist, and cannot be created under readonly connection"
) )
self.connection_readonly = await self._is_connection_readonly() self.connection_readonly = await self._is_connection_readonly()
self.readonly = True self.readonly = True
@ -44,10 +45,7 @@ class AioDatabase(Database):
await self.request_session.aclose() await self.request_session.aclose()
async def _send( async def _send(
self, self, data: str | bytes | AsyncGenerator, settings: dict = None, stream: bool = False
data: str | bytes | AsyncGenerator,
settings: dict = None,
stream: bool = False
): ):
r = await super()._send(data, settings, stream) r = await super()._send(data, settings, stream)
if r.status_code != 200: if r.status_code != 200:
@ -55,11 +53,7 @@ class AioDatabase(Database):
raise ServerError(r.text) raise ServerError(r.text)
return r return r
async def count( async def count(self, model_class: type[MODEL], conditions=None) -> int:
self,
model_class: type[MODEL],
conditions=None
) -> int:
""" """
Counts the number of records in the model's table. Counts the number of records in the model's table.
@ -70,14 +64,14 @@ class AioDatabase(Database):
if not self._init: if not self._init:
raise DatabaseException( raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used' "The AioDatabase object must execute the init method before it can be used"
) )
query = 'SELECT count() FROM $table' query = "SELECT count() FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = await self._send(query) r = await self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
@ -86,14 +80,14 @@ class AioDatabase(Database):
""" """
Creates the database on the ClickHouse server if it does not already exist. Creates the database on the ClickHouse server if it does not already exist.
""" """
await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) await self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True self.db_exists = True
async def drop_database(self): async def drop_database(self):
""" """
Deletes the database on the ClickHouse server. Deletes the database on the ClickHouse server.
""" """
await self._send('DROP DATABASE `%s`' % self.db_name) await self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False self.db_exists = False
async def create_table(self, model_class: type[MODEL]) -> None: async def create_table(self, model_class: type[MODEL]) -> None:
@ -102,7 +96,7 @@ class AioDatabase(Database):
""" """
if not self._init: if not self._init:
raise DatabaseException( raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used' "The AioDatabase object must execute the init method before it can be used"
) )
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't create system table") raise DatabaseException("You can't create system table")
@ -110,7 +104,7 @@ class AioDatabase(Database):
raise DatabaseException( raise DatabaseException(
"Creating a temporary table must be within the lifetime of a session " "Creating a temporary table must be within the lifetime of a session "
) )
if getattr(model_class, 'engine') is None: if getattr(model_class, "engine") is None:
raise DatabaseException(f"%s class must define an engine" % model_class.__name__) raise DatabaseException(f"%s class must define an engine" % model_class.__name__)
await self._send(model_class.create_table_sql(self)) await self._send(model_class.create_table_sql(self))
@ -121,7 +115,7 @@ class AioDatabase(Database):
""" """
if not self._init: if not self._init:
raise DatabaseException( raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used' "The AioDatabase object must execute the init method before it can be used"
) )
await self._send(model_class.create_temporary_table_sql(self, table_name)) await self._send(model_class.create_temporary_table_sql(self, table_name))
@ -132,7 +126,7 @@ class AioDatabase(Database):
""" """
if not self._init: if not self._init:
raise DatabaseException( raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used' "The AioDatabase object must execute the init method before it can be used"
) )
if model_class.is_system_model(): if model_class.is_system_model():
@ -146,18 +140,14 @@ class AioDatabase(Database):
""" """
if not self._init: if not self._init:
raise DatabaseException( raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used' "The AioDatabase object must execute the init method before it can be used"
) )
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = await self._send(sql % (self.db_name, model_class.table_name())) r = await self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1' return r.text.strip() == "1"
async def get_model_for_table( async def get_model_for_table(self, table_name: str, system_table: bool = False):
self,
table_name: str,
system_table: bool = False
):
""" """
Generates a model class from an existing table in the database. Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class, This can be used for querying tables which don't have a corresponding model class,
@ -166,7 +156,7 @@ class AioDatabase(Database):
- `table_name`: the table to create a model for - `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database - `system_table`: whether the table is a system table, or belongs to the current database
""" """
db_name = 'system' if system_table else self.db_name db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = await self._send(sql) lines = await self._send(sql)
fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()] fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()]
@ -192,14 +182,13 @@ class AioDatabase(Database):
if first_instance.is_read_only() or first_instance.is_system_model(): if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables") raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join( fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
['`%s`' % name for name in first_instance.fields(writable=True)]) fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
async def gen(): async def gen():
buf = BytesIO() buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8')) buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self) first_instance.set_database(self)
buf.write(first_instance.to_db_string()) buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size # Collect lines in batches of batch_size
@ -217,13 +206,11 @@ class AioDatabase(Database):
# Return any remaining lines in partial batch # Return any remaining lines in partial batch
if lines: if lines:
yield buf.getvalue() yield buf.getvalue()
await self._send(gen()) await self._send(gen())
async def select( async def select(
self, self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
query: str,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
) -> AsyncGenerator[MODEL, None]: ) -> AsyncGenerator[MODEL, None]:
""" """
Performs a query and returns a generator of model instances. Performs a query and returns a generator of model instances.
@ -233,7 +220,7 @@ class AioDatabase(Database):
or `None` for getting back instances of an ad-hoc model. or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
""" """
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = await self._send(query, settings, True) r = await self._send(query, settings, True)
try: try:
@ -245,7 +232,8 @@ class AioDatabase(Database):
elif not field_types: elif not field_types:
field_types = parse_tsv(line) field_types = parse_tsv(line)
model_class = model_class or ModelBase.create_ad_hoc_model( model_class = model_class or ModelBase.create_ad_hoc_model(
zip(field_names, field_types)) zip(field_names, field_types)
)
elif line.strip(): elif line.strip():
yield model_class.from_tsv(line, field_names, self.server_timezone, self) yield model_class.from_tsv(line, field_names, self.server_timezone, self)
except StopIteration: except StopIteration:
@ -271,7 +259,7 @@ class AioDatabase(Database):
page_num: int = 1, page_num: int = 1,
page_size: int = 100, page_size: int = 100,
conditions=None, conditions=None,
settings: Optional[dict] = None settings: Optional[dict] = None,
): ):
""" """
Selects records and returns a single page of model instances. Selects records and returns a single page of model instances.
@ -294,22 +282,22 @@ class AioDatabase(Database):
if page_num == -1: if page_num == -1:
page_num = max(pages_total, 1) page_num = max(pages_total, 1)
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table' query = "SELECT * FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query += ' ORDER BY %s' % order_by query += " ORDER BY %s" % order_by
query += ' LIMIT %d, %d' % (offset, page_size) query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
return Page( return Page(
objects=[r async for r in self.select(query, model_class, settings)] if count else [], objects=[r async for r in self.select(query, model_class, settings)] if count else [],
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
async def migrate(self, migrations_package_name, up_to=9999): async def migrate(self, migrations_package_name, up_to=9999):
@ -322,19 +310,23 @@ class AioDatabase(Database):
""" """
from ..migrations import MigrationHistory from ..migrations import MigrationHistory
logger = logging.getLogger('migrations') logger = logging.getLogger("migrations")
applied_migrations = await self._get_applied_migrations(migrations_package_name) applied_migrations = await self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name) modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations): for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name) logger.info("Applying migration %s...", name)
for operation in modules[name].operations: for operation in modules[name].operations:
operation.apply(self) operation.apply(self)
await self.insert([MigrationHistory( await self.insert(
package_name=migrations_package_name, [
module_name=name, MigrationHistory(
applied=datetime.date.today() package_name=migrations_package_name,
)]) module_name=name,
applied=datetime.date.today(),
)
]
)
if int(name[:4]) >= up_to: if int(name[:4]) >= up_to:
break break
@ -342,28 +334,28 @@ class AioDatabase(Database):
r = await self._send( r = await self._send(
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name "SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
) )
return r.text.strip() == '1' return r.text.strip() == "1"
async def _is_connection_readonly(self): async def _is_connection_readonly(self):
r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'") r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0' return r.text.strip() != "0"
async def _get_server_timezone(self): async def _get_server_timezone(self):
try: try:
r = await self._send('SELECT timezone()') r = await self._send("SELECT timezone()")
return pytz.timezone(r.text.strip()) return pytz.timezone(r.text.strip())
except ServerError as err: except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err) logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
return pytz.utc return pytz.utc
async def _get_server_version(self, as_tuple=True): async def _get_server_version(self, as_tuple=True):
try: try:
r = await self._send('SELECT version();') r = await self._send("SELECT version();")
ver = r.text ver = r.text
except ServerError as err: except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
ver = '1.1.0' ver = "1.1.0"
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
async def _get_applied_migrations(self, migrations_package_name): async def _get_applied_migrations(self, migrations_package_name):
from ..migrations import MigrationHistory from ..migrations import MigrationHistory

View File

@ -11,10 +11,10 @@ class Point:
self.y = float(y) self.y = float(y)
def __repr__(self): def __repr__(self):
return f'<Point x={self.x} y={self.y}>' return f"<Point x={self.x} y={self.y}>"
def to_db_string(self): def to_db_string(self):
return f'({self.x},{self.y})' return f"({self.x},{self.y})"
class Ring: class Ring:
@ -29,16 +29,16 @@ class Ring:
return len(self.array) return len(self.array)
def __repr__(self): def __repr__(self):
return f'<Ring {self.to_db_string()}>' return f"<Ring {self.to_db_string()}>"
def to_db_string(self): def to_db_string(self):
return f'[{",".join(pt.to_db_string() for pt in self.array)}]' return f'[{",".join(pt.to_db_string() for pt in self.array)}]'
def parse_point(array_string: str) -> Point: def parse_point(array_string: str) -> Point:
if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')': if len(array_string) < 2 or array_string[0] != "(" or array_string[-1] != ")":
raise ValueError('Invalid point string: "%s"' % array_string) raise ValueError('Invalid point string: "%s"' % array_string)
x, y = array_string.strip('()').split(',') x, y = array_string.strip("()").split(",")
return Point(x, y) return Point(x, y)
@ -47,14 +47,14 @@ def parse_ring(array_string: str) -> Ring:
raise ValueError('Invalid ring string: "%s"' % array_string) raise ValueError('Invalid ring string: "%s"' % array_string)
ring = [] ring = []
for point in POINT_REGEX.finditer(array_string): for point in POINT_REGEX.finditer(array_string):
x, y = point.group('x'), point.group('y') x, y = point.group("x"), point.group("y")
ring.append(Point(x, y)) ring.append(Point(x, y))
return Ring(ring) return Ring(ring)
class PointField(Field): class PointField(Field):
class_default = Point(0, 0) class_default = Point(0, 0)
db_type = 'Point' db_type = "Point"
def __init__( def __init__(
self, self,
@ -63,7 +63,7 @@ class PointField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
super().__init__(default, alias, materialized, readonly, codec, db_column) super().__init__(default, alias, materialized, readonly, codec, db_column)
self.inner_field = Float64Field() self.inner_field = Float64Field()
@ -73,10 +73,10 @@ class PointField(Field):
value = parse_point(value) value = parse_point(value)
elif isinstance(value, (tuple, list)): elif isinstance(value, (tuple, list)):
if len(value) != 2: if len(value) != 2:
raise ValueError('PointField takes 2 value, but %s were given' % len(value)) raise ValueError("PointField takes 2 value, but %s were given" % len(value))
value = Point(value[0], value[1]) value = Point(value[0], value[1])
if not isinstance(value, Point): if not isinstance(value, Point):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
return value return value
def validate(self, value): def validate(self, value):
@ -91,7 +91,7 @@ class PointField(Field):
class RingField(Field): class RingField(Field):
class_default = [Point(0, 0)] class_default = [Point(0, 0)]
db_type = 'Ring' db_type = "Ring"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, str): if isinstance(value, str):
@ -100,11 +100,11 @@ class RingField(Field):
ring = [] ring = []
for point in value: for point in value:
if len(point) != 2: if len(point) != 2:
raise ValueError('Point takes 2 value, but %s were given' % len(value)) raise ValueError("Point takes 2 value, but %s were given" % len(value))
ring.append(Point(point[0], point[1])) ring.append(Point(point[0], point[1]))
value = Ring(ring) value = Ring(ring)
if not isinstance(value, Ring): if not isinstance(value, Ring):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
return value return value
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):

View File

@ -16,8 +16,8 @@ from .utils import parse_tsv, import_submodules
from .session import ctx_session_id, ctx_session_timeout from .session import ctx_session_id, ctx_session_timeout
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger("clickhouse_orm")
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
class DatabaseException(Exception): class DatabaseException(Exception):
@ -30,6 +30,7 @@ class ServerError(DatabaseException):
""" """
Raised when a server returns an error. Raised when a server returns an error.
""" """
def __init__(self, message): def __init__(self, message):
self.code = None self.code = None
processed = self.get_error_code_msg(message) processed = self.get_error_code_msg(message)
@ -43,21 +44,30 @@ class ServerError(DatabaseException):
ERROR_PATTERNS = ( ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3 # ClickHouse prior to v19.3.3
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?), \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
\ e.what\(\)\ =\ (?P<type2>[^ \n]+) \ e.what\(\)\ =\ (?P<type2>[^ \n]+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
# ClickHouse v19.3.3+ # ClickHouse v19.3.3+
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+) \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
# ClickHouse v21+ # ClickHouse v21+
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+). Code:\ (?P<code>\d+).
\ (?P<type1>[^ \n]+):\ (?P<msg>.+) \ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
) )
@classmethod @classmethod
@ -72,7 +82,7 @@ class ServerError(DatabaseException):
match = pattern.match(full_error_message) match = pattern.match(full_error_message)
if match: if match:
# assert match.group('type1') == match.group('type2') # assert match.group('type1') == match.group('type2')
return int(match.group('code')), match.group('msg').strip() return int(match.group("code")), match.group("msg").strip()
return 0, full_error_message return 0, full_error_message
@ -86,11 +96,21 @@ class Database:
Database instances connect to a specific ClickHouse database for running queries, Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations. inserting data and other operations.
""" """
_client_class = httpx.Client _client_class = httpx.Client
def __init__(self, db_name, db_url='http://localhost:8123/', def __init__(
username=None, password=None, readonly=False, auto_create=True, self,
timeout=60, verify_ssl_cert=True, log_statements=False): db_name,
db_url="http://localhost:8123/",
username=None,
password=None,
readonly=False,
auto_create=True,
timeout=60,
verify_ssl_cert=True,
log_statements=False,
):
""" """
Initializes a database instance. Unless it's readonly, the database will be Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist. created on the ClickHouse server if it does not already exist.
@ -114,7 +134,7 @@ class Database:
self.timeout = timeout self.timeout = timeout
self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout) self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout)
if username: if username:
self.request_session.auth = (username, password or '') self.request_session.auth = (username, password or "")
self.log_statements = log_statements self.log_statements = log_statements
self.settings = {} self.settings = {}
self.db_exists = False # this is required before running _is_existing_database self.db_exists = False # this is required before running _is_existing_database
@ -134,7 +154,7 @@ class Database:
if self._readonly: if self._readonly:
if not self.db_exists: if not self.db_exists:
raise DatabaseException( raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection' "Database does not exist, and cannot be created under readonly connection"
) )
self.connection_readonly = self._is_connection_readonly() self.connection_readonly = self._is_connection_readonly()
self.readonly = True self.readonly = True
@ -155,14 +175,14 @@ class Database:
""" """
Creates the database on the ClickHouse server if it does not already exist. Creates the database on the ClickHouse server if it does not already exist.
""" """
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True self.db_exists = True
def drop_database(self): def drop_database(self):
""" """
Deletes the database on the ClickHouse server. Deletes the database on the ClickHouse server.
""" """
self._send('DROP DATABASE `%s`' % self.db_name) self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False self.db_exists = False
def create_table(self, model_class: type[MODEL]) -> None: def create_table(self, model_class: type[MODEL]) -> None:
@ -171,7 +191,7 @@ class Database:
""" """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't create system table") raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None: if getattr(model_class, "engine") is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__) raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self)) self._send(model_class.create_table_sql(self))
@ -190,13 +210,9 @@ class Database:
""" """
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name())) r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1' return r.text.strip() == "1"
def get_model_for_table( def get_model_for_table(self, table_name: str, system_table: bool = False):
self,
table_name: str,
system_table: bool = False
):
""" """
Generates a model class from an existing table in the database. Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class, This can be used for querying tables which don't have a corresponding model class,
@ -205,7 +221,7 @@ class Database:
- `table_name`: the table to create a model for - `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database - `system_table`: whether the table is a system table, or belongs to the current database
""" """
db_name = 'system' if system_table else self.db_name db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines() lines = self._send(sql).iter_lines()
fields = [parse_tsv(line)[:2] for line in lines] fields = [parse_tsv(line)[:2] for line in lines]
@ -222,7 +238,7 @@ class Database:
The name must be string, and the value is converted to string in case The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value. it isn't. To remove a setting, pass `None` as the value.
""" """
assert isinstance(name, str), 'Setting name must be a string' assert isinstance(name, str), "Setting name must be a string"
if value is None: if value is None:
self.settings.pop(name, None) self.settings.pop(name, None)
else: else:
@ -246,14 +262,13 @@ class Database:
if first_instance.is_read_only() or first_instance.is_system_model(): if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables") raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join( fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
['`%s`' % name for name in first_instance.fields(writable=True)]) fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
def gen(): def gen():
buf = BytesIO() buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8')) buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self) first_instance.set_database(self)
buf.write(first_instance.to_db_string()) buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size # Collect lines in batches of batch_size
@ -271,12 +286,11 @@ class Database:
# Return any remaining lines in partial batch # Return any remaining lines in partial batch
if lines: if lines:
yield buf.getvalue() yield buf.getvalue()
self._send(gen()) self._send(gen())
def count( def count(
self, self, model_class: Optional[type[MODEL]], conditions: Optional[Union[str, "Q"]] = None
model_class: Optional[type[MODEL]],
conditions: Optional[Union[str, 'Q']] = None
) -> int: ) -> int:
""" """
Counts the number of records in the model's table. Counts the number of records in the model's table.
@ -286,20 +300,17 @@ class Database:
""" """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table' query = "SELECT count() FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query) r = self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select( def select(
self, self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
query: str,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]: ) -> Generator[MODEL, None, None]:
""" """
Performs a query and returns a generator of model instances. Performs a query and returns a generator of model instances.
@ -309,7 +320,7 @@ class Database:
or `None` for getting back instances of an ad-hoc model. or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
""" """
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query, settings, True) r = self._send(query, settings, True)
try: try:
@ -345,7 +356,7 @@ class Database:
page_num: int = 1, page_num: int = 1,
page_size: int = 100, page_size: int = 100,
conditions=None, conditions=None,
settings: Optional[dict] = None settings: Optional[dict] = None,
): ):
""" """
Selects records and returns a single page of model instances. Selects records and returns a single page of model instances.
@ -362,27 +373,28 @@ class Database:
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
""" """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
if page_num == -1: if page_num == -1:
page_num = max(pages_total, 1) page_num = max(pages_total, 1)
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table' query = "SELECT * FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query += ' ORDER BY %s' % order_by query += " ORDER BY %s" % order_by
query += ' LIMIT %d, %d' % (offset, page_size) query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
return Page( return Page(
objects=list(self.select(query, model_class, settings)) if count else [], objects=list(self.select(query, model_class, settings)) if count else [],
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def migrate(self, migrations_package_name, up_to=9999): def migrate(self, migrations_package_name, up_to=9999):
@ -395,19 +407,23 @@ class Database:
""" """
from .migrations import MigrationHistory # pylint: disable=C0415 from .migrations import MigrationHistory # pylint: disable=C0415
logger = logging.getLogger('migrations') logger = logging.getLogger("migrations")
applied_migrations = self._get_applied_migrations(migrations_package_name) applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name) modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations): for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name) logger.info("Applying migration %s...", name)
for operation in modules[name].operations: for operation in modules[name].operations:
operation.apply(self) operation.apply(self)
self.insert([MigrationHistory( self.insert(
package_name=migrations_package_name, [
module_name=name, MigrationHistory(
applied=datetime.date.today()) package_name=migrations_package_name,
]) module_name=name,
applied=datetime.date.today(),
)
]
)
if int(name[:4]) >= up_to: if int(name[:4]) >= up_to:
break break
@ -432,19 +448,14 @@ class Database:
query = self._substitute(query, MigrationHistory) query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query)) return set(obj.module_name for obj in self.select(query))
def _send( def _send(self, data: str | bytes | Generator, settings: dict = None, stream: bool = False):
self,
data: str | bytes | Generator,
settings: dict = None,
stream: bool = False
):
if isinstance(data, str): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode("utf-8")
if self.log_statements: if self.log_statements:
logger.info(data) logger.info(data)
params = self._build_params(settings) params = self._build_params(settings)
request = self.request_session.build_request( request = self.request_session.build_request(
method='POST', url=self.db_url, content=data, params=params method="POST", url=self.db_url, content=data, params=params
) )
r = self.request_session.send(request, stream=stream) r = self.request_session.send(request, stream=stream)
if isinstance(r, httpx.Response) and r.status_code != 200: if isinstance(r, httpx.Response) and r.status_code != 200:
@ -457,52 +468,52 @@ class Database:
params.update(self.settings) params.update(self.settings)
params.update(self._context_params) params.update(self._context_params)
if self.db_exists: if self.db_exists:
params['database'] = self.db_name params["database"] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error) # Send the readonly flag, unless the connection is already readonly (to prevent db error)
if self.readonly and not self.connection_readonly: if self.readonly and not self.connection_readonly:
params['readonly'] = '1' params["readonly"] = "1"
return params return params
def _substitute(self, query, model_class=None): def _substitute(self, query, model_class=None):
""" """
Replaces $db and $table placeholders in the query. Replaces $db and $table placeholders in the query.
""" """
if '$' in query: if "$" in query:
mapping = dict(db="`%s`" % self.db_name) mapping = dict(db="`%s`" % self.db_name)
if model_class: if model_class:
if model_class.is_system_model(): if model_class.is_system_model():
mapping['table'] = "`system`.`%s`" % model_class.table_name() mapping["table"] = "`system`.`%s`" % model_class.table_name()
elif model_class.is_temporary_model(): elif model_class.is_temporary_model():
mapping['table'] = "`%s`" % model_class.table_name() mapping["table"] = "`%s`" % model_class.table_name()
else: else:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).safe_substitute(mapping) query = Template(query).safe_substitute(mapping)
return query return query
def _get_server_timezone(self): def _get_server_timezone(self):
try: try:
r = self._send('SELECT timezone()') r = self._send("SELECT timezone()")
return pytz.timezone(r.text.strip()) return pytz.timezone(r.text.strip())
except ServerError as err: except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err) logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
return pytz.utc return pytz.utc
def _get_server_version(self, as_tuple=True): def _get_server_version(self, as_tuple=True):
try: try:
r = self._send('SELECT version();') r = self._send("SELECT version();")
ver = r.text ver = r.text
except ServerError as err: except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
ver = '1.1.0' ver = "1.1.0"
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
def _is_existing_database(self): def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name) r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
return r.text.strip() == '1' return r.text.strip() == "1"
def _is_connection_readonly(self): def _is_connection_readonly(self):
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'") r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0' return r.text.strip() != "0"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -11,35 +11,30 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model from clickhouse_orm.models import Model
from clickhouse_orm.funcs import F from clickhouse_orm.funcs import F
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger("clickhouse_orm")
class Engine: class Engine:
def create_table_sql(self, db: Database) -> str: def create_table_sql(self, db: Database) -> str:
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class TinyLog(Engine): class TinyLog(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'TinyLog' return "TinyLog"
class Log(Engine): class Log(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Log' return "Log"
class Memory(Engine): class Memory(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Memory' return "Memory"
class MergeTree(Engine): class MergeTree(Engine):
def __init__( def __init__(
self, self,
date_col: Optional[str] = None, date_col: Optional[str] = None,
@ -49,22 +44,27 @@ class MergeTree(Engine):
replica_table_path: Optional[str] = None, replica_table_path: Optional[str] = None,
replica_name: Optional[str] = None, replica_name: Optional[str] = None,
partition_key: Optional[Union[list, tuple]] = None, partition_key: Optional[Union[list, tuple]] = None,
primary_key: Optional[Union[list, tuple]] = None primary_key: Optional[Union[list, tuple]] = None,
): ):
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple' assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present' assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
assert primary_key is None or type(primary_key) in (list, tuple), \ assert primary_key is None or type(primary_key) in (
'primary_key must be a list or tuple' list,
assert partition_key is None or type(partition_key) in (list, tuple),\ tuple,
'partition_key must be tuple or list if present' ), "primary_key must be a list or tuple"
assert (replica_table_path is None) == (replica_name is None), \ assert partition_key is None or type(partition_key) in (
'both replica_table_path and replica_name must be specified' list,
tuple,
), "partition_key must be tuple or list if present"
assert (replica_table_path is None) == (
replica_name is None
), "both replica_table_path and replica_name must be specified"
# These values conflict with each other (old and new syntax of table engines. # These values conflict with each other (old and new syntax of table engines.
# So let's control only one of them is given. # So let's control only one of them is given.
assert date_col or partition_key, "You must set either date_col or partition_key" assert date_col or partition_key, "You must set either date_col or partition_key"
self.date_col = date_col self.date_col = date_col
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,) self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
self.primary_key = primary_key self.primary_key = primary_key
self.order_by = order_by self.order_by = order_by
@ -76,28 +76,33 @@ class MergeTree(Engine):
# I changed field name for new reality and syntax # I changed field name for new reality and syntax
@property @property
def key_cols(self): def key_cols(self):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. ' logger.warning(
'Use `order_by` attribute instead') "`key_cols` attribute is deprecated and may be removed in future. "
"Use `order_by` attribute instead"
)
return self.order_by return self.order_by
@key_cols.setter @key_cols.setter
def key_cols(self, value): def key_cols(self, value):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. ' logger.warning(
'Use `order_by` attribute instead') "`key_cols` attribute is deprecated and may be removed in future. "
"Use `order_by` attribute instead"
)
self.order_by = value self.order_by = value
def create_table_sql(self, db: Database) -> str: def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__ name = self.__class__.__name__
if self.replica_name: if self.replica_name:
name = 'Replicated' + name name = "Replicated" + name
# In ClickHouse 1.1.54310 custom partitioning key was introduced # In ClickHouse 1.1.54310 custom partitioning key was introduced
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/ # https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
# Let's check version and use new syntax if available # Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310): if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \ partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
% (comma_join(self.partition_key, stringify=True), comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True)) comma_join(self.order_by, stringify=True),
)
if self.primary_key: if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
@ -110,16 +115,17 @@ class MergeTree(Engine):
elif not self.date_col: elif not self.date_col:
# Can't import it globally due to circular import # Can't import it globally due to circular import
from clickhouse_orm.database import DatabaseException from clickhouse_orm.database import DatabaseException
raise DatabaseException( raise DatabaseException(
"Custom partitioning is not supported before ClickHouse 1.1.54310. " "Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax." "Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/" "https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
) )
else: else:
partition_sql = '' partition_sql = ""
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql) return "%s(%s) %s" % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db: Database) -> list[str]: def _build_sql_params(self, db: Database) -> list[str]:
params = [] params = []
@ -134,22 +140,34 @@ class MergeTree(Engine):
params.append(self.date_col) params.append(self.date_col)
if self.sampling_expr: if self.sampling_expr:
params.append(self.sampling_expr) params.append(self.sampling_expr)
params.append('(%s)' % comma_join(self.order_by, stringify=True)) params.append("(%s)" % comma_join(self.order_by, stringify=True))
params.append(str(self.index_granularity)) params.append(str(self.index_granularity))
return params return params
class CollapsingMergeTree(MergeTree): class CollapsingMergeTree(MergeTree):
def __init__( def __init__(
self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, date_col=None,
partition_key=None, primary_key=None order_by=(),
sign_col="sign",
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
): ):
super(CollapsingMergeTree, self).__init__( super(CollapsingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity, date_col,
replica_table_path, replica_name, partition_key, primary_key order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
) )
self.sign_col = sign_col self.sign_col = sign_col
@ -160,37 +178,63 @@ class CollapsingMergeTree(MergeTree):
class SummingMergeTree(MergeTree): class SummingMergeTree(MergeTree):
def __init__( def __init__(
self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, date_col=None,
partition_key=None, primary_key=None order_by=(),
summing_cols=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
): ):
super(SummingMergeTree, self).__init__( super(SummingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity, date_col,
replica_table_path, replica_name, partition_key, primary_key order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
) )
assert type is None or type(summing_cols) in (list, tuple), \ assert type is None or type(summing_cols) in (
'summing_cols must be a list or tuple' list,
tuple,
), "summing_cols must be a list or tuple"
self.summing_cols = summing_cols self.summing_cols = summing_cols
def _build_sql_params(self, db: Database) -> list[str]: def _build_sql_params(self, db: Database) -> list[str]:
params = super(SummingMergeTree, self)._build_sql_params(db) params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols: if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols)) params.append("(%s)" % comma_join(self.summing_cols))
return params return params
class ReplacingMergeTree(MergeTree): class ReplacingMergeTree(MergeTree):
def __init__( def __init__(
self, date_col=None, order_by=(), ver_col=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, date_col=None,
partition_key=None, primary_key=None order_by=(),
ver_col=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
): ):
super(ReplacingMergeTree, self).__init__( super(ReplacingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity, date_col,
replica_table_path, replica_name, partition_key, primary_key order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
) )
self.ver_col = ver_col self.ver_col = ver_col
@ -217,7 +261,7 @@ class Buffer(Engine):
min_rows: int = 10000, min_rows: int = 10000,
max_rows: int = 1000000, max_rows: int = 1000000,
min_bytes: int = 10000000, min_bytes: int = 10000000,
max_bytes: int = 100000000 max_bytes: int = 100000000,
): ):
self.main_model = main_model self.main_model = main_model
self.num_layers = num_layers self.num_layers = num_layers
@ -231,11 +275,17 @@ class Buffer(Engine):
def create_table_sql(self, db: Database) -> str: def create_table_sql(self, db: Database) -> str:
# Overriden create_table_sql example: # Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' # sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
db.db_name, self.main_model.table_name(), self.num_layers, db.db_name,
self.min_time, self.max_time, self.min_rows, self.main_model.table_name(),
self.max_rows, self.min_bytes, self.max_bytes self.num_layers,
) self.min_time,
self.max_time,
self.min_rows,
self.max_rows,
self.min_bytes,
self.max_bytes,
)
return sql return sql
@ -265,6 +315,7 @@ class Distributed(Engine):
See full documentation here See full documentation here
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/ https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
""" """
def __init__(self, cluster, table=None, sharding_key=None): def __init__(self, cluster, table=None, sharding_key=None):
""" """
- `cluster`: what cluster to access data from - `cluster`: what cluster to access data from
@ -292,12 +343,15 @@ class Distributed(Engine):
def create_table_sql(self, db: Database) -> str: def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__ name = self.__class__.__name__
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params)) return "%s(%s)" % (name, ", ".join(params))
def _build_sql_params(self, db: Database) -> list[str]: def _build_sql_params(self, db: Database) -> list[str]:
if self.table_name is None: if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format( raise ValueError(
self.__class__.__name__)) "Cannot create {} engine: specify an underlying table".format(
self.__class__.__name__
)
)
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]] params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
if self.sharding_key: if self.sharding_key:

View File

@ -21,13 +21,14 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model from clickhouse_orm.models import Model
from clickhouse_orm.database import Database from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm') logger = getLogger("clickhouse_orm")
class Field(FunctionOperatorsMixin): class Field(FunctionOperatorsMixin):
""" """
Abstract base class for all field types. Abstract base class for all field types.
""" """
name: str = None # this is set by the parent model name: str = None # this is set by the parent model
parent: type["Model"] = None # this is set by the parent model parent: type["Model"] = None # this is set by the parent model
creation_counter: int = 0 # used for keeping the model fields ordered creation_counter: int = 0 # used for keeping the model fields ordered
@ -41,21 +42,29 @@ class Field(FunctionOperatorsMixin):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
assert [default, alias, materialized].count(None) >= 2, \ assert [default, alias, materialized].count(
"Only one of default, alias and materialized parameters can be given" None
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ ) >= 2, "Only one of default, alias and materialized parameters can be given"
"Alias parameter must be a string or function object, if given" assert (
assert (materialized is None or isinstance(materialized, F) or alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
isinstance(materialized, str) and materialized != ""), \ ), "Alias parameter must be a string or function object, if given"
"Materialized parameter must be a string or function object, if given" assert (
assert readonly is None or type( materialized is None
readonly) is bool, "readonly parameter must be bool if given" or isinstance(materialized, F)
assert codec is None or isinstance(codec, str) and codec != "", \ or isinstance(materialized, str)
"Codec field must be string, if given" and materialized != ""
assert db_column is None or isinstance(db_column, str) and db_column != "", \ ), "Materialized parameter must be a string or function object, if given"
"db_column field must be string, if given" assert (
readonly is None or type(readonly) is bool
), "readonly parameter must be bool if given"
assert (
codec is None or isinstance(codec, str) and codec != ""
), "Codec field must be string, if given"
assert (
db_column is None or isinstance(db_column, str) and db_column != ""
), "db_column field must be string, if given"
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
@ -70,7 +79,7 @@ class Field(FunctionOperatorsMixin):
return self.name return self.name
def __repr__(self): def __repr__(self):
return '<%s>' % self.__class__.__name__ return "<%s>" % self.__class__.__name__
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
""" """
@ -92,9 +101,10 @@ class Field(FunctionOperatorsMixin):
Utility method to check that the given value is between min_value and max_value. Utility method to check that the given value is between min_value and max_value.
""" """
if value < min_value or value > max_value: if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % ( raise ValueError(
self.__class__.__name__, value, min_value, max_value "%s out of range - %s is not between %s and %s"
)) % (self.__class__.__name__, value, min_value, max_value)
)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
""" """
@ -114,7 +124,7 @@ class Field(FunctionOperatorsMixin):
sql = self.db_type sql = self.db_type
args = self.get_db_type_args() args = self.get_db_type_args()
if args: if args:
sql += '(%s)' % comma_join(args) sql += "(%s)" % comma_join(args)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql
@ -124,18 +134,18 @@ class Field(FunctionOperatorsMixin):
return [] return []
def _extra_params(self, db: Database) -> str: def _extra_params(self, db: Database) -> str:
sql = '' sql = ""
if self.alias: if self.alias:
sql += ' ALIAS %s' % string_or_func(self.alias) sql += " ALIAS %s" % string_or_func(self.alias)
elif self.materialized: elif self.materialized:
sql += ' MATERIALIZED %s' % string_or_func(self.materialized) sql += " MATERIALIZED %s" % string_or_func(self.materialized)
elif isinstance(self.default, F): elif isinstance(self.default, F):
sql += ' DEFAULT %s' % self.default.to_sql() sql += " DEFAULT %s" % self.default.to_sql()
elif self.default: elif self.default:
default = self.to_db_string(self.default) default = self.to_db_string(self.default)
sql += ' DEFAULT %s' % default sql += " DEFAULT %s" % default
if self.codec and db and db.has_codec_support and not self.alias: if self.codec and db and db.has_codec_support and not self.alias:
sql += ' CODEC(%s)' % self.codec sql += " CODEC(%s)" % self.codec
return sql return sql
def isinstance(self, types) -> bool: def isinstance(self, types) -> bool:
@ -149,28 +159,27 @@ class Field(FunctionOperatorsMixin):
""" """
if isinstance(self, types): if isinstance(self, types):
return True return True
inner_field = getattr(self, 'inner_field', None) inner_field = getattr(self, "inner_field", None)
while inner_field: while inner_field:
if isinstance(inner_field, types): if isinstance(inner_field, types):
return True return True
inner_field = getattr(inner_field, 'inner_field', None) inner_field = getattr(inner_field, "inner_field", None)
return False return False
class StringField(Field): class StringField(Field):
class_default = '' class_default = ""
db_type = 'String' db_type = "String"
def to_python(self, value, timezone_in_use) -> str: def to_python(self, value, timezone_in_use) -> str:
if isinstance(value, str): if isinstance(value, str):
return value return value
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode('UTF-8') return value.decode("UTF-8")
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
class FixedStringField(StringField): class FixedStringField(StringField):
def __init__( def __init__(
self, self,
length: int, length: int,
@ -178,22 +187,22 @@ class FixedStringField(StringField):
alias: Optional[Union[F, str]] = None, alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None, readonly: Optional[bool] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
self._length = length self._length = length
self.db_type = 'FixedString(%d)' % length self.db_type = "FixedString(%d)" % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column) super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use) -> str: def to_python(self, value, timezone_in_use) -> str:
value = super(FixedStringField, self).to_python(value, timezone_in_use) value = super(FixedStringField, self).to_python(value, timezone_in_use)
return value.rstrip('\0') return value.rstrip("\0")
def validate(self, value): def validate(self, value):
if isinstance(value, str): if isinstance(value, str):
value = value.encode('UTF-8') value = value.encode("UTF-8")
if len(value) > self._length: if len(value) > self._length:
raise ValueError( raise ValueError(
f'Value of {len(value)} bytes is too long for FixedStringField({self._length})' f"Value of {len(value)} bytes is too long for FixedStringField({self._length})"
) )
@ -201,7 +210,7 @@ class DateField(Field):
min_value = datetime.date(1970, 1, 1) min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31) max_value = datetime.date(2105, 12, 31)
class_default = min_value class_default = min_value
db_type = 'Date' db_type = "Date"
def to_python(self, value, timezone_in_use) -> datetime.date: def to_python(self, value, timezone_in_use) -> datetime.date:
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
@ -211,10 +220,10 @@ class DateField(Field):
if isinstance(value, int): if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value) return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, str): if isinstance(value, str):
if value == '0000-00-00': if value == "0000-00-00":
return DateField.min_value return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date() return datetime.datetime.strptime(value, "%Y-%m-%d").date()
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def validate(self, value): def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value) self._range_check(value, DateField.min_value, DateField.max_value)
@ -225,7 +234,7 @@ class DateField(Field):
class DateTimeField(Field): class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc) class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime' db_type = "DateTime"
def __init__( def __init__(
self, self,
@ -235,7 +244,7 @@ class DateTimeField(Field):
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None, db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None timezone: Optional[Union[BaseTzInfo, str]] = None,
): ):
super().__init__(default, alias, materialized, readonly, codec, db_column) super().__init__(default, alias, materialized, readonly, codec, db_column)
# assert not timezone, 'Temporarily field timezone is not supported' # assert not timezone, 'Temporarily field timezone is not supported'
@ -257,7 +266,7 @@ class DateTimeField(Field):
if isinstance(value, int): if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str): if isinstance(value, str):
if value == '0000-00-00 00:00:00': if value == "0000-00-00 00:00:00":
return self.class_default return self.class_default
if len(value) == 10: if len(value) == 10:
try: try:
@ -275,14 +284,14 @@ class DateTimeField(Field):
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = timezone_in_use.localize(dt) dt = timezone_in_use.localize(dt)
return dt return dt
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
return escape('%010d' % timegm(value.utctimetuple()), quote) return escape("%010d" % timegm(value.utctimetuple()), quote)
class DateTime64Field(DateTimeField): class DateTime64Field(DateTimeField):
db_type = 'DateTime64' db_type = "DateTime64"
""" """
@ -303,10 +312,10 @@ class DateTime64Field(DateTimeField):
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None, db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None, timezone: Optional[Union[BaseTzInfo, str]] = None,
precision: int = 6 precision: int = 6,
): ):
super().__init__(default, alias, materialized, readonly, codec, db_column, timezone) super().__init__(default, alias, materialized, readonly, codec, db_column, timezone)
assert precision is None or isinstance(precision, int), 'Precision must be int type' assert precision is None or isinstance(precision, int), "Precision must be int type"
self.precision = precision self.precision = precision
def get_db_type_args(self): def get_db_type_args(self):
@ -322,11 +331,10 @@ class DateTime64Field(DateTimeField):
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
""" """
return escape( return escape(
'{timestamp:0{width}.{precision}f}'.format( "{timestamp:0{width}.{precision}f}".format(
timestamp=value.timestamp(), timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
width=11 + self.precision, ),
precision=self.precision), quote,
quote
) )
def to_python(self, value, timezone_in_use) -> datetime.datetime: def to_python(self, value, timezone_in_use) -> datetime.datetime:
@ -336,8 +344,8 @@ class DateTime64Field(DateTimeField):
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str): if isinstance(value, str):
left_part = value.split('.')[0] left_part = value.split(".")[0]
if left_part == '0000-00-00 00:00:00': if left_part == "0000-00-00 00:00:00":
return self.class_default return self.class_default
if len(left_part) == 10: if len(left_part) == 10:
try: try:
@ -357,7 +365,7 @@ class BaseIntField(Field):
try: try:
return int(value) return int(value)
except: except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain # There's no need to call escape since numbers do not contain
@ -370,50 +378,50 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField): class UInt8Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2 ** 8 - 1 max_value = 2**8 - 1
db_type = 'UInt8' db_type = "UInt8"
class UInt16Field(BaseIntField): class UInt16Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2 ** 16 - 1 max_value = 2**16 - 1
db_type = 'UInt16' db_type = "UInt16"
class UInt32Field(BaseIntField): class UInt32Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2 ** 32 - 1 max_value = 2**32 - 1
db_type = 'UInt32' db_type = "UInt32"
class UInt64Field(BaseIntField): class UInt64Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2 ** 64 - 1 max_value = 2**64 - 1
db_type = 'UInt64' db_type = "UInt64"
class Int8Field(BaseIntField): class Int8Field(BaseIntField):
min_value = -2 ** 7 min_value = -(2**7)
max_value = 2 ** 7 - 1 max_value = 2**7 - 1
db_type = 'Int8' db_type = "Int8"
class Int16Field(BaseIntField): class Int16Field(BaseIntField):
min_value = -2 ** 15 min_value = -(2**15)
max_value = 2 ** 15 - 1 max_value = 2**15 - 1
db_type = 'Int16' db_type = "Int16"
class Int32Field(BaseIntField): class Int32Field(BaseIntField):
min_value = -2 ** 31 min_value = -(2**31)
max_value = 2 ** 31 - 1 max_value = 2**31 - 1
db_type = 'Int32' db_type = "Int32"
class Int64Field(BaseIntField): class Int64Field(BaseIntField):
min_value = -2 ** 63 min_value = -(2**63)
max_value = 2 ** 63 - 1 max_value = 2**63 - 1
db_type = 'Int64' db_type = "Int64"
class BaseFloatField(Field): class BaseFloatField(Field):
@ -425,7 +433,7 @@ class BaseFloatField(Field):
try: try:
return float(value) return float(value)
except: except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain # There's no need to call escape since numbers do not contain
@ -434,11 +442,11 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField): class Float32Field(BaseFloatField):
db_type = 'Float32' db_type = "Float32"
class Float64Field(BaseFloatField): class Float64Field(BaseFloatField):
db_type = 'Float64' db_type = "Float64"
class DecimalField(Field): class DecimalField(Field):
@ -454,13 +462,13 @@ class DecimalField(Field):
alias: Optional[Union[F, str]] = None, alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 1 <= precision <= 38, "Precision must be between 1 and 38"
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
self.precision = precision self.precision = precision
self.scale = scale self.scale = scale
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
with localcontext() as ctx: with localcontext() as ctx:
ctx.prec = 38 ctx.prec = 38
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
@ -473,9 +481,9 @@ class DecimalField(Field):
try: try:
value = Decimal(value) value = Decimal(value)
except: except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
if not value.is_finite(): if not value.is_finite():
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
return self._round(value) return self._round(value)
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
@ -498,14 +506,13 @@ class Decimal32Field(DecimalField):
alias: Optional[Union[F, str]] = None, alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
super().__init__(9, scale, default, alias, materialized, readonly, db_column) super().__init__(9, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal32(%d)' % scale self.db_type = "Decimal32(%d)" % scale
class Decimal64Field(DecimalField): class Decimal64Field(DecimalField):
def __init__( def __init__(
self, self,
scale: int, scale: int,
@ -513,14 +520,13 @@ class Decimal64Field(DecimalField):
alias: Optional[Union[F, str]] = None, alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
super().__init__(18, scale, default, alias, materialized, readonly, db_column) super().__init__(18, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal64(%d)' % scale self.db_type = "Decimal64(%d)" % scale
class Decimal128Field(DecimalField): class Decimal128Field(DecimalField):
def __init__( def __init__(
self, self,
scale: int, scale: int,
@ -528,10 +534,10 @@ class Decimal128Field(DecimalField):
alias: Optional[Union[F, str]] = None, alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
super().__init__(38, scale, default, alias, materialized, readonly, db_column) super().__init__(38, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal128(%d)' % scale self.db_type = "Decimal128(%d)" % scale
class BaseEnumField(Field): class BaseEnumField(Field):
@ -547,7 +553,7 @@ class BaseEnumField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
self.enum_cls = enum_cls self.enum_cls = enum_cls
if default is None: if default is None:
@ -564,7 +570,7 @@ class BaseEnumField(Field):
except Exception: except Exception:
return self.enum_cls(value) return self.enum_cls(value)
if isinstance(value, bytes): if isinstance(value, bytes):
decoded = value.decode('UTF-8') decoded = value.decode("UTF-8")
try: try:
return self.enum_cls[decoded] return self.enum_cls[decoded]
except Exception: except Exception:
@ -573,13 +579,13 @@ class BaseEnumField(Field):
return self.enum_cls(value) return self.enum_cls(value)
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
return escape(value.name, quote) return escape(value.name, quote)
def get_db_type_args(self): def get_db_type_args(self):
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
@classmethod @classmethod
def create_ad_hoc_field(cls, db_type) -> BaseEnumField: def create_ad_hoc_field(cls, db_type) -> BaseEnumField:
@ -590,17 +596,17 @@ class BaseEnumField(Field):
members = {} members = {}
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type): for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
members[match.group(1)] = int(match.group(2)) members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members) enum_cls = Enum("AdHocEnum", members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
return field_class(enum_cls) return field_class(enum_cls)
class Enum8Field(BaseEnumField): class Enum8Field(BaseEnumField):
db_type = 'Enum8' db_type = "Enum8"
class Enum16Field(BaseEnumField): class Enum16Field(BaseEnumField):
db_type = 'Enum16' db_type = "Enum16"
class ArrayField(Field): class ArrayField(Field):
@ -614,12 +620,14 @@ class ArrayField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
assert isinstance(inner_field, Field), \ assert isinstance(
"The first argument of ArrayField must be a Field instance" inner_field, Field
assert not isinstance(inner_field, ArrayField), \ ), "The first argument of ArrayField must be a Field instance"
"Multidimensional array fields are not supported by the ORM" assert not isinstance(
inner_field, ArrayField
), "Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column) super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column)
@ -627,9 +635,9 @@ class ArrayField(Field):
if isinstance(value, str): if isinstance(value, str):
value = parse_array(value) value = parse_array(value)
elif isinstance(value, bytes): elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8')) value = parse_array(value.decode("UTF-8"))
elif not isinstance(value, (list, tuple)): elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value] return [self.inner_field.to_python(v, timezone_in_use) for v in value]
def validate(self, value): def validate(self, value):
@ -638,12 +646,12 @@ class ArrayField(Field):
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
array = [self.inner_field.to_db_string(v, quote=True) for v in value] array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + comma_join(array) + ']' return "[" + comma_join(array) + "]"
def get_sql(self, with_default_expression=True, db=None) -> str: def get_sql(self, with_default_expression=True, db=None) -> str:
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support: if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec sql += " CODEC(%s)" % self.codec
return sql return sql
@ -658,17 +666,19 @@ class TupleField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: bool = None, readonly: bool = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
self.names = {} self.names = {}
self.inner_fields = [] self.inner_fields = []
for (name, field) in name_fields: for (name, field) in name_fields:
if name in self.names: if name in self.names:
raise ValueError('The Field name conflict') raise ValueError("The Field name conflict")
assert isinstance(field, Field), \ assert isinstance(
"The first argument of TupleField must be a Field instance" field, Field
assert not isinstance(field, (ArrayField, TupleField)), \ ), "The first argument of TupleField must be a Field instance"
"Multidimensional array fields are not supported by the ORM" assert not isinstance(
field, (ArrayField, TupleField)
), "Multidimensional array fields are not supported by the ORM"
self.names[name] = field self.names[name] = field
self.inner_fields.append(field) self.inner_fields.append(field)
self.class_default = tuple(field.class_default for field in self.inner_fields) self.class_default = tuple(field.class_default for field in self.inner_fields)
@ -677,16 +687,19 @@ class TupleField(Field):
def to_python(self, value, timezone_in_use) -> tuple: def to_python(self, value, timezone_in_use) -> tuple:
if isinstance(value, str): if isinstance(value, str):
value = parse_array(value) value = parse_array(value)
value = (self.inner_fields[i].to_python(v, timezone_in_use) value = (
for i, v in enumerate(value)) self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
elif isinstance(value, bytes): elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8')) value = parse_array(value.decode("UTF-8"))
value = (self.inner_fields[i].to_python(v, timezone_in_use) value = (
for i, v in enumerate(value)) self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
elif not isinstance(value, (list, tuple)): elif not isinstance(value, (list, tuple)):
raise ValueError('TupleField expects list or tuple, not %s' % type(value)) raise ValueError("TupleField expects list or tuple, not %s" % type(value))
return tuple(self.inner_fields[i].to_python(v, timezone_in_use) return tuple(
for i, v in enumerate(value)) self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
def validate(self, value): def validate(self, value):
for i, v in enumerate(value): for i, v in enumerate(value):
@ -694,21 +707,22 @@ class TupleField(Field):
def to_db_string(self, value, quote=True) -> str: def to_db_string(self, value, quote=True) -> str:
array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)] array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)]
return '(' + comma_join(array) + ')' return "(" + comma_join(array) + ")"
def get_sql(self, with_default_expression=True, db=None) -> str: def get_sql(self, with_default_expression=True, db=None) -> str:
inner_sql = ', '.join('%s %s' % (name, field.get_sql(False)) inner_sql = ", ".join(
for name, field in self.names.items()) "%s %s" % (name, field.get_sql(False)) for name, field in self.names.items()
)
sql = 'Tuple(%s)' % inner_sql sql = "Tuple(%s)" % inner_sql
if with_default_expression and self.codec and db and db.has_codec_support: if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec sql += " CODEC(%s)" % self.codec
return sql return sql
class UUIDField(Field): class UUIDField(Field):
class_default = UUID(int=0) class_default = UUID(int=0)
db_type = 'UUID' db_type = "UUID"
def to_python(self, value, timezone_in_use) -> UUID: def to_python(self, value, timezone_in_use) -> UUID:
if isinstance(value, UUID): if isinstance(value, UUID):
@ -722,7 +736,7 @@ class UUIDField(Field):
elif isinstance(value, tuple): elif isinstance(value, tuple):
return UUID(fields=value) return UUID(fields=value)
else: else:
raise ValueError('Invalid value for UUIDField: %r' % value) raise ValueError("Invalid value for UUIDField: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -730,7 +744,7 @@ class UUIDField(Field):
class IPv4Field(Field): class IPv4Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv4' db_type = "IPv4"
def to_python(self, value, timezone_in_use) -> IPv4Address: def to_python(self, value, timezone_in_use) -> IPv4Address:
if isinstance(value, IPv4Address): if isinstance(value, IPv4Address):
@ -738,7 +752,7 @@ class IPv4Field(Field):
elif isinstance(value, (bytes, str, int)): elif isinstance(value, (bytes, str, int)):
return IPv4Address(value) return IPv4Address(value)
else: else:
raise ValueError('Invalid value for IPv4Address: %r' % value) raise ValueError("Invalid value for IPv4Address: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -746,7 +760,7 @@ class IPv4Field(Field):
class IPv6Field(Field): class IPv6Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv6' db_type = "IPv6"
def to_python(self, value, timezone_in_use) -> IPv6Address: def to_python(self, value, timezone_in_use) -> IPv6Address:
if isinstance(value, IPv6Address): if isinstance(value, IPv6Address):
@ -754,7 +768,7 @@ class IPv6Field(Field):
elif isinstance(value, (bytes, str, int)): elif isinstance(value, (bytes, str, int)):
return IPv6Address(value) return IPv6Address(value)
else: else:
raise ValueError('Invalid value for IPv6Address: %r' % value) raise ValueError("Invalid value for IPv6Address: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -771,11 +785,13 @@ class NullableField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
extra_null_values: Optional[Iterable] = None, extra_null_values: Optional[Iterable] = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
assert isinstance(inner_field, Field), \ assert isinstance(
"The first argument of NullableField must be a Field instance." \ inner_field, Field
" Not: {}".format(inner_field) ), "The first argument of NullableField must be a Field instance." " Not: {}".format(
inner_field
)
self.inner_field = inner_field self.inner_field = inner_field
self._null_values = [None] self._null_values = [None]
if extra_null_values: if extra_null_values:
@ -785,7 +801,7 @@ class NullableField(Field):
) )
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values: if value == "\\N" or value in self._null_values:
return None return None
return self.inner_field.to_python(value, timezone_in_use) return self.inner_field.to_python(value, timezone_in_use)
@ -794,18 +810,17 @@ class NullableField(Field):
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
if value in self._null_values: if value in self._null_values:
return '\\N' return "\\N"
return self.inner_field.to_db_string(value, quote=quote) return self.inner_field.to_db_string(value, quote=quote)
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql
class LowCardinalityField(Field): class LowCardinalityField(Field):
def __init__( def __init__(
self, self,
inner_field: Field, inner_field: Field,
@ -814,16 +829,20 @@ class LowCardinalityField(Field):
materialized: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None, readonly: Optional[bool] = None,
codec: Optional[str] = None, codec: Optional[str] = None,
db_column: Optional[str] = None db_column: Optional[str] = None,
): ):
assert isinstance(inner_field, Field), \ assert isinstance(
"The first argument of LowCardinalityField must be a Field instance." \ inner_field, Field
" Not: {}".format(inner_field) ), "The first argument of LowCardinalityField must be a Field instance." " Not: {}".format(
assert not isinstance(inner_field, LowCardinalityField), \ inner_field
"LowCardinality inner fields are not supported by the ORM" )
assert not isinstance(inner_field, ArrayField), \ assert not isinstance(
"Array field inside LowCardinality are not supported by the ORM." \ inner_field, LowCardinalityField
), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), (
"Array field inside LowCardinality are not supported by the ORM."
" Use Array(LowCardinality) instead" " Use Array(LowCardinality) instead"
)
self.inner_field = inner_field self.inner_field = inner_field
self.class_default = self.inner_field.class_default self.class_default = self.inner_field.class_default
super().__init__(default, alias, materialized, readonly, codec, db_column) super().__init__(default, alias, materialized, readonly, codec, db_column)
@ -839,12 +858,12 @@ class LowCardinalityField(Field):
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
if db and db.has_low_cardinality_support: if db and db.has_low_cardinality_support:
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
else: else:
sql = self.inner_field.get_sql(with_default_expression=False) sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning( logger.warning(
'LowCardinalityField not supported on clickhouse-server version < 19.0' "LowCardinalityField not supported on clickhouse-server version < 19.0"
' using {} as fallback'.format(self.inner_field.__class__.__name__) " using {} as fallback".format(self.inner_field.__class__.__name__)
) )
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ from .utils import get_subclass_names
import logging import logging
logger = logging.getLogger('migrations') logger = logging.getLogger("migrations")
class Operation: class Operation:
@ -14,7 +14,7 @@ class Operation:
""" """
def apply(self, database): def apply(self, database):
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation): class ModelOperation(Operation):
@ -30,9 +30,9 @@ class ModelOperation(Operation):
self.table_name = model_class.table_name() self.table_name = model_class.table_name()
def _alter_table(self, database, cmd): def _alter_table(self, database, cmd):
''' """
Utility for running ALTER TABLE commands. Utility for running ALTER TABLE commands.
''' """
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd) cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
logger.debug(cmd) logger.debug(cmd)
database.raw(cmd) database.raw(cmd)
@ -44,7 +44,7 @@ class CreateTable(ModelOperation):
""" """
def apply(self, database): def apply(self, database):
logger.info(' Create table %s', self.table_name) logger.info(" Create table %s", self.table_name)
if issubclass(self.model_class, BufferModel): if issubclass(self.model_class, BufferModel):
database.create_table(self.model_class.engine.main_model) database.create_table(self.model_class.engine.main_model)
database.create_table(self.model_class) database.create_table(self.model_class)
@ -65,7 +65,7 @@ class AlterTable(ModelOperation):
return [(row.name, row.type) for row in database.select(query)] return [(row.name, row.type) for row in database.select(query)]
def apply(self, database): def apply(self, database):
logger.info(' Alter table %s', self.table_name) logger.info(" Alter table %s", self.table_name)
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC, # Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
# ADD COLUMN ... AFTER doesn't affect it # ADD COLUMN ... AFTER doesn't affect it
@ -74,8 +74,8 @@ class AlterTable(ModelOperation):
# Identify fields that were deleted from the model # Identify fields that were deleted from the model
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields()) deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
for name in deleted_fields: for name in deleted_fields:
logger.info(' Drop column %s', name) logger.info(" Drop column %s", name)
self._alter_table(database, 'DROP COLUMN %s' % name) self._alter_table(database, "DROP COLUMN %s" % name)
del table_fields[name] del table_fields[name]
# Identify fields that were added to the model # Identify fields that were added to the model
@ -83,13 +83,13 @@ class AlterTable(ModelOperation):
for name, field in self.model_class.fields().items(): for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias) is_regular_field = not (field.materialized or field.alias)
if name not in table_fields: if name not in table_fields:
logger.info(' Add column %s', name) logger.info(" Add column %s", name)
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database)) cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field: if is_regular_field:
if prev_name: if prev_name:
cmd += ' AFTER %s' % prev_name cmd += " AFTER %s" % prev_name
else: else:
cmd += ' FIRST' cmd += " FIRST"
self._alter_table(database, cmd) self._alter_table(database, cmd)
if is_regular_field: if is_regular_field:
@ -101,16 +101,24 @@ class AlterTable(ModelOperation):
# The order of class attributes can be changed any time, so we can't count on it # The order of class attributes can be changed any time, so we can't count on it
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save # Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47 # attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database) model_fields = {
for name, field in self.model_class.fields().items()} name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()
}
for field_name, field_sql in self._get_table_fields(database): for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment # All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement' assert field_name in model_fields, "Model fields and table columns in disagreement"
if field_sql != model_fields[field_name]: if field_sql != model_fields[field_name]:
logger.info(' Change type of column %s from %s to %s', field_name, field_sql, logger.info(
model_fields[field_name]) " Change type of column %s from %s to %s",
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name])) field_name,
field_sql,
model_fields[field_name],
)
self._alter_table(
database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name])
)
class AlterTableWithBuffer(ModelOperation): class AlterTableWithBuffer(ModelOperation):
@ -135,7 +143,7 @@ class DropTable(ModelOperation):
""" """
def apply(self, database): def apply(self, database):
logger.info(' Drop table %s', self.table_name) logger.info(" Drop table %s", self.table_name)
database.drop_table(self.model_class) database.drop_table(self.model_class)
@ -148,28 +156,29 @@ class AlterConstraints(ModelOperation):
""" """
def apply(self, database): def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name) logger.info(" Alter constraints for %s", self.table_name)
existing = self._get_constraint_names(database) existing = self._get_constraint_names(database)
# Go over constraints in the model # Go over constraints in the model
for constraint in self.model_class._constraints.values(): for constraint in self.model_class._constraints.values():
# Check if it's a new constraint # Check if it's a new constraint
if constraint.name not in existing: if constraint.name not in existing:
logger.info(' Add constraint %s', constraint.name) logger.info(" Add constraint %s", constraint.name)
self._alter_table(database, 'ADD %s' % constraint.create_table_sql()) self._alter_table(database, "ADD %s" % constraint.create_table_sql())
else: else:
existing.remove(constraint.name) existing.remove(constraint.name)
# Remaining constraints in `existing` are obsolete # Remaining constraints in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop constraint %s', name) logger.info(" Drop constraint %s", name)
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name) self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
def _get_constraint_names(self, database): def _get_constraint_names(self, database):
""" """
Returns a set containing the names of existing constraints in the table. Returns a set containing the names of existing constraints in the table.
""" """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
return set(matches) return set(matches)
@ -191,33 +200,34 @@ class AlterIndexes(ModelOperation):
self.reindex = reindex self.reindex = reindex
def apply(self, database): def apply(self, database):
logger.info(' Alter indexes for %s', self.table_name) logger.info(" Alter indexes for %s", self.table_name)
existing = self._get_index_names(database) existing = self._get_index_names(database)
logger.info(existing) logger.info(existing)
# Go over indexes in the model # Go over indexes in the model
for index in self.model_class._indexes.values(): for index in self.model_class._indexes.values():
# Check if it's a new index # Check if it's a new index
if index.name not in existing: if index.name not in existing:
logger.info(' Add index %s', index.name) logger.info(" Add index %s", index.name)
self._alter_table(database, 'ADD %s' % index.create_table_sql()) self._alter_table(database, "ADD %s" % index.create_table_sql())
else: else:
existing.remove(index.name) existing.remove(index.name)
# Remaining indexes in `existing` are obsolete # Remaining indexes in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop index %s', name) logger.info(" Drop index %s", name)
self._alter_table(database, 'DROP INDEX `%s`' % name) self._alter_table(database, "DROP INDEX `%s`" % name)
# Reindex # Reindex
if self.reindex: if self.reindex:
logger.info(' Build indexes on table') logger.info(" Build indexes on table")
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name) database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
def _get_index_names(self, database): def _get_index_names(self, database):
""" """
Returns a set containing the names of existing indexes in the table. Returns a set containing the names of existing indexes in the table.
""" """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
return set(matches) return set(matches)
@ -225,16 +235,17 @@ class RunPython(Operation):
""" """
A migration operation that executes a Python function. A migration operation that executes a Python function.
""" """
def __init__(self, func): def __init__(self, func):
''' """
Initializer. The given Python function will be called with a single Initializer. The given Python function will be called with a single
argument - the Database instance to apply the migration to. argument - the Database instance to apply the migration to.
''' """
assert callable(func), "'func' argument must be function" assert callable(func), "'func' argument must be function"
self._func = func self._func = func
def apply(self, database): def apply(self, database):
logger.info(' Executing python operation %s', self._func.__name__) logger.info(" Executing python operation %s", self._func.__name__)
self._func(database) self._func(database)
@ -244,17 +255,17 @@ class RunSQL(Operation):
""" """
def __init__(self, sql): def __init__(self, sql):
''' """
Initializer. The given sql argument must be a valid SQL statement or Initializer. The given sql argument must be a valid SQL statement or
list of statements. list of statements.
''' """
if isinstance(sql, str): if isinstance(sql, str):
sql = [sql] sql = [sql]
assert isinstance(sql, list), "'sql' argument must be string or list of strings" assert isinstance(sql, list), "'sql' argument must be string or list of strings"
self._sql = sql self._sql = sql
def apply(self, database): def apply(self, database):
logger.info(' Executing raw SQL operations') logger.info(" Executing raw SQL operations")
for item in self._sql: for item in self._sql:
database.raw(item) database.raw(item)
@ -268,11 +279,11 @@ class MigrationHistory(Model):
module_name = StringField() module_name = StringField()
applied = DateField() applied = DateField()
engine = MergeTree('applied', ('package_name', 'module_name')) engine = MergeTree("applied", ("package_name", "module_name"))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'infi_clickhouse_orm_migrations' return "infi_clickhouse_orm_migrations"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -17,7 +17,7 @@ from .engines import Merge, Distributed, Memory
if TYPE_CHECKING: if TYPE_CHECKING:
from clickhouse_orm.database import Database from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm') logger = getLogger("clickhouse_orm")
class Constraint: class Constraint:
@ -38,7 +38,7 @@ class Constraint:
""" """
Returns the SQL statement for defining this constraint during table creation. Returns the SQL statement for defining this constraint during table creation.
""" """
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
class Index: class Index:
@ -66,8 +66,11 @@ class Index:
""" """
Returns the SQL statement for defining this index during table creation. Returns the SQL statement for defining this index during table creation.
""" """
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % ( return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (
self.name, arg_to_sql(self.expr), self.type, self.granularity self.name,
arg_to_sql(self.expr),
self.type,
self.granularity,
) )
@staticmethod @staticmethod
@ -76,7 +79,7 @@ class Index:
An index that stores extremes of the specified expression (if the expression is tuple, then it stores An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
""" """
return 'minmax' return "minmax"
@staticmethod @staticmethod
def set(max_rows: int) -> str: def set(max_rows: int) -> str:
@ -85,11 +88,12 @@ class Index:
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data. on a block of data.
""" """
return 'set(%d)' % max_rows return "set(%d)" % max_rows
@staticmethod @staticmethod
def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int, def ngrambf_v1(
number_of_hash_functions: int, random_seed: int) -> str: n: int, size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
) -> str:
""" """
An index that stores a Bloom filter containing all ngrams from a block of data. An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions. Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -100,13 +104,17 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
""" """
return 'ngrambf_v1(%d, %d, %d, %d)' % ( return "ngrambf_v1(%d, %d, %d, %d)" % (
n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed n,
size_of_bloom_filter_in_bytes,
number_of_hash_functions,
random_seed,
) )
@staticmethod @staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, def tokenbf_v1(
random_seed: int) -> str: size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
) -> str:
""" """
An index that stores a Bloom filter containing string tokens. Tokens are sequences An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters. separated by non-alphanumeric characters.
@ -116,8 +124,10 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
""" """
return 'tokenbf_v1(%d, %d, %d)' % ( return "tokenbf_v1(%d, %d, %d)" % (
size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed size_of_bloom_filter_in_bytes,
number_of_hash_functions,
random_seed,
) )
@staticmethod @staticmethod
@ -128,7 +138,7 @@ class Index:
- `false_positive` - the probability (between 0 and 1) of receiving a false positive - `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter response from the filter
""" """
return 'bloom_filter(%f)' % false_positive return "bloom_filter(%f)" % false_positive
class ModelBase(type): class ModelBase(type):
@ -183,23 +193,23 @@ class ModelBase(type):
_indexes=indexes, _indexes=indexes,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults, _defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults _has_funcs_as_defaults=has_funcs_as_defaults,
) )
model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs) model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs)
# Let each field, constraint and index know its parent and its own name # Let each field, constraint and index know its parent and its own name
for n, obj in chain(fields, constraints.items(), indexes.items()): for n, obj in chain(fields, constraints.items(), indexes.items()):
setattr(obj, 'parent', model) setattr(obj, "parent", model)
setattr(obj, 'name', n) setattr(obj, "name", n)
return model return model
@classmethod @classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): def create_ad_hoc_model(cls, fields, model_name="AdHocModel"):
# fields is a list of tuples (name, db_type) # fields is a list of tuples (name, db_type)
# Check if model exists in cache # Check if model exists in cache
fields = list(fields) fields = list(fields)
cache_key = model_name + ' ' + str(fields) cache_key = model_name + " " + str(fields)
if cache_key in cls.ad_hoc_model_cache: if cache_key in cls.ad_hoc_model_cache:
return cls.ad_hoc_model_cache[cache_key] return cls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class # Create an ad hoc model class
@ -217,28 +227,25 @@ class ModelBase(type):
import clickhouse_orm.contrib.geo.fields as geo_fields import clickhouse_orm.contrib.geo.fields as geo_fields
# Enums # Enums
if db_type.startswith('Enum'): if db_type.startswith("Enum"):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# DateTime with timezone # DateTime with timezone
if db_type.startswith('DateTime('): if db_type.startswith("DateTime("):
timezone = db_type[9:-1] timezone = db_type[9:-1]
return orm_fields.DateTimeField( return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
timezone=timezone[1:-1] if timezone else None
)
# DateTime64 # DateTime64
if db_type.startswith('DateTime64('): if db_type.startswith("DateTime64("):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')] precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
return orm_fields.DateTime64Field( return orm_fields.DateTime64Field(
precision=int(precision), precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
timezone=timezone[0][1:-1] if timezone else None
) )
# Arrays # Arrays
if db_type.startswith('Array'): if db_type.startswith("Array"):
inner_field = cls.create_ad_hoc_field(db_type[6:-1]) inner_field = cls.create_ad_hoc_field(db_type[6:-1])
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# Tuples # Tuples
if db_type.startswith('Tuple'): if db_type.startswith("Tuple"):
types = [s.strip().split(' ') for s in db_type[6:-1].split(',')] types = [s.strip().split(" ") for s in db_type[6:-1].split(",")]
name_fields = [] name_fields = []
for i, tp in enumerate(types): for i, tp in enumerate(types):
if len(tp) == 2: if len(tp) == 2:
@ -247,27 +254,27 @@ class ModelBase(type):
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0]))) name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
return orm_fields.TupleField(name_fields=name_fields) return orm_fields.TupleField(name_fields=name_fields)
# FixedString # FixedString
if db_type.startswith('FixedString'): if db_type.startswith("FixedString"):
length = int(db_type[12:-1]) length = int(db_type[12:-1])
return orm_fields.FixedStringField(length) return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128 # Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'): if db_type.startswith("Decimal"):
p = db_type.index('(') p = db_type.index("(")
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')] args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
field_class = getattr(orm_fields, db_type[:p] + 'Field') field_class = getattr(orm_fields, db_type[:p] + "Field")
return field_class(*args) return field_class(*args)
# Nullable # Nullable
if db_type.startswith('Nullable'): if db_type.startswith("Nullable"):
inner_field = cls.create_ad_hoc_field(db_type[9 : -1]) inner_field = cls.create_ad_hoc_field(db_type[9:-1])
return orm_fields.NullableField(inner_field) return orm_fields.NullableField(inner_field)
# LowCardinality # LowCardinality
if db_type.startswith('LowCardinality'): if db_type.startswith("LowCardinality"):
inner_field = cls.create_ad_hoc_field(db_type[15 : -1]) inner_field = cls.create_ad_hoc_field(db_type[15:-1])
return orm_fields.LowCardinalityField(inner_field) return orm_fields.LowCardinalityField(inner_field)
# Simple fields # Simple fields
name = db_type + 'Field' name = db_type + "Field"
if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)): if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)):
raise NotImplementedError('No field class for %s' % db_type) raise NotImplementedError("No field class for %s" % db_type)
field_class = getattr(orm_fields, name, None) or getattr(geo_fields, name, None) field_class = getattr(orm_fields, name, None) or getattr(geo_fields, name, None)
return field_class() return field_class()
@ -282,6 +289,7 @@ class Model(metaclass=ModelBase):
cpu_percent = Float32Field() cpu_percent = Float32Field()
engine = Memory() engine = Memory()
""" """
_has_funcs_as_defaults: bool _has_funcs_as_defaults: bool
_constraints: dict[str, Constraint] _constraints: dict[str, Constraint]
_indexes: dict[str, Index] _indexes: dict[str, Index]
@ -318,7 +326,7 @@ class Model(metaclass=ModelBase):
setattr(self, name, value) setattr(self, name, value)
else: else:
raise AttributeError( raise AttributeError(
'%s does not have a field called %s' % (self.__class__.__name__, name) "%s does not have a field called %s" % (self.__class__.__name__, name)
) )
def __setattr__(self, name, value): def __setattr__(self, name, value):
@ -383,29 +391,29 @@ class Model(metaclass=ModelBase):
""" """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
""" """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
# Fields # Fields
items = [] items = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db))) items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints # Constraints
for c in cls._constraints.values(): for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql()) items.append(" %s" % c.create_table_sql())
# Indexes # Indexes
for i in cls._indexes.values(): for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql()) items.append(" %s" % i.create_table_sql())
parts.append(',\n'.join(items)) parts.append(",\n".join(items))
# Engine # Engine
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
@classmethod @classmethod
def drop_table_sql(cls, db: Database) -> str: def drop_table_sql(cls, db: Database) -> str:
""" """
Returns the SQL command for deleting this model's table. Returns the SQL command for deleting this model's table.
""" """
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
@ -422,7 +430,7 @@ class Model(metaclass=ModelBase):
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
field = getattr(cls, name) field = getattr(cls, name)
field_timezone = getattr(field, 'timezone', None) or timezone_in_use field_timezone = getattr(field, "timezone", None) or timezone_in_use
kwargs[name] = field.to_python(next(values), field_timezone) kwargs[name] = field.to_python(next(values), field_timezone)
obj = cls(**kwargs) obj = cls(**kwargs)
@ -439,7 +447,9 @@ class Model(metaclass=ModelBase):
""" """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) return "\t".join(
field.to_db_string(data[name], quote=False) for name, field in fields.items()
)
def to_tskv(self, include_readonly=True): def to_tskv(self, include_readonly=True):
""" """
@ -453,16 +463,16 @@ class Model(metaclass=ModelBase):
parts = [] parts = []
for name, field in fields.items(): for name, field in fields.items():
if data[name] != NO_VALUE: if data[name] != NO_VALUE:
parts.append(name + '=' + field.to_db_string(data[name], quote=False)) parts.append(name + "=" + field.to_db_string(data[name], quote=False))
return '\t'.join(parts) return "\t".join(parts)
def to_db_string(self) -> bytes: def to_db_string(self) -> bytes:
""" """
Returns the instance as a bytestring ready to be inserted into the database. Returns the instance as a bytestring ready to be inserted into the database.
""" """
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n' s += "\n"
return s.encode('utf-8') return s.encode("utf-8")
def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]: def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]:
""" """
@ -519,19 +529,18 @@ class Model(metaclass=ModelBase):
class BufferModel(Model): class BufferModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db: Database) -> str: def create_table_sql(cls, db: Database) -> str:
""" """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
""" """
parts = [ parts = [
'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % ( "CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name()) % (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
] ]
engine_str = cls.engine.create_table_sql(db) engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str) parts.append(engine_str)
return ' '.join(parts) return " ".join(parts)
class MergeModel(Model): class MergeModel(Model):
@ -540,6 +549,7 @@ class MergeModel(Model):
Predefines virtual _table column an controls that rows can't be inserted to this table type Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
""" """
readonly = True readonly = True
# Virtual fields can't be inserted into database # Virtual fields can't be inserted into database
@ -551,15 +561,16 @@ class MergeModel(Model):
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
""" """
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
if name != '_table': if name != "_table":
cols.append(' %s %s' % (name, field.get_sql(db=db))) cols.append(" %s %s" % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols)) parts.append(",\n".join(cols))
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
# TODO: base class for models that require specific engine # TODO: base class for models that require specific engine
@ -574,8 +585,9 @@ class DistributedModel(Model):
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
""" """
assert isinstance(self.engine, Distributed),\ assert isinstance(
"engine must be an instance of engines.Distributed" self.engine, Distributed
), "engine must be an instance of engines.Distributed"
super().set_database(db) super().set_database(db)
@classmethod @classmethod
@ -616,15 +628,20 @@ class DistributedModel(Model):
return return
# find out all the superclasses of the Model that store any data # find out all the superclasses of the Model that store any data
storage_models = [b for b in cls.__bases__ if issubclass(b, Model) storage_models = [
and not issubclass(b, DistributedModel)] b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)
]
if not storage_models: if not storage_models:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has a parent model") "When defining Distributed engine without the table_name "
"ensure that your model has a parent model"
)
if len(storage_models) > 1: if len(storage_models) > 1:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has exactly one non-distributed superclass") "When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass"
)
# enable correct SQL for engine # enable correct SQL for engine
cls.engine.table = storage_models[0] cls.engine.table = storage_models[0]
@ -637,10 +654,12 @@ class DistributedModel(Model):
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
parts = [ parts = [
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format( "CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
db.db_name, cls.table_name(), cls.engine.table_name), db.db_name, cls.table_name(), cls.engine.table_name
'ENGINE = ' + cls.engine.create_table_sql(db)] ),
return '\n'.join(parts) "ENGINE = " + cls.engine.create_table_sql(db),
]
return "\n".join(parts)
class TemporaryModel(Model): class TemporaryModel(Model):
@ -657,30 +676,31 @@ class TemporaryModel(Model):
https://clickhouse.com/docs/en/sql-reference/statements/create/table/#temporary-tables https://clickhouse.com/docs/en/sql-reference/statements/create/table/#temporary-tables
""" """
_temporary = True _temporary = True
@classmethod @classmethod
def create_table_sql(cls, db: Database) -> str: def create_table_sql(cls, db: Database) -> str:
assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance" assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance"
parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()] parts = ["CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (" % cls.table_name()]
# Fields # Fields
items = [] items = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db))) items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints # Constraints
for c in cls._constraints.values(): for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql()) items.append(" %s" % c.create_table_sql())
# Indexes # Indexes
for i in cls._indexes.values(): for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql()) items.append(" %s" % i.create_table_sql())
parts.append(',\n'.join(items)) parts.append(",\n".join(items))
# Engine # Engine
parts.append(')') parts.append(")")
parts.append('ENGINE = Memory') parts.append("ENGINE = Memory")
return '\n'.join(parts) return "\n".join(parts)
# Expose only relevant classes in import * # Expose only relevant classes in import *
MODEL = TypeVar('MODEL', bound=Model) MODEL = TypeVar("MODEL", bound=Model)
__all__ = get_subclass_names(locals(), (Model, Constraint, Index)) __all__ = get_subclass_names(locals(), (Model, Constraint, Index))

View File

@ -11,7 +11,7 @@ from typing import (
Generic, Generic,
TypeVar, TypeVar,
AsyncIterator, AsyncIterator,
Iterator Iterator,
) )
import pytz import pytz
@ -24,7 +24,7 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model from clickhouse_orm.models import Model
from clickhouse_orm.database import Database, Page from clickhouse_orm.database import Database, Page
MODEL = TypeVar('MODEL', bound='Model') MODEL = TypeVar("MODEL", bound="Model")
class Operator: class Operator:
@ -59,9 +59,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None: if value == "\\N" and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null]) return " ".join([field_name, self._sql_for_null])
return ' '.join([field.name, self._sql_operator, value]) return " ".join([field.name, self._sql_operator, value])
class InOperator(Operator): class InOperator(Operator):
@ -81,7 +81,7 @@ class InOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field.name, value) return "%s IN (%s)" % (field.name, value)
class GlobalInOperator(Operator): class GlobalInOperator(Operator):
@ -95,7 +95,7 @@ class GlobalInOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s GLOBAL IN (%s)' % (field.name, value) return "%s GLOBAL IN (%s)" % (field.name, value)
class LikeOperator(Operator): class LikeOperator(Operator):
@ -111,11 +111,11 @@ class LikeOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False) value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value) pattern = self._pattern.format(value)
if self._case_sensitive: if self._case_sensitive:
return '%s LIKE \'%s\'' % (field.name, pattern) return "%s LIKE '%s'" % (field.name, pattern)
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern) return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field.name, pattern)
class IExactOperator(Operator): class IExactOperator(Operator):
@ -126,7 +126,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value) return "lowerUTF8(%s) = lowerUTF8(%s)" % (field.name, value)
class NotOperator(Operator): class NotOperator(Operator):
@ -139,7 +139,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
# Negate the base operator # Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator): class BetweenOperator(Operator):
@ -154,16 +154,22 @@ class BetweenOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len( value0 = (
str(value[0])) > 0 else None self._value_to_sql(field, value[0])
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len( if value[0] is not None or len(str(value[0])) > 0
str(value[1])) > 0 else None else None
)
value1 = (
self._value_to_sql(field, value[1])
if value[1] is not None or len(str(value[1])) > 0
else None
)
if value0 and value1: if value0 and value1:
return '%s BETWEEN %s AND %s' % (field.name, value0, value1) return "%s BETWEEN %s AND %s" % (field.name, value0, value1)
if value0 and not value1: if value0 and not value1:
return ' '.join([field.name, '>=', value0]) return " ".join([field.name, ">=", value0])
if value1 and not value0: if value1 and not value0:
return ' '.join([field.name, '<=', value1]) return " ".join([field.name, "<=", value1])
# Define the set of builtin operators # Define the set of builtin operators
@ -175,24 +181,24 @@ def register_operator(name: str, sql: Operator):
_operators[name] = sql _operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL')) register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator('gt', SimpleOperator('>')) register_operator("gt", SimpleOperator(">"))
register_operator('gte', SimpleOperator('>=')) register_operator("gte", SimpleOperator(">="))
register_operator('lt', SimpleOperator('<')) register_operator("lt", SimpleOperator("<"))
register_operator('lte', SimpleOperator('<=')) register_operator("lte", SimpleOperator("<="))
register_operator('between', BetweenOperator()) register_operator("between", BetweenOperator())
register_operator('in', InOperator()) register_operator("in", InOperator())
register_operator('gin', GlobalInOperator()) register_operator("gin", GlobalInOperator())
register_operator('not_in', NotOperator(InOperator())) register_operator("not_in", NotOperator(InOperator()))
register_operator('not_gin', NotOperator(GlobalInOperator())) register_operator("not_gin", NotOperator(GlobalInOperator()))
register_operator('contains', LikeOperator('%{}%')) register_operator("contains", LikeOperator("%{}%"))
register_operator('startswith', LikeOperator('{}%')) register_operator("startswith", LikeOperator("{}%"))
register_operator('endswith', LikeOperator('%{}')) register_operator("endswith", LikeOperator("%{}"))
register_operator('icontains', LikeOperator('%{}%', False)) register_operator("icontains", LikeOperator("%{}%", False))
register_operator('istartswith', LikeOperator('{}%', False)) register_operator("istartswith", LikeOperator("{}%", False))
register_operator('iendswith', LikeOperator('%{}', False)) register_operator("iendswith", LikeOperator("%{}", False))
register_operator('iexact', IExactOperator()) register_operator("iexact", IExactOperator())
class Cond: class Cond:
@ -214,8 +220,8 @@ class FieldCond(Cond):
self._operator = _operators.get(operator) self._operator = _operators.get(operator)
if self._operator is None: if self._operator is None:
# The field name contains __ like my__field # The field name contains __ like my__field
self._field_name = field_name + '__' + operator self._field_name = field_name + "__" + operator
self._operator = _operators['eq'] self._operator = _operators["eq"]
self._value = value self._value = value
def to_sql(self, model_cls: type[Model]) -> str: def to_sql(self, model_cls: type[Model]) -> str:
@ -228,12 +234,13 @@ class FieldCond(Cond):
class Q: class Q:
AND_MODE = 'AND' AND_MODE = "AND"
OR_MODE = 'OR' OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields): def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in self._conds = list(filter_funcs) + [
filter_fields.items()] self._build_cond(k, v) for k, v in filter_fields.items()
]
self._children = [] self._children = []
self._negate = False self._negate = False
self._mode = self.AND_MODE self._mode = self.AND_MODE
@ -263,10 +270,10 @@ class Q:
return q return q
def _build_cond(self, key, value): def _build_cond(self, key, value):
if '__' in key: if "__" in key:
field_name, operator = key.rsplit('__', 1) field_name, operator = key.rsplit("__", 1)
else: else:
field_name, operator = key, 'eq' field_name, operator = key, "eq"
return FieldCond(field_name, operator, value) return FieldCond(field_name, operator, value)
def to_sql(self, model_cls: type[Model]) -> str: def to_sql(self, model_cls: type[Model]) -> str:
@ -280,16 +287,16 @@ class Q:
if not condition_sql: if not condition_sql:
# Empty Q() object returns everything # Empty Q() object returns everything
sql = '1' sql = "1"
elif len(condition_sql) == 1: elif len(condition_sql) == 1:
# Skip not needed brackets over single condition # Skip not needed brackets over single condition
sql = condition_sql[0] sql = condition_sql[0]
else: else:
# Each condition must be enclosed in brackets, or order of operations may be wrong # Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql) sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate: if self._negate:
sql = 'NOT (%s)' % sql sql = "NOT (%s)" % sql
return sql return sql
@ -400,16 +407,16 @@ class QuerySet(Generic[MODEL]):
def __getitem__(self, s): def __getitem__(self, s):
if isinstance(s, int): if isinstance(s, int):
# Single index # Single index
assert s >= 0, 'negative indexes are not supported' assert s >= 0, "negative indexes are not supported"
queryset = self._clone() queryset = self._clone()
queryset._limits = (s, 1) queryset._limits = (s, 1)
return next(iter(queryset)) return next(iter(queryset))
# Slice # Slice
assert s.step in (None, 1), 'step is not supported in slices' assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0 start = s.start or 0
stop = s.stop or 2 ** 63 - 1 stop = s.stop or 2**63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, 'start of slice cannot be smaller than its end' assert start <= stop, "start of slice cannot be smaller than its end"
queryset = self._clone() queryset = self._clone()
queryset._limits = (start, stop - start) queryset._limits = (start, stop - start)
return queryset return queryset
@ -425,7 +432,7 @@ class QuerySet(Generic[MODEL]):
offset_limit = (0, offset_limit) offset_limit = (0, offset_limit)
offset = offset_limit[0] offset = offset_limit[0]
limit = offset_limit[1] limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported' assert offset >= 0 and limit >= 0, "negative limits are not supported"
queryset = self._clone() queryset = self._clone()
queryset._limit_by = (offset, limit) queryset._limit_by = (offset, limit)
queryset._limit_by_fields = fields_or_expr queryset._limit_by_fields = fields_or_expr
@ -435,44 +442,44 @@ class QuerySet(Generic[MODEL]):
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
fields = '*' fields = "*"
if self._fields: if self._fields:
fields = comma_join('`%s`' % field for field in self._fields) fields = comma_join("`%s`" % field for field in self._fields)
return fields return fields
def as_sql(self) -> str: def as_sql(self) -> str:
""" """
Returns the whole query as a SQL string. Returns the whole query as a SQL string.
""" """
distinct = 'DISTINCT ' if self._distinct else '' distinct = "DISTINCT " if self._distinct else ""
final = ' FINAL' if self._final else '' final = " FINAL" if self._final else ""
table_name = '`%s`' % self._model_cls.table_name() table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model(): if self._model_cls.is_system_model():
table_name = '`system`.' + table_name table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final) params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = 'SELECT %s%s\nFROM %s%s' % params sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty: if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty: if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields: if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields) sql += "\nGROUP BY %s" % comma_join("%s" % field for field in self._grouping_fields)
if self._grouping_with_totals: if self._grouping_with_totals:
sql += ' WITH TOTALS' sql += " WITH TOTALS"
if self._order_by: if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql() sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by: if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by sql += "\nLIMIT %d, %d" % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields) sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits: if self._limits:
sql += '\nLIMIT %d, %d' % self._limits sql += "\nLIMIT %d, %d" % self._limits
return sql return sql
@ -480,10 +487,12 @@ class QuerySet(Generic[MODEL]):
""" """
Returns the contents of the query's `ORDER BY` clause as a string. Returns the contents of the query's `ORDER BY` clause as a string.
""" """
return comma_join([ return comma_join(
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field) [
for field in self._order_by "%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
]) for field in self._order_by
]
)
def conditions_as_sql(self, prewhere=False) -> str: def conditions_as_sql(self, prewhere=False) -> str:
""" """
@ -498,7 +507,7 @@ class QuerySet(Generic[MODEL]):
""" """
if self._distinct or self._limits: if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate # Use a subquery, since a simple count won't be accurate
sql = 'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0
@ -527,8 +536,8 @@ class QuerySet(Generic[MODEL]):
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]": def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
from clickhouse_orm.funcs import F from clickhouse_orm.funcs import F
inverse = kwargs.pop('_inverse', False) inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop('prewhere', False) prewhere = kwargs.pop("prewhere", False)
queryset = self._clone() queryset = self._clone()
@ -588,14 +597,14 @@ class QuerySet(Generic[MODEL]):
if page_num == -1: if page_num == -1:
page_num = pages_total page_num = pages_total
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
return Page( return Page(
objects=list(self[offset: offset + page_size]), objects=list(self[offset : offset + page_size]),
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def distinct(self) -> "QuerySet[MODEL]": def distinct(self) -> "QuerySet[MODEL]":
@ -616,8 +625,8 @@ class QuerySet(Generic[MODEL]):
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError( raise TypeError(
'final() method can be used only with the CollapsingMergeTree' "final() method can be used only with the CollapsingMergeTree"
' and ReplacingMergeTree engines' " and ReplacingMergeTree engines"
) )
queryset = self._clone() queryset = self._clone()
@ -631,7 +640,7 @@ class QuerySet(Generic[MODEL]):
""" """
self._verify_mutation_allowed() self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions) sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql) self._database.raw(sql)
return self return self
@ -641,12 +650,14 @@ class QuerySet(Generic[MODEL]):
Keyword arguments specify the field names and expressions to use for the update. Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate. Note that ClickHouse performs updates in the background, so they are not immediate.
""" """
assert kwargs, 'No fields specified for update' assert kwargs, "No fields specified for update"
self._verify_mutation_allowed() self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % ( sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (
self._model_cls.table_name(), fields, conditions self._model_cls.table_name(),
fields,
conditions,
) )
self._database.raw(sql) self._database.raw(sql)
return self return self
@ -655,10 +666,10 @@ class QuerySet(Generic[MODEL]):
""" """
Checks that the queryset's state allows mutations. Raises an AssertionError if not. Checks that the queryset's state allows mutations. Raises an AssertionError if not.
""" """
assert not self._limits, 'Mutations are not allowed after slicing the queryset' assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, 'Mutations are not allowed after calling final()' assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]": def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
""" """
@ -687,7 +698,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
self, self,
base_queryset: QuerySet, base_queryset: QuerySet,
grouping_fields: tuple[Any], grouping_fields: tuple[Any],
calculated_fields: dict[str, str] calculated_fields: dict[str, str],
): ):
""" """
Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`. Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`.
@ -705,7 +716,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
At least one calculated field is required. At least one calculated field is required.
""" """
super().__init__(base_queryset._model_cls, base_queryset._database) super().__init__(base_queryset._model_cls, base_queryset._database)
assert calculated_fields, 'No calculated fields specified for aggregation' assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields self._fields = grouping_fields
self._grouping_fields = grouping_fields self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields self._calculated_fields = calculated_fields
@ -734,8 +745,9 @@ class AggregateQuerySet(QuerySet[MODEL]):
created with. created with.
""" """
for name in args: for name in args:
assert name in self._fields or name in self._calculated_fields, \ assert name in self._fields or name in self._calculated_fields, (
'Cannot group by `%s` since it is not included in the query' % name "Cannot group by `%s` since it is not included in the query" % name
)
queryset = copy(self) queryset = copy(self)
queryset._grouping_fields = args queryset._grouping_fields = args
return queryset return queryset
@ -750,14 +762,16 @@ class AggregateQuerySet(QuerySet[MODEL]):
""" """
This method is not supported on `AggregateQuerySet`. This method is not supported on `AggregateQuerySet`.
""" """
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self) -> str: def select_fields_as_sql(self) -> str:
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in return comma_join(
self._calculated_fields.items()]) [str(f) for f in self._fields]
+ ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self) -> Iterator[Model]: def __iter__(self) -> Iterator[Model]:
""" """
@ -778,7 +792,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
""" """
Returns the number of rows after aggregation. Returns the number of rows after aggregation.
""" """
sql = 'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
if isinstance(raw, CoroutineType): if isinstance(raw, CoroutineType):
return raw return raw
@ -795,7 +809,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
return queryset return queryset
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet') raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -2,8 +2,8 @@ import uuid
from typing import Optional from typing import Optional
from contextvars import ContextVar, Token from contextvars import ContextVar, Token
ctx_session_id: ContextVar[str] = ContextVar('ck.session_id') ctx_session_id: ContextVar[str] = ContextVar("ck.session_id")
ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout') ctx_session_timeout: ContextVar[float] = ContextVar("ck.session_timeout")
class SessionContext: class SessionContext:

View File

@ -16,12 +16,15 @@ class SystemPart(Model):
This model operates only fields, described in the reference. Other fields are ignored. This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/ https://clickhouse.tech/docs/en/system_tables/system.parts/
""" """
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
_readonly = True _readonly = True
_system = True _system = True
database = StringField() # Name of the database where the table that this part belongs to is located. database = (
StringField()
) # Name of the database where the table that this part belongs to is located.
table = StringField() # Name of the table that this part belongs to. table = StringField() # Name of the table that this part belongs to.
engine = StringField() # Name of the table engine, without parameters. engine = StringField() # Name of the table engine, without parameters.
partition = StringField() # Name of the partition, in the format YYYYMM. partition = StringField() # Name of the partition, in the format YYYYMM.
@ -43,7 +46,9 @@ class SystemPart(Model):
# Time the directory with the part was modified. Usually corresponds to the part's creation time. # Time the directory with the part was modified. Usually corresponds to the part's creation time.
modification_time = DateTimeField() modification_time = DateTimeField()
remove_time = DateTimeField() # For inactive parts only - the time when the part became inactive. remove_time = (
DateTimeField()
) # For inactive parts only - the time when the part became inactive.
# The number of places where the part is used. A value greater than 2 indicates # The number of places where the part is used. A value greater than 2 indicates
# that this part participates in queries or merges. # that this part participates in queries or merges.
@ -51,12 +56,13 @@ class SystemPart(Model):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'parts' return "parts"
""" """
Next methods return SQL for some operations, which can be done with partitions Next methods return SQL for some operations, which can be done with partitions
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
""" """
def _partition_operation_sql(self, operation, settings=None, from_part=None): def _partition_operation_sql(self, operation, settings=None, from_part=None):
""" """
Performs some operation over partition Performs some operation over partition
@ -68,9 +74,16 @@ class SystemPart(Model):
Returns: Operation execution result Returns: Operation execution result
""" """
operation = operation.upper() operation = operation.upper()
assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(self.OPERATIONS) assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(
self.OPERATIONS
)
sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (self._database.db_name, self.table, operation, self.partition) sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (
self._database.db_name,
self.table,
operation,
self.partition,
)
if from_part is not None: if from_part is not None:
sql += " FROM %s" % from_part sql += " FROM %s" % from_part
self._database.raw(sql, settings=settings, stream=False) self._database.raw(sql, settings=settings, stream=False)
@ -83,7 +96,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DETACH', settings=settings) return self._partition_operation_sql("DETACH", settings=settings)
def drop(self, settings=None): def drop(self, settings=None):
""" """
@ -93,7 +106,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DROP', settings=settings) return self._partition_operation_sql("DROP", settings=settings)
def attach(self, settings=None): def attach(self, settings=None):
""" """
@ -103,7 +116,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('ATTACH', settings=settings) return self._partition_operation_sql("ATTACH", settings=settings)
def freeze(self, settings=None): def freeze(self, settings=None):
""" """
@ -113,7 +126,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FREEZE', settings=settings) return self._partition_operation_sql("FREEZE", settings=settings)
def fetch(self, zookeeper_path, settings=None): def fetch(self, zookeeper_path, settings=None):
""" """
@ -124,7 +137,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path) return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
@classmethod @classmethod
def get(cls, database, conditions=""): def get(cls, database, conditions=""):
@ -140,9 +153,12 @@ class SystemPart(Model):
assert isinstance(conditions, str), "conditions must be a string" assert isinstance(conditions, str), "conditions must be a string"
if conditions: if conditions:
conditions += " AND" conditions += " AND"
field_names = ','.join(cls.fields()) field_names = ",".join(cls.fields())
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" % return database.select(
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls) "SELECT %s FROM `system`.%s WHERE %s database='%s'"
% (field_names, cls.table_name(), conditions, database.db_name),
model_class=cls,
)
@classmethod @classmethod
def get_active(cls, database, conditions=""): def get_active(cls, database, conditions=""):
@ -155,8 +171,8 @@ class SystemPart(Model):
Returns: A list of SystemPart objects Returns: A list of SystemPart objects
""" """
if conditions: if conditions:
conditions += ' AND ' conditions += " AND "
conditions += 'active' conditions += "active"
return SystemPart.get(database, conditions=conditions) return SystemPart.get(database, conditions=conditions)

View File

@ -10,10 +10,10 @@ SPECIAL_CHARS = {
"\t": "\\t", "\t": "\\t",
"\0": "\\0", "\0": "\\0",
"\\": "\\\\", "\\": "\\\\",
"'": "\\'" "'": "\\'",
} }
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)") POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]") RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
@ -36,11 +36,11 @@ def escape(value, quote=True):
def unescape(value): def unescape(value):
return codecs.escape_decode(value)[0].decode('utf-8') return codecs.escape_decode(value)[0].decode("utf-8")
def string_or_func(obj): def string_or_func(obj):
return obj.to_sql() if hasattr(obj, 'to_sql') else obj return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg): def arg_to_sql(arg):
@ -50,6 +50,7 @@ def arg_to_sql(arg):
None, numbers, timezones, arrays/iterables. None, numbers, timezones, arrays/iterables.
""" """
from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
if isinstance(arg, F): if isinstance(arg, F):
return arg.to_sql() return arg.to_sql()
if isinstance(arg, Field): if isinstance(arg, Field):
@ -67,22 +68,22 @@ def arg_to_sql(arg):
if isinstance(arg, tzinfo): if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None)) return StringField().to_db_string(arg.tzname(None))
if arg is None: if arg is None:
return 'NULL' return "NULL"
if isinstance(arg, QuerySet): if isinstance(arg, QuerySet):
return "(%s)" % arg return "(%s)" % arg
if isinstance(arg, tuple): if isinstance(arg, tuple):
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')' return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
if is_iterable(arg): if is_iterable(arg):
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']' return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
return str(arg) return str(arg)
def parse_tsv(line): def parse_tsv(line):
if isinstance(line, bytes): if isinstance(line, bytes):
line = line.decode() line = line.decode()
if line and line[-1] == '\n': if line and line[-1] == "\n":
line = line[:-1] line = line[:-1]
return [unescape(value) for value in line.split(str('\t'))] return [unescape(value) for value in line.split(str("\t"))]
def parse_array(array_string): def parse_array(array_string):
@ -92,17 +93,17 @@ def parse_array(array_string):
"(1,2,3)" ==> [1, 2, 3] "(1,2,3)" ==> [1, 2, 3]
""" """
# Sanity check # Sanity check
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])': if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
raise ValueError('Invalid array string: "%s"' % array_string) raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace # Drop opening brace
array_string = array_string[1:] array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left # Go over the string, lopping off each value at the beginning until nothing is left
values = [] values = []
while True: while True:
if array_string in '])': if array_string in "])":
# End of array # End of array
return values return values
elif array_string[0] in ', ': elif array_string[0] in ", ":
# In between values # In between values
array_string = array_string[1:] array_string = array_string[1:]
elif array_string[0] == "'": elif array_string[0] == "'":
@ -110,13 +111,13 @@ def parse_array(array_string):
match = re.search(r"[^\\]'", array_string) match = re.search(r"[^\\]'", array_string)
if match is None: if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string) raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1: match.start() + 1]) values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():] array_string = array_string[match.end() :]
else: else:
# Start of non-quoted value, find its end # Start of non-quoted value, find its end
match = re.search(r",|\]|\)", array_string) match = re.search(r",|\]|\)", array_string)
values.append(array_string[0: match.start()]) values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1:] array_string = array_string[match.end() - 1 :]
def import_submodules(package_name): def import_submodules(package_name):
@ -124,9 +125,10 @@ def import_submodules(package_name):
Import all submodules of a module. Import all submodules of a module.
""" """
import importlib, pkgutil import importlib, pkgutil
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
return { return {
name: importlib.import_module(package_name + '.' + name) name: importlib.import_module(package_name + "." + name)
for _, name, _ in pkgutil.iter_modules(package.__path__) for _, name, _ in pkgutil.iter_modules(package.__path__)
} }
@ -136,9 +138,9 @@ def comma_join(items, stringify=False):
Joins an iterable of strings with commas. Joins an iterable of strings with commas.
""" """
if stringify: if stringify:
return ', '.join(str(item) for item in items) return ", ".join(str(item) for item in items)
else: else:
return ', '.join(items) return ", ".join(items)
def is_iterable(obj): def is_iterable(obj):
@ -154,6 +156,7 @@ def is_iterable(obj):
def get_subclass_names(locals, base_class): def get_subclass_names(locals, base_class):
from inspect import isclass from inspect import isclass
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
@ -164,7 +167,7 @@ class NoValue:
""" """
def __repr__(self): def __repr__(self):
return 'NO_VALUE' return "NO_VALUE"
NO_VALUE = NoValue() NO_VALUE = NoValue()