From d22683f28c7c59053bd009cedfe35e54e6b3aead Mon Sep 17 00:00:00 2001 From: sw <935405794@qq.com> Date: Sat, 4 Jun 2022 21:25:34 +0800 Subject: [PATCH] Migrate code style to Black --- docs/ref.md | 526 --------------- src/clickhouse_orm/aio/database.py | 112 ++-- src/clickhouse_orm/contrib/geo/fields.py | 26 +- src/clickhouse_orm/database.py | 163 ++--- src/clickhouse_orm/engines.py | 176 +++-- src/clickhouse_orm/fields.py | 333 +++++----- src/clickhouse_orm/funcs.py | 792 ++++++++++++----------- src/clickhouse_orm/migrations.py | 97 +-- src/clickhouse_orm/models.py | 202 +++--- src/clickhouse_orm/query.py | 214 +++--- src/clickhouse_orm/session.py | 4 +- src/clickhouse_orm/system_models.py | 48 +- src/clickhouse_orm/utils.py | 43 +- 13 files changed, 1188 insertions(+), 1548 deletions(-) delete mode 100644 docs/ref.md diff --git a/docs/ref.md b/docs/ref.md deleted file mode 100644 index 4679b2b..0000000 --- a/docs/ref.md +++ /dev/null @@ -1,526 +0,0 @@ -Class Reference -=============== - -infi.clickhouse_orm.database ----------------------------- - -### Database - -#### Database(db_name, db_url="http://localhost:8123/", username=None, password=None, readonly=False) - -Initializes a database instance. Unless it's readonly, the database will be -created on the ClickHouse server if it does not already exist. - -- `db_name`: name of the database to connect to. -- `db_url`: URL of the ClickHouse server. -- `username`: optional connection credentials. -- `password`: optional connection credentials. -- `readonly`: use a read-only connection. - - -#### count(model_class, conditions=None) - -Counts the number of records in the model's table. - -- `model_class`: the model to count. -- `conditions`: optional SQL conditions (contents of the WHERE clause). - - -#### create_database() - -Creates the database on the ClickHouse server if it does not already exist. - - -#### create_table(model_class) - -Creates a table for the given model class, if it does not exist already. - - -#### drop_database() - -Deletes the database on the ClickHouse server. - - -#### drop_table(model_class) - -Drops the database table of the given model class, if it exists. - - -#### insert(model_instances, batch_size=1000) - -Insert records into the database. - -- `model_instances`: any iterable containing instances of a single model class. -- `batch_size`: number of records to send per chunk (use a lower number if your records are very large). - - -#### migrate(migrations_package_name, up_to=9999) - -Executes schema migrations. - -- `migrations_package_name` - fully qualified name of the Python package -containing the migrations. -- `up_to` - number of the last migration to apply. - - -#### paginate(model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None) - -Selects records and returns a single page of model instances. - -- `model_class`: the model class matching the query's table, -or `None` for getting back instances of an ad-hoc model. -- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause). -- `page_num`: the page number (1-based), or -1 to get the last page. -- `page_size`: number of records to return per page. -- `conditions`: optional SQL conditions (contents of the WHERE clause). -- `settings`: query settings to send as HTTP GET parameters - -The result is a namedtuple containing `objects` (list), `number_of_objects`, -`pages_total`, `number` (of the current page), and `page_size`. - - -#### raw(query, settings=None, stream=False) - -Performs a query and returns its output as text. - -- `query`: the SQL query to execute. -- `settings`: query settings to send as HTTP GET parameters -- `stream`: if true, the HTTP response from ClickHouse will be streamed. - - -#### select(query, model_class=None, settings=None) - -Performs a query and returns a generator of model instances. - -- `query`: the SQL query to execute. -- `model_class`: the model class matching the query's table, -or `None` for getting back instances of an ad-hoc model. -- `settings`: query settings to send as HTTP GET parameters - - -### DatabaseException - -Extends Exception - -Raised when a database operation fails. - -infi.clickhouse_orm.models --------------------------- - -### Model - -A base class for ORM models. - -#### Model(**kwargs) - -Creates a model instance, using keyword arguments as field values. -Since values are immediately converted to their Pythonic type, -invalid values will cause a `ValueError` to be raised. -Unrecognized field names will cause an `AttributeError`. - - -#### Model.create_table_sql(db) - -Returns the SQL command for creating a table for this model. - - -#### Model.drop_table_sql(db) - -Returns the SQL command for deleting this model's table. - - -#### Model.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None) - -Create a model instance from a tab-separated line. The line may or may not include a newline. -The `field_names` list must match the fields defined in the model, but does not have to include all of them. -If omitted, it is assumed to be the names of all fields in the model, in order of definition. - -- `line`: the TSV-formatted data. -- `field_names`: names of the model fields in the data. -- `timezone_in_use`: the timezone to use when parsing dates and datetimes. -- `database`: if given, sets the database that this instance belongs to. - - -#### get_database() - -Gets the `Database` that this model instance belongs to. -Returns `None` unless the instance was read from the database or written to it. - - -#### get_field(name) - -Gets a `Field` instance given its name, or `None` if not found. - - -#### Model.objects_in(database) - -Returns a `QuerySet` for selecting instances of this model class. - - -#### set_database(db) - -Sets the `Database` that this model instance belongs to. -This is done automatically when the instance is read from the database or written to it. - - -#### Model.table_name() - -Returns the model's database table name. By default this is the -class name converted to lowercase. Override this if you want to use -a different table name. - - -#### to_dict(include_readonly=True, field_names=None) - -Returns the instance's column values as a dict. - -- `include_readonly`: if false, returns only fields that can be inserted into database. -- `field_names`: an iterable of field names to return (optional) - - -#### to_tsv(include_readonly=True) - -Returns the instance's column values as a tab-separated line. A newline is not included. - -- `include_readonly`: if false, returns only fields that can be inserted into database. - - -### BufferModel - -Extends Model - -#### BufferModel(**kwargs) - -Creates a model instance, using keyword arguments as field values. -Since values are immediately converted to their Pythonic type, -invalid values will cause a `ValueError` to be raised. -Unrecognized field names will cause an `AttributeError`. - - -#### BufferModel.create_table_sql(db) - -Returns the SQL command for creating a table for this model. - - -#### BufferModel.drop_table_sql(db) - -Returns the SQL command for deleting this model's table. - - -#### BufferModel.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None) - -Create a model instance from a tab-separated line. The line may or may not include a newline. -The `field_names` list must match the fields defined in the model, but does not have to include all of them. -If omitted, it is assumed to be the names of all fields in the model, in order of definition. - -- `line`: the TSV-formatted data. -- `field_names`: names of the model fields in the data. -- `timezone_in_use`: the timezone to use when parsing dates and datetimes. -- `database`: if given, sets the database that this instance belongs to. - - -#### get_database() - -Gets the `Database` that this model instance belongs to. -Returns `None` unless the instance was read from the database or written to it. - - -#### get_field(name) - -Gets a `Field` instance given its name, or `None` if not found. - - -#### BufferModel.objects_in(database) - -Returns a `QuerySet` for selecting instances of this model class. - - -#### set_database(db) - -Sets the `Database` that this model instance belongs to. -This is done automatically when the instance is read from the database or written to it. - - -#### BufferModel.table_name() - -Returns the model's database table name. By default this is the -class name converted to lowercase. Override this if you want to use -a different table name. - - -#### to_dict(include_readonly=True, field_names=None) - -Returns the instance's column values as a dict. - -- `include_readonly`: if false, returns only fields that can be inserted into database. -- `field_names`: an iterable of field names to return (optional) - - -#### to_tsv(include_readonly=True) - -Returns the instance's column values as a tab-separated line. A newline is not included. - -- `include_readonly`: if false, returns only fields that can be inserted into database. - - -infi.clickhouse_orm.fields --------------------------- - -### Field - -Abstract base class for all field types. - -#### Field(default=None, alias=None, materialized=None) - - -### StringField - -Extends Field - -#### StringField(default=None, alias=None, materialized=None) - - -### DateField - -Extends Field - -#### DateField(default=None, alias=None, materialized=None) - - -### DateTimeField - -Extends Field - -#### DateTimeField(default=None, alias=None, materialized=None) - - -### BaseIntField - -Extends Field - -Abstract base class for all integer-type fields. - -#### BaseIntField(default=None, alias=None, materialized=None) - - -### BaseFloatField - -Extends Field - -Abstract base class for all float-type fields. - -#### BaseFloatField(default=None, alias=None, materialized=None) - - -### BaseEnumField - -Extends Field - -Abstract base class for all enum-type fields. - -#### BaseEnumField(enum_cls, default=None, alias=None, materialized=None) - - -### ArrayField - -Extends Field - -#### ArrayField(inner_field, default=None, alias=None, materialized=None) - - -### FixedStringField - -Extends StringField - -#### FixedStringField(length, default=None, alias=None, materialized=None) - - -### UInt8Field - -Extends BaseIntField - -#### UInt8Field(default=None, alias=None, materialized=None) - - -### UInt16Field - -Extends BaseIntField - -#### UInt16Field(default=None, alias=None, materialized=None) - - -### UInt32Field - -Extends BaseIntField - -#### UInt32Field(default=None, alias=None, materialized=None) - - -### UInt64Field - -Extends BaseIntField - -#### UInt64Field(default=None, alias=None, materialized=None) - - -### Int8Field - -Extends BaseIntField - -#### Int8Field(default=None, alias=None, materialized=None) - - -### Int16Field - -Extends BaseIntField - -#### Int16Field(default=None, alias=None, materialized=None) - - -### Int32Field - -Extends BaseIntField - -#### Int32Field(default=None, alias=None, materialized=None) - - -### Int64Field - -Extends BaseIntField - -#### Int64Field(default=None, alias=None, materialized=None) - - -### Float32Field - -Extends BaseFloatField - -#### Float32Field(default=None, alias=None, materialized=None) - - -### Float64Field - -Extends BaseFloatField - -#### Float64Field(default=None, alias=None, materialized=None) - - -### Enum8Field - -Extends BaseEnumField - -#### Enum8Field(enum_cls, default=None, alias=None, materialized=None) - - -### Enum16Field - -Extends BaseEnumField - -#### Enum16Field(enum_cls, default=None, alias=None, materialized=None) - - -infi.clickhouse_orm.engines ---------------------------- - -### Engine - -### TinyLog - -Extends Engine - -### Log - -Extends Engine - -### Memory - -Extends Engine - -### MergeTree - -Extends Engine - -#### MergeTree(date_col, key_cols, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None) - - -### Buffer - -Extends Engine - -Here we define Buffer engine -Read more here https://clickhouse.tech/reference_en.html#Buffer - -#### Buffer(main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, min_bytes=10000000, max_bytes=100000000) - - -### CollapsingMergeTree - -Extends MergeTree - -#### CollapsingMergeTree(date_col, key_cols, sign_col, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None) - - -### SummingMergeTree - -Extends MergeTree - -#### SummingMergeTree(date_col, key_cols, summing_cols=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None) - - -### ReplacingMergeTree - -Extends MergeTree - -#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None) - - -infi.clickhouse_orm.query -------------------------- - -### QuerySet - -#### QuerySet(model_cls, database) - - -#### conditions_as_sql(prewhere=True) - -Return the contents of the queryset's WHERE or `PREWHERE` clause. - - -#### count() - -Returns the number of matching model instances. - - -#### exclude(**kwargs) - -Returns a new QuerySet instance that excludes all rows matching the conditions. - - -#### filter(**kwargs) - -Returns a new QuerySet instance that includes only rows matching the conditions. - - -#### only(*field_names) - -Limit the query to return only the specified field names. -Useful when there are large fields that are not needed, -or for creating a subquery to use with an IN operator. - - -#### order_by(*field_names) - -Returns a new QuerySet instance with the ordering changed. - - -#### order_by_as_sql() - -Return the contents of the queryset's ORDER BY clause. - - -#### query() - -Return the the queryset as SQL. - - diff --git a/src/clickhouse_orm/aio/database.py b/src/clickhouse_orm/aio/database.py index fd36128..be423b5 100644 --- a/src/clickhouse_orm/aio/database.py +++ b/src/clickhouse_orm/aio/database.py @@ -15,6 +15,7 @@ from clickhouse_orm.database import Database, ServerError, DatabaseException, lo # pylint: disable=C0116 + class AioDatabase(Database): _client_class = httpx.AsyncClient @@ -25,7 +26,7 @@ class AioDatabase(Database): if self._readonly: if not self.db_exists: raise DatabaseException( - 'Database does not exist, and cannot be created under readonly connection' + "Database does not exist, and cannot be created under readonly connection" ) self.connection_readonly = await self._is_connection_readonly() self.readonly = True @@ -44,10 +45,7 @@ class AioDatabase(Database): await self.request_session.aclose() async def _send( - self, - data: str | bytes | AsyncGenerator, - settings: dict = None, - stream: bool = False + self, data: str | bytes | AsyncGenerator, settings: dict = None, stream: bool = False ): r = await super()._send(data, settings, stream) if r.status_code != 200: @@ -55,11 +53,7 @@ class AioDatabase(Database): raise ServerError(r.text) return r - async def count( - self, - model_class: type[MODEL], - conditions=None - ) -> int: + async def count(self, model_class: type[MODEL], conditions=None) -> int: """ Counts the number of records in the model's table. @@ -70,14 +64,14 @@ class AioDatabase(Database): if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the init method before it can be used' + "The AioDatabase object must execute the init method before it can be used" ) - query = 'SELECT count() FROM $table' + query = "SELECT count() FROM $table" if conditions: if isinstance(conditions, Q): conditions = conditions.to_sql(model_class) - query += ' WHERE ' + str(conditions) + query += " WHERE " + str(conditions) query = self._substitute(query, model_class) r = await self._send(query) return int(r.text) if r.text else 0 @@ -86,14 +80,14 @@ class AioDatabase(Database): """ Creates the database on the ClickHouse server if it does not already exist. """ - await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) + await self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name) self.db_exists = True async def drop_database(self): """ Deletes the database on the ClickHouse server. """ - await self._send('DROP DATABASE `%s`' % self.db_name) + await self._send("DROP DATABASE `%s`" % self.db_name) self.db_exists = False async def create_table(self, model_class: type[MODEL]) -> None: @@ -102,7 +96,7 @@ class AioDatabase(Database): """ if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the init method before it can be used' + "The AioDatabase object must execute the init method before it can be used" ) if model_class.is_system_model(): raise DatabaseException("You can't create system table") @@ -110,7 +104,7 @@ class AioDatabase(Database): raise DatabaseException( "Creating a temporary table must be within the lifetime of a session " ) - if getattr(model_class, 'engine') is None: + if getattr(model_class, "engine") is None: raise DatabaseException(f"%s class must define an engine" % model_class.__name__) await self._send(model_class.create_table_sql(self)) @@ -121,7 +115,7 @@ class AioDatabase(Database): """ if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the init method before it can be used' + "The AioDatabase object must execute the init method before it can be used" ) await self._send(model_class.create_temporary_table_sql(self, table_name)) @@ -132,7 +126,7 @@ class AioDatabase(Database): """ if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the init method before it can be used' + "The AioDatabase object must execute the init method before it can be used" ) if model_class.is_system_model(): @@ -146,18 +140,14 @@ class AioDatabase(Database): """ if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the init method before it can be used' + "The AioDatabase object must execute the init method before it can be used" ) sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" r = await self._send(sql % (self.db_name, model_class.table_name())) - return r.text.strip() == '1' + return r.text.strip() == "1" - async def get_model_for_table( - self, - table_name: str, - system_table: bool = False - ): + async def get_model_for_table(self, table_name: str, system_table: bool = False): """ Generates a model class from an existing table in the database. This can be used for querying tables which don't have a corresponding model class, @@ -166,7 +156,7 @@ class AioDatabase(Database): - `table_name`: the table to create a model for - `system_table`: whether the table is a system table, or belongs to the current database """ - db_name = 'system' if system_table else self.db_name + db_name = "system" if system_table else self.db_name sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) lines = await self._send(sql) fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()] @@ -192,14 +182,13 @@ class AioDatabase(Database): if first_instance.is_read_only() or first_instance.is_system_model(): raise DatabaseException("You can't insert into read only and system tables") - fields_list = ','.join( - ['`%s`' % name for name in first_instance.fields(writable=True)]) - fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' - query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt) + fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)]) + fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated" + query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt) async def gen(): buf = BytesIO() - buf.write(self._substitute(query, model_class).encode('utf-8')) + buf.write(self._substitute(query, model_class).encode("utf-8")) first_instance.set_database(self) buf.write(first_instance.to_db_string()) # Collect lines in batches of batch_size @@ -217,13 +206,11 @@ class AioDatabase(Database): # Return any remaining lines in partial batch if lines: yield buf.getvalue() + await self._send(gen()) async def select( - self, - query: str, - model_class: Optional[type[MODEL]] = None, - settings: Optional[dict] = None + self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None ) -> AsyncGenerator[MODEL, None]: """ Performs a query and returns a generator of model instances. @@ -233,7 +220,7 @@ class AioDatabase(Database): or `None` for getting back instances of an ad-hoc model. - `settings`: query settings to send as HTTP GET parameters """ - query += ' FORMAT TabSeparatedWithNamesAndTypes' + query += " FORMAT TabSeparatedWithNamesAndTypes" query = self._substitute(query, model_class) r = await self._send(query, settings, True) try: @@ -245,7 +232,8 @@ class AioDatabase(Database): elif not field_types: field_types = parse_tsv(line) model_class = model_class or ModelBase.create_ad_hoc_model( - zip(field_names, field_types)) + zip(field_names, field_types) + ) elif line.strip(): yield model_class.from_tsv(line, field_names, self.server_timezone, self) except StopIteration: @@ -271,7 +259,7 @@ class AioDatabase(Database): page_num: int = 1, page_size: int = 100, conditions=None, - settings: Optional[dict] = None + settings: Optional[dict] = None, ): """ Selects records and returns a single page of model instances. @@ -294,22 +282,22 @@ class AioDatabase(Database): if page_num == -1: page_num = max(pages_total, 1) elif page_num < 1: - raise ValueError('Invalid page number: %d' % page_num) + raise ValueError("Invalid page number: %d" % page_num) offset = (page_num - 1) * page_size - query = 'SELECT * FROM $table' + query = "SELECT * FROM $table" if conditions: if isinstance(conditions, Q): conditions = conditions.to_sql(model_class) - query += ' WHERE ' + str(conditions) - query += ' ORDER BY %s' % order_by - query += ' LIMIT %d, %d' % (offset, page_size) + query += " WHERE " + str(conditions) + query += " ORDER BY %s" % order_by + query += " LIMIT %d, %d" % (offset, page_size) query = self._substitute(query, model_class) return Page( objects=[r async for r in self.select(query, model_class, settings)] if count else [], number_of_objects=count, pages_total=pages_total, number=page_num, - page_size=page_size + page_size=page_size, ) async def migrate(self, migrations_package_name, up_to=9999): @@ -322,19 +310,23 @@ class AioDatabase(Database): """ from ..migrations import MigrationHistory - logger = logging.getLogger('migrations') + logger = logging.getLogger("migrations") applied_migrations = await self._get_applied_migrations(migrations_package_name) modules = import_submodules(migrations_package_name) unapplied_migrations = set(modules.keys()) - applied_migrations for name in sorted(unapplied_migrations): - logger.info('Applying migration %s...', name) + logger.info("Applying migration %s...", name) for operation in modules[name].operations: operation.apply(self) - await self.insert([MigrationHistory( - package_name=migrations_package_name, - module_name=name, - applied=datetime.date.today() - )]) + await self.insert( + [ + MigrationHistory( + package_name=migrations_package_name, + module_name=name, + applied=datetime.date.today(), + ) + ] + ) if int(name[:4]) >= up_to: break @@ -342,28 +334,28 @@ class AioDatabase(Database): r = await self._send( "SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name ) - return r.text.strip() == '1' + return r.text.strip() == "1" async def _is_connection_readonly(self): r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'") - return r.text.strip() != '0' + return r.text.strip() != "0" async def _get_server_timezone(self): try: - r = await self._send('SELECT timezone()') + r = await self._send("SELECT timezone()") return pytz.timezone(r.text.strip()) except ServerError as err: - logger.exception('Cannot determine server timezone (%s), assuming UTC', err) + logger.exception("Cannot determine server timezone (%s), assuming UTC", err) return pytz.utc async def _get_server_version(self, as_tuple=True): try: - r = await self._send('SELECT version();') + r = await self._send("SELECT version();") ver = r.text except ServerError as err: - logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) - ver = '1.1.0' - return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver + logger.exception("Cannot determine server version (%s), assuming 1.1.0", err) + ver = "1.1.0" + return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver async def _get_applied_migrations(self, migrations_package_name): from ..migrations import MigrationHistory diff --git a/src/clickhouse_orm/contrib/geo/fields.py b/src/clickhouse_orm/contrib/geo/fields.py index df4d2cb..b21c9e9 100644 --- a/src/clickhouse_orm/contrib/geo/fields.py +++ b/src/clickhouse_orm/contrib/geo/fields.py @@ -11,10 +11,10 @@ class Point: self.y = float(y) def __repr__(self): - return f'' + return f"" def to_db_string(self): - return f'({self.x},{self.y})' + return f"({self.x},{self.y})" class Ring: @@ -29,16 +29,16 @@ class Ring: return len(self.array) def __repr__(self): - return f'' + return f"" def to_db_string(self): return f'[{",".join(pt.to_db_string() for pt in self.array)}]' def parse_point(array_string: str) -> Point: - if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')': + if len(array_string) < 2 or array_string[0] != "(" or array_string[-1] != ")": raise ValueError('Invalid point string: "%s"' % array_string) - x, y = array_string.strip('()').split(',') + x, y = array_string.strip("()").split(",") return Point(x, y) @@ -47,14 +47,14 @@ def parse_ring(array_string: str) -> Ring: raise ValueError('Invalid ring string: "%s"' % array_string) ring = [] for point in POINT_REGEX.finditer(array_string): - x, y = point.group('x'), point.group('y') + x, y = point.group("x"), point.group("y") ring.append(Point(x, y)) return Ring(ring) class PointField(Field): class_default = Point(0, 0) - db_type = 'Point' + db_type = "Point" def __init__( self, @@ -63,7 +63,7 @@ class PointField(Field): materialized: Optional[Union[F, str]] = None, readonly: bool = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): super().__init__(default, alias, materialized, readonly, codec, db_column) self.inner_field = Float64Field() @@ -73,10 +73,10 @@ class PointField(Field): value = parse_point(value) elif isinstance(value, (tuple, list)): if len(value) != 2: - raise ValueError('PointField takes 2 value, but %s were given' % len(value)) + raise ValueError("PointField takes 2 value, but %s were given" % len(value)) value = Point(value[0], value[1]) if not isinstance(value, Point): - raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) + raise ValueError("PointField expects list or tuple and Point, not %s" % type(value)) return value def validate(self, value): @@ -91,7 +91,7 @@ class PointField(Field): class RingField(Field): class_default = [Point(0, 0)] - db_type = 'Ring' + db_type = "Ring" def to_python(self, value, timezone_in_use): if isinstance(value, str): @@ -100,11 +100,11 @@ class RingField(Field): ring = [] for point in value: if len(point) != 2: - raise ValueError('Point takes 2 value, but %s were given' % len(value)) + raise ValueError("Point takes 2 value, but %s were given" % len(value)) ring.append(Point(point[0], point[1])) value = Ring(ring) if not isinstance(value, Ring): - raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) + raise ValueError("PointField expects list or tuple and Point, not %s" % type(value)) return value def to_db_string(self, value, quote=True): diff --git a/src/clickhouse_orm/database.py b/src/clickhouse_orm/database.py index 013207f..b7ab5e4 100644 --- a/src/clickhouse_orm/database.py +++ b/src/clickhouse_orm/database.py @@ -16,8 +16,8 @@ from .utils import parse_tsv, import_submodules from .session import ctx_session_id, ctx_session_timeout -logger = logging.getLogger('clickhouse_orm') -Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') +logger = logging.getLogger("clickhouse_orm") +Page = namedtuple("Page", "objects number_of_objects pages_total number page_size") class DatabaseException(Exception): @@ -30,6 +30,7 @@ class ServerError(DatabaseException): """ Raised when a server returns an error. """ + def __init__(self, message): self.code = None processed = self.get_error_code_msg(message) @@ -43,21 +44,30 @@ class ServerError(DatabaseException): ERROR_PATTERNS = ( # ClickHouse prior to v19.3.3 - re.compile(r''' + re.compile( + r""" Code:\ (?P\d+), \ e\.displayText\(\)\ =\ (?P[^ \n]+):\ (?P.+?), \ e.what\(\)\ =\ (?P[^ \n]+) - ''', re.VERBOSE | re.DOTALL), + """, + re.VERBOSE | re.DOTALL, + ), # ClickHouse v19.3.3+ - re.compile(r''' + re.compile( + r""" Code:\ (?P\d+), \ e\.displayText\(\)\ =\ (?P[^ \n]+):\ (?P.+) - ''', re.VERBOSE | re.DOTALL), + """, + re.VERBOSE | re.DOTALL, + ), # ClickHouse v21+ - re.compile(r''' + re.compile( + r""" Code:\ (?P\d+). \ (?P[^ \n]+):\ (?P.+) - ''', re.VERBOSE | re.DOTALL), + """, + re.VERBOSE | re.DOTALL, + ), ) @classmethod @@ -72,7 +82,7 @@ class ServerError(DatabaseException): match = pattern.match(full_error_message) if match: # assert match.group('type1') == match.group('type2') - return int(match.group('code')), match.group('msg').strip() + return int(match.group("code")), match.group("msg").strip() return 0, full_error_message @@ -86,11 +96,21 @@ class Database: Database instances connect to a specific ClickHouse database for running queries, inserting data and other operations. """ + _client_class = httpx.Client - def __init__(self, db_name, db_url='http://localhost:8123/', - username=None, password=None, readonly=False, auto_create=True, - timeout=60, verify_ssl_cert=True, log_statements=False): + def __init__( + self, + db_name, + db_url="http://localhost:8123/", + username=None, + password=None, + readonly=False, + auto_create=True, + timeout=60, + verify_ssl_cert=True, + log_statements=False, + ): """ Initializes a database instance. Unless it's readonly, the database will be created on the ClickHouse server if it does not already exist. @@ -114,7 +134,7 @@ class Database: self.timeout = timeout self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout) if username: - self.request_session.auth = (username, password or '') + self.request_session.auth = (username, password or "") self.log_statements = log_statements self.settings = {} self.db_exists = False # this is required before running _is_existing_database @@ -134,7 +154,7 @@ class Database: if self._readonly: if not self.db_exists: raise DatabaseException( - 'Database does not exist, and cannot be created under readonly connection' + "Database does not exist, and cannot be created under readonly connection" ) self.connection_readonly = self._is_connection_readonly() self.readonly = True @@ -155,14 +175,14 @@ class Database: """ Creates the database on the ClickHouse server if it does not already exist. """ - self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) + self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name) self.db_exists = True def drop_database(self): """ Deletes the database on the ClickHouse server. """ - self._send('DROP DATABASE `%s`' % self.db_name) + self._send("DROP DATABASE `%s`" % self.db_name) self.db_exists = False def create_table(self, model_class: type[MODEL]) -> None: @@ -171,7 +191,7 @@ class Database: """ if model_class.is_system_model(): raise DatabaseException("You can't create system table") - if getattr(model_class, 'engine') is None: + if getattr(model_class, "engine") is None: raise DatabaseException("%s class must define an engine" % model_class.__name__) self._send(model_class.create_table_sql(self)) @@ -190,13 +210,9 @@ class Database: """ sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" r = self._send(sql % (self.db_name, model_class.table_name())) - return r.text.strip() == '1' + return r.text.strip() == "1" - def get_model_for_table( - self, - table_name: str, - system_table: bool = False - ): + def get_model_for_table(self, table_name: str, system_table: bool = False): """ Generates a model class from an existing table in the database. This can be used for querying tables which don't have a corresponding model class, @@ -205,7 +221,7 @@ class Database: - `table_name`: the table to create a model for - `system_table`: whether the table is a system table, or belongs to the current database """ - db_name = 'system' if system_table else self.db_name + db_name = "system" if system_table else self.db_name sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) lines = self._send(sql).iter_lines() fields = [parse_tsv(line)[:2] for line in lines] @@ -222,7 +238,7 @@ class Database: The name must be string, and the value is converted to string in case it isn't. To remove a setting, pass `None` as the value. """ - assert isinstance(name, str), 'Setting name must be a string' + assert isinstance(name, str), "Setting name must be a string" if value is None: self.settings.pop(name, None) else: @@ -246,14 +262,13 @@ class Database: if first_instance.is_read_only() or first_instance.is_system_model(): raise DatabaseException("You can't insert into read only and system tables") - fields_list = ','.join( - ['`%s`' % name for name in first_instance.fields(writable=True)]) - fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' - query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt) + fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)]) + fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated" + query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt) def gen(): buf = BytesIO() - buf.write(self._substitute(query, model_class).encode('utf-8')) + buf.write(self._substitute(query, model_class).encode("utf-8")) first_instance.set_database(self) buf.write(first_instance.to_db_string()) # Collect lines in batches of batch_size @@ -271,12 +286,11 @@ class Database: # Return any remaining lines in partial batch if lines: yield buf.getvalue() + self._send(gen()) def count( - self, - model_class: Optional[type[MODEL]], - conditions: Optional[Union[str, 'Q']] = None + self, model_class: Optional[type[MODEL]], conditions: Optional[Union[str, "Q"]] = None ) -> int: """ Counts the number of records in the model's table. @@ -286,20 +300,17 @@ class Database: """ from clickhouse_orm.query import Q - query = 'SELECT count() FROM $table' + query = "SELECT count() FROM $table" if conditions: if isinstance(conditions, Q): conditions = conditions.to_sql(model_class) - query += ' WHERE ' + str(conditions) + query += " WHERE " + str(conditions) query = self._substitute(query, model_class) r = self._send(query) return int(r.text) if r.text else 0 def select( - self, - query: str, - model_class: Optional[type[MODEL]] = None, - settings: Optional[dict] = None + self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None ) -> Generator[MODEL, None, None]: """ Performs a query and returns a generator of model instances. @@ -309,7 +320,7 @@ class Database: or `None` for getting back instances of an ad-hoc model. - `settings`: query settings to send as HTTP GET parameters """ - query += ' FORMAT TabSeparatedWithNamesAndTypes' + query += " FORMAT TabSeparatedWithNamesAndTypes" query = self._substitute(query, model_class) r = self._send(query, settings, True) try: @@ -345,7 +356,7 @@ class Database: page_num: int = 1, page_size: int = 100, conditions=None, - settings: Optional[dict] = None + settings: Optional[dict] = None, ): """ Selects records and returns a single page of model instances. @@ -362,27 +373,28 @@ class Database: `pages_total`, `number` (of the current page), and `page_size`. """ from clickhouse_orm.query import Q + count = self.count(model_class, conditions) pages_total = int(ceil(count / float(page_size))) if page_num == -1: page_num = max(pages_total, 1) elif page_num < 1: - raise ValueError('Invalid page number: %d' % page_num) + raise ValueError("Invalid page number: %d" % page_num) offset = (page_num - 1) * page_size - query = 'SELECT * FROM $table' + query = "SELECT * FROM $table" if conditions: if isinstance(conditions, Q): conditions = conditions.to_sql(model_class) - query += ' WHERE ' + str(conditions) - query += ' ORDER BY %s' % order_by - query += ' LIMIT %d, %d' % (offset, page_size) + query += " WHERE " + str(conditions) + query += " ORDER BY %s" % order_by + query += " LIMIT %d, %d" % (offset, page_size) query = self._substitute(query, model_class) return Page( objects=list(self.select(query, model_class, settings)) if count else [], number_of_objects=count, pages_total=pages_total, number=page_num, - page_size=page_size + page_size=page_size, ) def migrate(self, migrations_package_name, up_to=9999): @@ -395,19 +407,23 @@ class Database: """ from .migrations import MigrationHistory # pylint: disable=C0415 - logger = logging.getLogger('migrations') + logger = logging.getLogger("migrations") applied_migrations = self._get_applied_migrations(migrations_package_name) modules = import_submodules(migrations_package_name) unapplied_migrations = set(modules.keys()) - applied_migrations for name in sorted(unapplied_migrations): - logger.info('Applying migration %s...', name) + logger.info("Applying migration %s...", name) for operation in modules[name].operations: operation.apply(self) - self.insert([MigrationHistory( - package_name=migrations_package_name, - module_name=name, - applied=datetime.date.today()) - ]) + self.insert( + [ + MigrationHistory( + package_name=migrations_package_name, + module_name=name, + applied=datetime.date.today(), + ) + ] + ) if int(name[:4]) >= up_to: break @@ -432,19 +448,14 @@ class Database: query = self._substitute(query, MigrationHistory) return set(obj.module_name for obj in self.select(query)) - def _send( - self, - data: str | bytes | Generator, - settings: dict = None, - stream: bool = False - ): + def _send(self, data: str | bytes | Generator, settings: dict = None, stream: bool = False): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") if self.log_statements: logger.info(data) params = self._build_params(settings) request = self.request_session.build_request( - method='POST', url=self.db_url, content=data, params=params + method="POST", url=self.db_url, content=data, params=params ) r = self.request_session.send(request, stream=stream) if isinstance(r, httpx.Response) and r.status_code != 200: @@ -457,52 +468,52 @@ class Database: params.update(self.settings) params.update(self._context_params) if self.db_exists: - params['database'] = self.db_name + params["database"] = self.db_name # Send the readonly flag, unless the connection is already readonly (to prevent db error) if self.readonly and not self.connection_readonly: - params['readonly'] = '1' + params["readonly"] = "1" return params def _substitute(self, query, model_class=None): """ Replaces $db and $table placeholders in the query. """ - if '$' in query: + if "$" in query: mapping = dict(db="`%s`" % self.db_name) if model_class: if model_class.is_system_model(): - mapping['table'] = "`system`.`%s`" % model_class.table_name() + mapping["table"] = "`system`.`%s`" % model_class.table_name() elif model_class.is_temporary_model(): - mapping['table'] = "`%s`" % model_class.table_name() + mapping["table"] = "`%s`" % model_class.table_name() else: - mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) + mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) query = Template(query).safe_substitute(mapping) return query def _get_server_timezone(self): try: - r = self._send('SELECT timezone()') + r = self._send("SELECT timezone()") return pytz.timezone(r.text.strip()) except ServerError as err: - logger.exception('Cannot determine server timezone (%s), assuming UTC', err) + logger.exception("Cannot determine server timezone (%s), assuming UTC", err) return pytz.utc def _get_server_version(self, as_tuple=True): try: - r = self._send('SELECT version();') + r = self._send("SELECT version();") ver = r.text except ServerError as err: - logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) - ver = '1.1.0' - return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver + logger.exception("Cannot determine server version (%s), assuming 1.1.0", err) + ver = "1.1.0" + return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver def _is_existing_database(self): r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name) - return r.text.strip() == '1' + return r.text.strip() == "1" def _is_connection_readonly(self): r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'") - return r.text.strip() != '0' + return r.text.strip() != "0" # Expose only relevant classes in import * diff --git a/src/clickhouse_orm/engines.py b/src/clickhouse_orm/engines.py index 7996ca5..21ba6b4 100644 --- a/src/clickhouse_orm/engines.py +++ b/src/clickhouse_orm/engines.py @@ -11,35 +11,30 @@ if TYPE_CHECKING: from clickhouse_orm.models import Model from clickhouse_orm.funcs import F -logger = logging.getLogger('clickhouse_orm') +logger = logging.getLogger("clickhouse_orm") class Engine: - def create_table_sql(self, db: Database) -> str: - raise NotImplementedError() # pragma: no cover + raise NotImplementedError() # pragma: no cover class TinyLog(Engine): - def create_table_sql(self, db): - return 'TinyLog' + return "TinyLog" class Log(Engine): - def create_table_sql(self, db): - return 'Log' + return "Log" class Memory(Engine): - def create_table_sql(self, db): - return 'Memory' + return "Memory" class MergeTree(Engine): - def __init__( self, date_col: Optional[str] = None, @@ -49,22 +44,27 @@ class MergeTree(Engine): replica_table_path: Optional[str] = None, replica_name: Optional[str] = None, partition_key: Optional[Union[list, tuple]] = None, - primary_key: Optional[Union[list, tuple]] = None + primary_key: Optional[Union[list, tuple]] = None, ): - assert type(order_by) in (list, tuple), 'order_by must be a list or tuple' - assert date_col is None or isinstance(date_col, str), 'date_col must be string if present' - assert primary_key is None or type(primary_key) in (list, tuple), \ - 'primary_key must be a list or tuple' - assert partition_key is None or type(partition_key) in (list, tuple),\ - 'partition_key must be tuple or list if present' - assert (replica_table_path is None) == (replica_name is None), \ - 'both replica_table_path and replica_name must be specified' + assert type(order_by) in (list, tuple), "order_by must be a list or tuple" + assert date_col is None or isinstance(date_col, str), "date_col must be string if present" + assert primary_key is None or type(primary_key) in ( + list, + tuple, + ), "primary_key must be a list or tuple" + assert partition_key is None or type(partition_key) in ( + list, + tuple, + ), "partition_key must be tuple or list if present" + assert (replica_table_path is None) == ( + replica_name is None + ), "both replica_table_path and replica_name must be specified" # These values conflict with each other (old and new syntax of table engines. # So let's control only one of them is given. assert date_col or partition_key, "You must set either date_col or partition_key" self.date_col = date_col - self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,) + self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,) self.primary_key = primary_key self.order_by = order_by @@ -76,28 +76,33 @@ class MergeTree(Engine): # I changed field name for new reality and syntax @property def key_cols(self): - logger.warning('`key_cols` attribute is deprecated and may be removed in future. ' - 'Use `order_by` attribute instead') + logger.warning( + "`key_cols` attribute is deprecated and may be removed in future. " + "Use `order_by` attribute instead" + ) return self.order_by @key_cols.setter def key_cols(self, value): - logger.warning('`key_cols` attribute is deprecated and may be removed in future. ' - 'Use `order_by` attribute instead') + logger.warning( + "`key_cols` attribute is deprecated and may be removed in future. " + "Use `order_by` attribute instead" + ) self.order_by = value def create_table_sql(self, db: Database) -> str: name = self.__class__.__name__ if self.replica_name: - name = 'Replicated' + name + name = "Replicated" + name # In ClickHouse 1.1.54310 custom partitioning key was introduced # https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/ # Let's check version and use new syntax if available if db.server_version >= (1, 1, 54310): - partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \ - % (comma_join(self.partition_key, stringify=True), - comma_join(self.order_by, stringify=True)) + partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % ( + comma_join(self.partition_key, stringify=True), + comma_join(self.order_by, stringify=True), + ) if self.primary_key: partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) @@ -110,16 +115,17 @@ class MergeTree(Engine): elif not self.date_col: # Can't import it globally due to circular import from clickhouse_orm.database import DatabaseException + raise DatabaseException( "Custom partitioning is not supported before ClickHouse 1.1.54310. " "Please update your server or use date_col syntax." "https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/" ) else: - partition_sql = '' + partition_sql = "" params = self._build_sql_params(db) - return '%s(%s) %s' % (name, comma_join(params), partition_sql) + return "%s(%s) %s" % (name, comma_join(params), partition_sql) def _build_sql_params(self, db: Database) -> list[str]: params = [] @@ -134,22 +140,34 @@ class MergeTree(Engine): params.append(self.date_col) if self.sampling_expr: params.append(self.sampling_expr) - params.append('(%s)' % comma_join(self.order_by, stringify=True)) + params.append("(%s)" % comma_join(self.order_by, stringify=True)) params.append(str(self.index_granularity)) return params class CollapsingMergeTree(MergeTree): - def __init__( - self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None, - index_granularity=8192, replica_table_path=None, replica_name=None, - partition_key=None, primary_key=None + self, + date_col=None, + order_by=(), + sign_col="sign", + sampling_expr=None, + index_granularity=8192, + replica_table_path=None, + replica_name=None, + partition_key=None, + primary_key=None, ): super(CollapsingMergeTree, self).__init__( - date_col, order_by, sampling_expr, index_granularity, - replica_table_path, replica_name, partition_key, primary_key + date_col, + order_by, + sampling_expr, + index_granularity, + replica_table_path, + replica_name, + partition_key, + primary_key, ) self.sign_col = sign_col @@ -160,37 +178,63 @@ class CollapsingMergeTree(MergeTree): class SummingMergeTree(MergeTree): - def __init__( - self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None, - index_granularity=8192, replica_table_path=None, replica_name=None, - partition_key=None, primary_key=None + self, + date_col=None, + order_by=(), + summing_cols=None, + sampling_expr=None, + index_granularity=8192, + replica_table_path=None, + replica_name=None, + partition_key=None, + primary_key=None, ): super(SummingMergeTree, self).__init__( - date_col, order_by, sampling_expr, index_granularity, - replica_table_path, replica_name, partition_key, primary_key + date_col, + order_by, + sampling_expr, + index_granularity, + replica_table_path, + replica_name, + partition_key, + primary_key, ) - assert type is None or type(summing_cols) in (list, tuple), \ - 'summing_cols must be a list or tuple' + assert type is None or type(summing_cols) in ( + list, + tuple, + ), "summing_cols must be a list or tuple" self.summing_cols = summing_cols def _build_sql_params(self, db: Database) -> list[str]: params = super(SummingMergeTree, self)._build_sql_params(db) if self.summing_cols: - params.append('(%s)' % comma_join(self.summing_cols)) + params.append("(%s)" % comma_join(self.summing_cols)) return params class ReplacingMergeTree(MergeTree): - def __init__( - self, date_col=None, order_by=(), ver_col=None, sampling_expr=None, - index_granularity=8192, replica_table_path=None, replica_name=None, - partition_key=None, primary_key=None + self, + date_col=None, + order_by=(), + ver_col=None, + sampling_expr=None, + index_granularity=8192, + replica_table_path=None, + replica_name=None, + partition_key=None, + primary_key=None, ): super(ReplacingMergeTree, self).__init__( - date_col, order_by, sampling_expr, index_granularity, - replica_table_path, replica_name, partition_key, primary_key + date_col, + order_by, + sampling_expr, + index_granularity, + replica_table_path, + replica_name, + partition_key, + primary_key, ) self.ver_col = ver_col @@ -217,7 +261,7 @@ class Buffer(Engine): min_rows: int = 10000, max_rows: int = 1000000, min_bytes: int = 10000000, - max_bytes: int = 100000000 + max_bytes: int = 100000000, ): self.main_model = main_model self.num_layers = num_layers @@ -231,11 +275,17 @@ class Buffer(Engine): def create_table_sql(self, db: Database) -> str: # Overriden create_table_sql example: # sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' - sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( - db.db_name, self.main_model.table_name(), self.num_layers, - self.min_time, self.max_time, self.min_rows, - self.max_rows, self.min_bytes, self.max_bytes - ) + sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % ( + db.db_name, + self.main_model.table_name(), + self.num_layers, + self.min_time, + self.max_time, + self.min_rows, + self.max_rows, + self.min_bytes, + self.max_bytes, + ) return sql @@ -265,6 +315,7 @@ class Distributed(Engine): See full documentation here https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/ """ + def __init__(self, cluster, table=None, sharding_key=None): """ - `cluster`: what cluster to access data from @@ -292,12 +343,15 @@ class Distributed(Engine): def create_table_sql(self, db: Database) -> str: name = self.__class__.__name__ params = self._build_sql_params(db) - return '%s(%s)' % (name, ', '.join(params)) + return "%s(%s)" % (name, ", ".join(params)) def _build_sql_params(self, db: Database) -> list[str]: if self.table_name is None: - raise ValueError("Cannot create {} engine: specify an underlying table".format( - self.__class__.__name__)) + raise ValueError( + "Cannot create {} engine: specify an underlying table".format( + self.__class__.__name__ + ) + ) params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]] if self.sharding_key: diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index 0377965..b0025e6 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -21,13 +21,14 @@ if TYPE_CHECKING: from clickhouse_orm.models import Model from clickhouse_orm.database import Database -logger = getLogger('clickhouse_orm') +logger = getLogger("clickhouse_orm") class Field(FunctionOperatorsMixin): """ Abstract base class for all field types. """ + name: str = None # this is set by the parent model parent: type["Model"] = None # this is set by the parent model creation_counter: int = 0 # used for keeping the model fields ordered @@ -41,21 +42,29 @@ class Field(FunctionOperatorsMixin): materialized: Optional[Union[F, str]] = None, readonly: bool = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): - assert [default, alias, materialized].count(None) >= 2, \ - "Only one of default, alias and materialized parameters can be given" - assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ - "Alias parameter must be a string or function object, if given" - assert (materialized is None or isinstance(materialized, F) or - isinstance(materialized, str) and materialized != ""), \ - "Materialized parameter must be a string or function object, if given" - assert readonly is None or type( - readonly) is bool, "readonly parameter must be bool if given" - assert codec is None or isinstance(codec, str) and codec != "", \ - "Codec field must be string, if given" - assert db_column is None or isinstance(db_column, str) and db_column != "", \ - "db_column field must be string, if given" + assert [default, alias, materialized].count( + None + ) >= 2, "Only one of default, alias and materialized parameters can be given" + assert ( + alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "" + ), "Alias parameter must be a string or function object, if given" + assert ( + materialized is None + or isinstance(materialized, F) + or isinstance(materialized, str) + and materialized != "" + ), "Materialized parameter must be a string or function object, if given" + assert ( + readonly is None or type(readonly) is bool + ), "readonly parameter must be bool if given" + assert ( + codec is None or isinstance(codec, str) and codec != "" + ), "Codec field must be string, if given" + assert ( + db_column is None or isinstance(db_column, str) and db_column != "" + ), "db_column field must be string, if given" self.creation_counter = Field.creation_counter Field.creation_counter += 1 @@ -70,7 +79,7 @@ class Field(FunctionOperatorsMixin): return self.name def __repr__(self): - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def to_python(self, value, timezone_in_use): """ @@ -92,9 +101,10 @@ class Field(FunctionOperatorsMixin): Utility method to check that the given value is between min_value and max_value. """ if value < min_value or value > max_value: - raise ValueError('%s out of range - %s is not between %s and %s' % ( - self.__class__.__name__, value, min_value, max_value - )) + raise ValueError( + "%s out of range - %s is not between %s and %s" + % (self.__class__.__name__, value, min_value, max_value) + ) def to_db_string(self, value, quote=True): """ @@ -114,7 +124,7 @@ class Field(FunctionOperatorsMixin): sql = self.db_type args = self.get_db_type_args() if args: - sql += '(%s)' % comma_join(args) + sql += "(%s)" % comma_join(args) if with_default_expression: sql += self._extra_params(db) return sql @@ -124,18 +134,18 @@ class Field(FunctionOperatorsMixin): return [] def _extra_params(self, db: Database) -> str: - sql = '' + sql = "" if self.alias: - sql += ' ALIAS %s' % string_or_func(self.alias) + sql += " ALIAS %s" % string_or_func(self.alias) elif self.materialized: - sql += ' MATERIALIZED %s' % string_or_func(self.materialized) + sql += " MATERIALIZED %s" % string_or_func(self.materialized) elif isinstance(self.default, F): - sql += ' DEFAULT %s' % self.default.to_sql() + sql += " DEFAULT %s" % self.default.to_sql() elif self.default: default = self.to_db_string(self.default) - sql += ' DEFAULT %s' % default + sql += " DEFAULT %s" % default if self.codec and db and db.has_codec_support and not self.alias: - sql += ' CODEC(%s)' % self.codec + sql += " CODEC(%s)" % self.codec return sql def isinstance(self, types) -> bool: @@ -149,28 +159,27 @@ class Field(FunctionOperatorsMixin): """ if isinstance(self, types): return True - inner_field = getattr(self, 'inner_field', None) + inner_field = getattr(self, "inner_field", None) while inner_field: if isinstance(inner_field, types): return True - inner_field = getattr(inner_field, 'inner_field', None) + inner_field = getattr(inner_field, "inner_field", None) return False class StringField(Field): - class_default = '' - db_type = 'String' + class_default = "" + db_type = "String" def to_python(self, value, timezone_in_use) -> str: if isinstance(value, str): return value if isinstance(value, bytes): - return value.decode('UTF-8') - raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) + return value.decode("UTF-8") + raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value)) class FixedStringField(StringField): - def __init__( self, length: int, @@ -178,22 +187,22 @@ class FixedStringField(StringField): alias: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None, readonly: Optional[bool] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): self._length = length - self.db_type = 'FixedString(%d)' % length + self.db_type = "FixedString(%d)" % length super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column) def to_python(self, value, timezone_in_use) -> str: value = super(FixedStringField, self).to_python(value, timezone_in_use) - return value.rstrip('\0') + return value.rstrip("\0") def validate(self, value): if isinstance(value, str): - value = value.encode('UTF-8') + value = value.encode("UTF-8") if len(value) > self._length: raise ValueError( - f'Value of {len(value)} bytes is too long for FixedStringField({self._length})' + f"Value of {len(value)} bytes is too long for FixedStringField({self._length})" ) @@ -201,7 +210,7 @@ class DateField(Field): min_value = datetime.date(1970, 1, 1) max_value = datetime.date(2105, 12, 31) class_default = min_value - db_type = 'Date' + db_type = "Date" def to_python(self, value, timezone_in_use) -> datetime.date: if isinstance(value, datetime.datetime): @@ -211,10 +220,10 @@ class DateField(Field): if isinstance(value, int): return DateField.class_default + datetime.timedelta(days=value) if isinstance(value, str): - if value == '0000-00-00': + if value == "0000-00-00": return DateField.min_value - return datetime.datetime.strptime(value, '%Y-%m-%d').date() - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + return datetime.datetime.strptime(value, "%Y-%m-%d").date() + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def validate(self, value): self._range_check(value, DateField.min_value, DateField.max_value) @@ -225,7 +234,7 @@ class DateField(Field): class DateTimeField(Field): class_default = datetime.datetime.fromtimestamp(0, pytz.utc) - db_type = 'DateTime' + db_type = "DateTime" def __init__( self, @@ -235,7 +244,7 @@ class DateTimeField(Field): readonly: bool = None, codec: Optional[str] = None, db_column: Optional[str] = None, - timezone: Optional[Union[BaseTzInfo, str]] = None + timezone: Optional[Union[BaseTzInfo, str]] = None, ): super().__init__(default, alias, materialized, readonly, codec, db_column) # assert not timezone, 'Temporarily field timezone is not supported' @@ -257,7 +266,7 @@ class DateTimeField(Field): if isinstance(value, int): return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) if isinstance(value, str): - if value == '0000-00-00 00:00:00': + if value == "0000-00-00 00:00:00": return self.class_default if len(value) == 10: try: @@ -275,14 +284,14 @@ class DateTimeField(Field): if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: dt = timezone_in_use.localize(dt) return dt - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True) -> str: - return escape('%010d' % timegm(value.utctimetuple()), quote) + return escape("%010d" % timegm(value.utctimetuple()), quote) class DateTime64Field(DateTimeField): - db_type = 'DateTime64' + db_type = "DateTime64" """ @@ -303,10 +312,10 @@ class DateTime64Field(DateTimeField): codec: Optional[str] = None, db_column: Optional[str] = None, timezone: Optional[Union[BaseTzInfo, str]] = None, - precision: int = 6 + precision: int = 6, ): super().__init__(default, alias, materialized, readonly, codec, db_column, timezone) - assert precision is None or isinstance(precision, int), 'Precision must be int type' + assert precision is None or isinstance(precision, int), "Precision must be int type" self.precision = precision def get_db_type_args(self): @@ -322,11 +331,10 @@ class DateTime64Field(DateTimeField): Returns string in 0000000000.000000 format, where remainder digits count is equal to precision """ return escape( - '{timestamp:0{width}.{precision}f}'.format( - timestamp=value.timestamp(), - width=11 + self.precision, - precision=self.precision), - quote + "{timestamp:0{width}.{precision}f}".format( + timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision + ), + quote, ) def to_python(self, value, timezone_in_use) -> datetime.datetime: @@ -336,8 +344,8 @@ class DateTime64Field(DateTimeField): if isinstance(value, (int, float)): return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) if isinstance(value, str): - left_part = value.split('.')[0] - if left_part == '0000-00-00 00:00:00': + left_part = value.split(".")[0] + if left_part == "0000-00-00 00:00:00": return self.class_default if len(left_part) == 10: try: @@ -357,7 +365,7 @@ class BaseIntField(Field): try: return int(value) except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True) -> str: # There's no need to call escape since numbers do not contain @@ -370,50 +378,50 @@ class BaseIntField(Field): class UInt8Field(BaseIntField): min_value = 0 - max_value = 2 ** 8 - 1 - db_type = 'UInt8' + max_value = 2**8 - 1 + db_type = "UInt8" class UInt16Field(BaseIntField): min_value = 0 - max_value = 2 ** 16 - 1 - db_type = 'UInt16' + max_value = 2**16 - 1 + db_type = "UInt16" class UInt32Field(BaseIntField): min_value = 0 - max_value = 2 ** 32 - 1 - db_type = 'UInt32' + max_value = 2**32 - 1 + db_type = "UInt32" class UInt64Field(BaseIntField): min_value = 0 - max_value = 2 ** 64 - 1 - db_type = 'UInt64' + max_value = 2**64 - 1 + db_type = "UInt64" class Int8Field(BaseIntField): - min_value = -2 ** 7 - max_value = 2 ** 7 - 1 - db_type = 'Int8' + min_value = -(2**7) + max_value = 2**7 - 1 + db_type = "Int8" class Int16Field(BaseIntField): - min_value = -2 ** 15 - max_value = 2 ** 15 - 1 - db_type = 'Int16' + min_value = -(2**15) + max_value = 2**15 - 1 + db_type = "Int16" class Int32Field(BaseIntField): - min_value = -2 ** 31 - max_value = 2 ** 31 - 1 - db_type = 'Int32' + min_value = -(2**31) + max_value = 2**31 - 1 + db_type = "Int32" class Int64Field(BaseIntField): - min_value = -2 ** 63 - max_value = 2 ** 63 - 1 - db_type = 'Int64' + min_value = -(2**63) + max_value = 2**63 - 1 + db_type = "Int64" class BaseFloatField(Field): @@ -425,7 +433,7 @@ class BaseFloatField(Field): try: return float(value) except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) def to_db_string(self, value, quote=True) -> str: # There's no need to call escape since numbers do not contain @@ -434,11 +442,11 @@ class BaseFloatField(Field): class Float32Field(BaseFloatField): - db_type = 'Float32' + db_type = "Float32" class Float64Field(BaseFloatField): - db_type = 'Float64' + db_type = "Float64" class DecimalField(Field): @@ -454,13 +462,13 @@ class DecimalField(Field): alias: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None, readonly: bool = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): - assert 1 <= precision <= 38, 'Precision must be between 1 and 38' - assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' + assert 1 <= precision <= 38, "Precision must be between 1 and 38" + assert 0 <= scale <= precision, "Scale must be between 0 and the given precision" self.precision = precision self.scale = scale - self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) + self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale) with localcontext() as ctx: ctx.prec = 38 self.exp = Decimal(10) ** -self.scale # for rounding to the required scale @@ -473,9 +481,9 @@ class DecimalField(Field): try: value = Decimal(value) except: - raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value)) if not value.is_finite(): - raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value)) return self._round(value) def to_db_string(self, value, quote=True) -> str: @@ -498,14 +506,13 @@ class Decimal32Field(DecimalField): alias: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None, readonly: bool = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): super().__init__(9, scale, default, alias, materialized, readonly, db_column) - self.db_type = 'Decimal32(%d)' % scale + self.db_type = "Decimal32(%d)" % scale class Decimal64Field(DecimalField): - def __init__( self, scale: int, @@ -513,14 +520,13 @@ class Decimal64Field(DecimalField): alias: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None, readonly: bool = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): super().__init__(18, scale, default, alias, materialized, readonly, db_column) - self.db_type = 'Decimal64(%d)' % scale + self.db_type = "Decimal64(%d)" % scale class Decimal128Field(DecimalField): - def __init__( self, scale: int, @@ -528,10 +534,10 @@ class Decimal128Field(DecimalField): alias: Optional[Union[F, str]] = None, materialized: Optional[Union[F, str]] = None, readonly: bool = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): super().__init__(38, scale, default, alias, materialized, readonly, db_column) - self.db_type = 'Decimal128(%d)' % scale + self.db_type = "Decimal128(%d)" % scale class BaseEnumField(Field): @@ -547,7 +553,7 @@ class BaseEnumField(Field): materialized: Optional[Union[F, str]] = None, readonly: bool = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): self.enum_cls = enum_cls if default is None: @@ -564,7 +570,7 @@ class BaseEnumField(Field): except Exception: return self.enum_cls(value) if isinstance(value, bytes): - decoded = value.decode('UTF-8') + decoded = value.decode("UTF-8") try: return self.enum_cls[decoded] except Exception: @@ -573,13 +579,13 @@ class BaseEnumField(Field): return self.enum_cls(value) except (KeyError, ValueError): pass - raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) + raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value)) def to_db_string(self, value, quote=True) -> str: return escape(value.name, quote) def get_db_type_args(self): - return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] + return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls] @classmethod def create_ad_hoc_field(cls, db_type) -> BaseEnumField: @@ -590,17 +596,17 @@ class BaseEnumField(Field): members = {} for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type): members[match.group(1)] = int(match.group(2)) - enum_cls = Enum('AdHocEnum', members) - field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field + enum_cls = Enum("AdHocEnum", members) + field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field return field_class(enum_cls) class Enum8Field(BaseEnumField): - db_type = 'Enum8' + db_type = "Enum8" class Enum16Field(BaseEnumField): - db_type = 'Enum16' + db_type = "Enum16" class ArrayField(Field): @@ -614,12 +620,14 @@ class ArrayField(Field): materialized: Optional[Union[F, str]] = None, readonly: bool = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): - assert isinstance(inner_field, Field), \ - "The first argument of ArrayField must be a Field instance" - assert not isinstance(inner_field, ArrayField), \ - "Multidimensional array fields are not supported by the ORM" + assert isinstance( + inner_field, Field + ), "The first argument of ArrayField must be a Field instance" + assert not isinstance( + inner_field, ArrayField + ), "Multidimensional array fields are not supported by the ORM" self.inner_field = inner_field super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column) @@ -627,9 +635,9 @@ class ArrayField(Field): if isinstance(value, str): value = parse_array(value) elif isinstance(value, bytes): - value = parse_array(value.decode('UTF-8')) + value = parse_array(value.decode("UTF-8")) elif not isinstance(value, (list, tuple)): - raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) + raise ValueError("ArrayField expects list or tuple, not %s" % type(value)) return [self.inner_field.to_python(v, timezone_in_use) for v in value] def validate(self, value): @@ -638,12 +646,12 @@ class ArrayField(Field): def to_db_string(self, value, quote=True) -> str: array = [self.inner_field.to_db_string(v, quote=True) for v in value] - return '[' + comma_join(array) + ']' + return "[" + comma_join(array) + "]" def get_sql(self, with_default_expression=True, db=None) -> str: - sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) + sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression and self.codec and db and db.has_codec_support: - sql += ' CODEC(%s)' % self.codec + sql += " CODEC(%s)" % self.codec return sql @@ -658,17 +666,19 @@ class TupleField(Field): materialized: Optional[Union[F, str]] = None, readonly: bool = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): self.names = {} self.inner_fields = [] for (name, field) in name_fields: if name in self.names: - raise ValueError('The Field name conflict') - assert isinstance(field, Field), \ - "The first argument of TupleField must be a Field instance" - assert not isinstance(field, (ArrayField, TupleField)), \ - "Multidimensional array fields are not supported by the ORM" + raise ValueError("The Field name conflict") + assert isinstance( + field, Field + ), "The first argument of TupleField must be a Field instance" + assert not isinstance( + field, (ArrayField, TupleField) + ), "Multidimensional array fields are not supported by the ORM" self.names[name] = field self.inner_fields.append(field) self.class_default = tuple(field.class_default for field in self.inner_fields) @@ -677,16 +687,19 @@ class TupleField(Field): def to_python(self, value, timezone_in_use) -> tuple: if isinstance(value, str): value = parse_array(value) - value = (self.inner_fields[i].to_python(v, timezone_in_use) - for i, v in enumerate(value)) + value = ( + self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value) + ) elif isinstance(value, bytes): - value = parse_array(value.decode('UTF-8')) - value = (self.inner_fields[i].to_python(v, timezone_in_use) - for i, v in enumerate(value)) + value = parse_array(value.decode("UTF-8")) + value = ( + self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value) + ) elif not isinstance(value, (list, tuple)): - raise ValueError('TupleField expects list or tuple, not %s' % type(value)) - return tuple(self.inner_fields[i].to_python(v, timezone_in_use) - for i, v in enumerate(value)) + raise ValueError("TupleField expects list or tuple, not %s" % type(value)) + return tuple( + self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value) + ) def validate(self, value): for i, v in enumerate(value): @@ -694,21 +707,22 @@ class TupleField(Field): def to_db_string(self, value, quote=True) -> str: array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)] - return '(' + comma_join(array) + ')' + return "(" + comma_join(array) + ")" def get_sql(self, with_default_expression=True, db=None) -> str: - inner_sql = ', '.join('%s %s' % (name, field.get_sql(False)) - for name, field in self.names.items()) + inner_sql = ", ".join( + "%s %s" % (name, field.get_sql(False)) for name, field in self.names.items() + ) - sql = 'Tuple(%s)' % inner_sql + sql = "Tuple(%s)" % inner_sql if with_default_expression and self.codec and db and db.has_codec_support: - sql += ' CODEC(%s)' % self.codec + sql += " CODEC(%s)" % self.codec return sql class UUIDField(Field): class_default = UUID(int=0) - db_type = 'UUID' + db_type = "UUID" def to_python(self, value, timezone_in_use) -> UUID: if isinstance(value, UUID): @@ -722,7 +736,7 @@ class UUIDField(Field): elif isinstance(value, tuple): return UUID(fields=value) else: - raise ValueError('Invalid value for UUIDField: %r' % value) + raise ValueError("Invalid value for UUIDField: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -730,7 +744,7 @@ class UUIDField(Field): class IPv4Field(Field): class_default = 0 - db_type = 'IPv4' + db_type = "IPv4" def to_python(self, value, timezone_in_use) -> IPv4Address: if isinstance(value, IPv4Address): @@ -738,7 +752,7 @@ class IPv4Field(Field): elif isinstance(value, (bytes, str, int)): return IPv4Address(value) else: - raise ValueError('Invalid value for IPv4Address: %r' % value) + raise ValueError("Invalid value for IPv4Address: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -746,7 +760,7 @@ class IPv4Field(Field): class IPv6Field(Field): class_default = 0 - db_type = 'IPv6' + db_type = "IPv6" def to_python(self, value, timezone_in_use) -> IPv6Address: if isinstance(value, IPv6Address): @@ -754,7 +768,7 @@ class IPv6Field(Field): elif isinstance(value, (bytes, str, int)): return IPv6Address(value) else: - raise ValueError('Invalid value for IPv6Address: %r' % value) + raise ValueError("Invalid value for IPv6Address: %r" % value) def to_db_string(self, value, quote=True): return escape(str(value), quote) @@ -771,11 +785,13 @@ class NullableField(Field): materialized: Optional[Union[F, str]] = None, extra_null_values: Optional[Iterable] = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): - assert isinstance(inner_field, Field), \ - "The first argument of NullableField must be a Field instance." \ - " Not: {}".format(inner_field) + assert isinstance( + inner_field, Field + ), "The first argument of NullableField must be a Field instance." " Not: {}".format( + inner_field + ) self.inner_field = inner_field self._null_values = [None] if extra_null_values: @@ -785,7 +801,7 @@ class NullableField(Field): ) def to_python(self, value, timezone_in_use): - if value == '\\N' or value in self._null_values: + if value == "\\N" or value in self._null_values: return None return self.inner_field.to_python(value, timezone_in_use) @@ -794,18 +810,17 @@ class NullableField(Field): def to_db_string(self, value, quote=True): if value in self._null_values: - return '\\N' + return "\\N" return self.inner_field.to_db_string(value, quote=quote) def get_sql(self, with_default_expression=True, db=None): - sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) + sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression: sql += self._extra_params(db) return sql class LowCardinalityField(Field): - def __init__( self, inner_field: Field, @@ -814,16 +829,20 @@ class LowCardinalityField(Field): materialized: Optional[Union[F, str]] = None, readonly: Optional[bool] = None, codec: Optional[str] = None, - db_column: Optional[str] = None + db_column: Optional[str] = None, ): - assert isinstance(inner_field, Field), \ - "The first argument of LowCardinalityField must be a Field instance." \ - " Not: {}".format(inner_field) - assert not isinstance(inner_field, LowCardinalityField), \ - "LowCardinality inner fields are not supported by the ORM" - assert not isinstance(inner_field, ArrayField), \ - "Array field inside LowCardinality are not supported by the ORM." \ + assert isinstance( + inner_field, Field + ), "The first argument of LowCardinalityField must be a Field instance." " Not: {}".format( + inner_field + ) + assert not isinstance( + inner_field, LowCardinalityField + ), "LowCardinality inner fields are not supported by the ORM" + assert not isinstance(inner_field, ArrayField), ( + "Array field inside LowCardinality are not supported by the ORM." " Use Array(LowCardinality) instead" + ) self.inner_field = inner_field self.class_default = self.inner_field.class_default super().__init__(default, alias, materialized, readonly, codec, db_column) @@ -839,12 +858,12 @@ class LowCardinalityField(Field): def get_sql(self, with_default_expression=True, db=None): if db and db.has_low_cardinality_support: - sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) + sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False) else: sql = self.inner_field.get_sql(with_default_expression=False) logger.warning( - 'LowCardinalityField not supported on clickhouse-server version < 19.0' - ' using {} as fallback'.format(self.inner_field.__class__.__name__) + "LowCardinalityField not supported on clickhouse-server version < 19.0" + " using {} as fallback".format(self.inner_field.__class__.__name__) ) if with_default_expression: sql += self._extra_params(db) diff --git a/src/clickhouse_orm/funcs.py b/src/clickhouse_orm/funcs.py index 0904221..1f7da39 100644 --- a/src/clickhouse_orm/funcs.py +++ b/src/clickhouse_orm/funcs.py @@ -10,11 +10,13 @@ def binary_operator(func): """ Decorates a function to mark it as a binary operator. """ + @wraps(func) def wrapper(*args, **kwargs): ret = func(*args, **kwargs) ret.is_binary_operator = True return ret + return wrapper @@ -24,10 +26,12 @@ def type_conversion(func): The metaclass automatically generates "OrZero" and "OrNull" combinators for the decorated function. """ + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) - wrapper.f_type = 'type_conversion' + + wrapper.f_type = "type_conversion" return wrapper @@ -37,10 +41,12 @@ def aggregate(func): The metaclass automatically generates combinators such as "OrDefault", "OrNull", "If" etc. for the decorated function. """ + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) - wrapper.f_type = 'aggregate' + + wrapper.f_type = "aggregate" return wrapper @@ -49,10 +55,12 @@ def with_utf8_support(func): Decorates a function to mark it as a string function that has a UTF8 variant. The metaclass automatically generates a "UTF8" combinator for the decorated function. """ + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) - wrapper.f_type = 'with_utf8_support' + + wrapper.f_type = "with_utf8_support" return wrapper @@ -61,6 +69,7 @@ def parametric(func): Decorates a function to convert it to a parametric function, such as `quantile(level)(expr)`. """ + @wraps(func) def wrapper(*parameters): @wraps(func) @@ -68,9 +77,11 @@ def parametric(func): f = func(*args, **kwargs) # Append the parameter to the function name parameters_str = comma_join(parameters, stringify=True) - f.name = '%s(%s)' % (f.name, parameters_str) + f.name = "%s(%s)" % (f.name, parameters_str) return f + return inner + wrapper.f_parametric = True return wrapper @@ -177,29 +188,29 @@ class FunctionOperatorsMixin: class FMeta(type): FUNCTION_COMBINATORS = { - 'type_conversion': [ - {'suffix': 'OrZero'}, - {'suffix': 'OrNull'}, + "type_conversion": [ + {"suffix": "OrZero"}, + {"suffix": "OrNull"}, ], - 'aggregate': [ - {'suffix': 'OrDefault'}, - {'suffix': 'OrNull'}, - {'suffix': 'If', 'args': ['cond']}, - {'suffix': 'OrDefaultIf', 'args': ['cond']}, - {'suffix': 'OrNullIf', 'args': ['cond']}, + "aggregate": [ + {"suffix": "OrDefault"}, + {"suffix": "OrNull"}, + {"suffix": "If", "args": ["cond"]}, + {"suffix": "OrDefaultIf", "args": ["cond"]}, + {"suffix": "OrNullIf", "args": ["cond"]}, + ], + "with_utf8_support": [ + {"suffix": "UTF8"}, ], - 'with_utf8_support': [ - {'suffix': 'UTF8'}, - ] } def __init__(cls, name, bases, dct): for name, obj in dct.items(): - if hasattr(obj, '__func__'): - f_type = getattr(obj.__func__, 'f_type', '') + if hasattr(obj, "__func__"): + f_type = getattr(obj.__func__, "f_type", "") for combinator in FMeta.FUNCTION_COMBINATORS.get(f_type, []): - new_name = name + combinator['suffix'] - FMeta._add_func(cls, obj.__func__, new_name, combinator.get('args')) + new_name = name + combinator["suffix"] + FMeta._add_func(cls, obj.__func__, new_name, combinator.get("args")) @staticmethod def _add_func(cls, base_func, new_name, extra_args): @@ -208,7 +219,7 @@ class FMeta(type): """ # Get the function's signature sig = signature(base_func) - new_sig = str(sig)[1 : -1] # omit the parentheses + new_sig = str(sig)[1:-1] # omit the parentheses args = comma_join(sig.parameters) # Add extra args if extra_args: @@ -221,11 +232,16 @@ class FMeta(type): # Get default values for args argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty) # Build the new function - new_code = compile('def {new_name}({new_sig}): return F("{new_name}", {args})'.format(**locals()), - __file__, 'exec') - new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs) + new_code = compile( + 'def {new_name}({new_sig}): return F("{new_name}", {args})'.format(**locals()), + __file__, + "exec", + ) + new_func = FunctionType( + code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs + ) # If base_func was parametric, new_func should be too - if getattr(base_func, 'f_parametric', False): + if getattr(base_func, "f_parametric", False): new_func = parametric(new_func) # Attach to class setattr(cls, new_name, new_func) @@ -236,6 +252,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): Represents a database function call and its arguments. It doubles as a query condition when the function returns a boolean result. """ + def __init__(self, name, *args): """ Initializer. @@ -257,116 +274,116 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): gcd(12, 300) """ if self.is_binary_operator: - prefix = '' - sep = ' ' + self.name + ' ' + prefix = "" + sep = " " + self.name + " " else: prefix = self.name - sep = ', ' + sep = ", " arg_strs = (arg_to_sql(arg) for arg in self.args if arg != NO_VALUE) - return prefix + '(' + sep.join(arg_strs) + ')' + return prefix + "(" + sep.join(arg_strs) + ")" # Arithmetic functions @staticmethod @binary_operator def plus(a, b): - return F('+', a, b) + return F("+", a, b) @staticmethod @binary_operator def minus(a, b): - return F('-', a, b) + return F("-", a, b) @staticmethod @binary_operator def multiply(a, b): - return F('*', a, b) + return F("*", a, b) @staticmethod @binary_operator def divide(a, b): - return F('/', a, b) + return F("/", a, b) @staticmethod def intDiv(a, b): - return F('intDiv', a, b) + return F("intDiv", a, b) @staticmethod def intDivOrZero(a, b): - return F('intDivOrZero', a, b) + return F("intDivOrZero", a, b) @staticmethod @binary_operator def modulo(a, b): - return F('%', a, b) + return F("%", a, b) @staticmethod def negate(a): - return F('negate', a) + return F("negate", a) @staticmethod def abs(a): - return F('abs', a) + return F("abs", a) @staticmethod def gcd(a, b): - return F('gcd', a, b) + return F("gcd", a, b) @staticmethod def lcm(a, b): - return F('lcm', a, b) + return F("lcm", a, b) # Comparison functions @staticmethod @binary_operator def equals(a, b): - return F('=', a, b) + return F("=", a, b) @staticmethod @binary_operator def notEquals(a, b): - return F('!=', a, b) + return F("!=", a, b) @staticmethod @binary_operator def less(a, b): - return F('<', a, b) + return F("<", a, b) @staticmethod @binary_operator def greater(a, b): - return F('>', a, b) + return F(">", a, b) @staticmethod @binary_operator def lessOrEquals(a, b): - return F('<=', a, b) + return F("<=", a, b) @staticmethod @binary_operator def greaterOrEquals(a, b): - return F('>=', a, b) + return F(">=", a, b) # Logical functions (should be used as python operators: & | ^ ~) @staticmethod @binary_operator def _and(a, b): - return F('AND', a, b) + return F("AND", a, b) @staticmethod @binary_operator def _or(a, b): - return F('OR', a, b) + return F("OR", a, b) @staticmethod def _xor(a, b): - return F('xor', a, b) + return F("xor", a, b) @staticmethod def _not(a): - return F('not', a) + return F("not", a) # in / not in @@ -375,1460 +392,1469 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): def _in(a, b): if is_iterable(b) and not isinstance(b, (tuple, QuerySet)): b = tuple(b) - return F('IN', a, b) + return F("IN", a, b) @staticmethod @binary_operator def _notIn(a, b): if is_iterable(b) and not isinstance(b, (tuple, QuerySet)): b = tuple(b) - return F('NOT IN', a, b) + return F("NOT IN", a, b) # Functions for working with dates and times @staticmethod def toYear(d): - return F('toYear', d) + return F("toYear", d) @staticmethod def toISOYear(d, timezone=NO_VALUE): - return F('toISOYear', d, timezone) + return F("toISOYear", d, timezone) @staticmethod def toQuarter(d, timezone=NO_VALUE): - return F('toQuarter', d, timezone) if timezone else F('toQuarter', d) + return F("toQuarter", d, timezone) if timezone else F("toQuarter", d) @staticmethod def toMonth(d): - return F('toMonth', d) + return F("toMonth", d) @staticmethod def toWeek(d, mode=0, timezone=NO_VALUE): - return F('toWeek', d, mode, timezone) + return F("toWeek", d, mode, timezone) @staticmethod def toISOWeek(d, timezone=NO_VALUE): - return F('toISOWeek', d, timezone) if timezone else F('toISOWeek', d) + return F("toISOWeek", d, timezone) if timezone else F("toISOWeek", d) @staticmethod def toDayOfYear(d): - return F('toDayOfYear', d) + return F("toDayOfYear", d) @staticmethod def toDayOfMonth(d): - return F('toDayOfMonth', d) + return F("toDayOfMonth", d) @staticmethod def toDayOfWeek(d): - return F('toDayOfWeek', d) + return F("toDayOfWeek", d) @staticmethod def toHour(d): - return F('toHour', d) + return F("toHour", d) @staticmethod def toMinute(d): - return F('toMinute', d) + return F("toMinute", d) @staticmethod def toSecond(d): - return F('toSecond', d) + return F("toSecond", d) @staticmethod def toMonday(d): - return F('toMonday', d) + return F("toMonday", d) @staticmethod def toStartOfMonth(d): - return F('toStartOfMonth', d) + return F("toStartOfMonth", d) @staticmethod def toStartOfQuarter(d): - return F('toStartOfQuarter', d) + return F("toStartOfQuarter", d) @staticmethod def toStartOfYear(d): - return F('toStartOfYear', d) + return F("toStartOfYear", d) @staticmethod def toStartOfISOYear(d): - return F('toStartOfISOYear', d) + return F("toStartOfISOYear", d) @staticmethod def toStartOfTenMinutes(d): - return F('toStartOfTenMinutes', d) + return F("toStartOfTenMinutes", d) @staticmethod def toStartOfWeek(d, mode=0): - return F('toStartOfWeek', d) + return F("toStartOfWeek", d) @staticmethod def toStartOfMinute(d): - return F('toStartOfMinute', d) + return F("toStartOfMinute", d) @staticmethod def toStartOfFiveMinute(d): - return F('toStartOfFiveMinute', d) + return F("toStartOfFiveMinute", d) @staticmethod def toStartOfFifteenMinutes(d): - return F('toStartOfFifteenMinutes', d) + return F("toStartOfFifteenMinutes", d) @staticmethod def toStartOfHour(d): - return F('toStartOfHour', d) + return F("toStartOfHour", d) @staticmethod def toStartOfDay(d): - return F('toStartOfDay', d) + return F("toStartOfDay", d) @staticmethod def toTime(d, timezone=NO_VALUE): - return F('toTime', d, timezone) + return F("toTime", d, timezone) @staticmethod def toTimeZone(dt, timezone): - return F('toTimeZone', dt, timezone) + return F("toTimeZone", dt, timezone) @staticmethod def toUnixTimestamp(dt, timezone=NO_VALUE): - return F('toUnixTimestamp', dt, timezone) + return F("toUnixTimestamp", dt, timezone) @staticmethod def toYYYYMM(dt, timezone=NO_VALUE): - return F('toYYYYMM', dt, timezone) if timezone else F('toYYYYMM', dt) + return F("toYYYYMM", dt, timezone) if timezone else F("toYYYYMM", dt) @staticmethod def toYYYYMMDD(dt, timezone=NO_VALUE): - return F('toYYYYMMDD', dt, timezone) if timezone else F('toYYYYMMDD', dt) + return F("toYYYYMMDD", dt, timezone) if timezone else F("toYYYYMMDD", dt) @staticmethod def toYYYYMMDDhhmmss(dt, timezone=NO_VALUE): - return F('toYYYYMMDDhhmmss', dt, timezone) if timezone else F('toYYYYMMDDhhmmss', dt) + return F("toYYYYMMDDhhmmss", dt, timezone) if timezone else F("toYYYYMMDDhhmmss", dt) @staticmethod def toRelativeYearNum(d, timezone=NO_VALUE): - return F('toRelativeYearNum', d, timezone) + return F("toRelativeYearNum", d, timezone) @staticmethod def toRelativeMonthNum(d, timezone=NO_VALUE): - return F('toRelativeMonthNum', d, timezone) + return F("toRelativeMonthNum", d, timezone) @staticmethod def toRelativeWeekNum(d, timezone=NO_VALUE): - return F('toRelativeWeekNum', d, timezone) + return F("toRelativeWeekNum", d, timezone) @staticmethod def toRelativeDayNum(d, timezone=NO_VALUE): - return F('toRelativeDayNum', d, timezone) + return F("toRelativeDayNum", d, timezone) @staticmethod def toRelativeHourNum(d, timezone=NO_VALUE): - return F('toRelativeHourNum', d, timezone) + return F("toRelativeHourNum", d, timezone) @staticmethod def toRelativeMinuteNum(d, timezone=NO_VALUE): - return F('toRelativeMinuteNum', d, timezone) + return F("toRelativeMinuteNum", d, timezone) @staticmethod def toRelativeSecondNum(d, timezone=NO_VALUE): - return F('toRelativeSecondNum', d, timezone) + return F("toRelativeSecondNum", d, timezone) @staticmethod def now(): - return F('now') + return F("now") @staticmethod def today(): - return F('today') + return F("today") @staticmethod def yesterday(): - return F('yesterday') + return F("yesterday") @staticmethod def timeSlot(d): - return F('timeSlot', d) + return F("timeSlot", d) @staticmethod def timeSlots(start_time, duration): - return F('timeSlots', start_time, F.toUInt32(duration)) + return F("timeSlots", start_time, F.toUInt32(duration)) @staticmethod def formatDateTime(d, format, timezone=NO_VALUE): - return F('formatDateTime', d, format, timezone) + return F("formatDateTime", d, format, timezone) @staticmethod def addDays(d, n, timezone=NO_VALUE): - return F('addDays', d, n, timezone) + return F("addDays", d, n, timezone) @staticmethod def addHours(d, n, timezone=NO_VALUE): - return F('addHours', d, n, timezone) + return F("addHours", d, n, timezone) @staticmethod def addMinutes(d, n, timezone=NO_VALUE): - return F('addMinutes', d, n, timezone) + return F("addMinutes", d, n, timezone) @staticmethod def addMonths(d, n, timezone=NO_VALUE): - return F('addMonths', d, n, timezone) + return F("addMonths", d, n, timezone) @staticmethod def addQuarters(d, n, timezone=NO_VALUE): - return F('addQuarters', d, n, timezone) + return F("addQuarters", d, n, timezone) @staticmethod def addSeconds(d, n, timezone=NO_VALUE): - return F('addSeconds', d, n, timezone) + return F("addSeconds", d, n, timezone) @staticmethod def addWeeks(d, n, timezone=NO_VALUE): - return F('addWeeks', d, n, timezone) + return F("addWeeks", d, n, timezone) @staticmethod def addYears(d, n, timezone=NO_VALUE): - return F('addYears', d, n, timezone) + return F("addYears", d, n, timezone) @staticmethod def subtractDays(d, n, timezone=NO_VALUE): - return F('subtractDays', d, n, timezone) + return F("subtractDays", d, n, timezone) @staticmethod def subtractHours(d, n, timezone=NO_VALUE): - return F('subtractHours', d, n, timezone) + return F("subtractHours", d, n, timezone) @staticmethod def subtractMinutes(d, n, timezone=NO_VALUE): - return F('subtractMinutes', d, n, timezone) + return F("subtractMinutes", d, n, timezone) @staticmethod def subtractMonths(d, n, timezone=NO_VALUE): - return F('subtractMonths', d, n, timezone) + return F("subtractMonths", d, n, timezone) @staticmethod def subtractQuarters(d, n, timezone=NO_VALUE): - return F('subtractQuarters', d, n, timezone) + return F("subtractQuarters", d, n, timezone) @staticmethod def subtractSeconds(d, n, timezone=NO_VALUE): - return F('subtractSeconds', d, n, timezone) + return F("subtractSeconds", d, n, timezone) @staticmethod def subtractWeeks(d, n, timezone=NO_VALUE): - return F('subtractWeeks', d, n, timezone) + return F("subtractWeeks", d, n, timezone) @staticmethod def subtractYears(d, n, timezone=NO_VALUE): - return F('subtractYears', d, n, timezone) + return F("subtractYears", d, n, timezone) @staticmethod def toIntervalSecond(number): - return F('toIntervalSecond', number) + return F("toIntervalSecond", number) @staticmethod def toIntervalMinute(number): - return F('toIntervalMinute', number) + return F("toIntervalMinute", number) @staticmethod def toIntervalHour(number): - return F('toIntervalHour', number) + return F("toIntervalHour", number) @staticmethod def toIntervalDay(number): - return F('toIntervalDay', number) + return F("toIntervalDay", number) @staticmethod def toIntervalWeek(number): - return F('toIntervalWeek', number) + return F("toIntervalWeek", number) @staticmethod def toIntervalMonth(number): - return F('toIntervalMonth', number) + return F("toIntervalMonth", number) @staticmethod def toIntervalQuarter(number): - return F('toIntervalQuarter', number) + return F("toIntervalQuarter", number) @staticmethod def toIntervalYear(number): - return F('toIntervalYear', number) - + return F("toIntervalYear", number) # Type conversion functions @staticmethod @type_conversion def toUInt8(x): - return F('toUInt8', x) + return F("toUInt8", x) @staticmethod @type_conversion def toUInt16(x): - return F('toUInt16', x) + return F("toUInt16", x) @staticmethod @type_conversion def toUInt32(x): - return F('toUInt32', x) + return F("toUInt32", x) @staticmethod @type_conversion def toUInt64(x): - return F('toUInt64', x) + return F("toUInt64", x) @staticmethod @type_conversion def toInt8(x): - return F('toInt8', x) + return F("toInt8", x) @staticmethod @type_conversion def toInt16(x): - return F('toInt16', x) + return F("toInt16", x) @staticmethod @type_conversion def toInt32(x): - return F('toInt32', x) + return F("toInt32", x) @staticmethod @type_conversion def toInt64(x): - return F('toInt64', x) + return F("toInt64", x) @staticmethod @type_conversion def toFloat32(x): - return F('toFloat32', x) + return F("toFloat32", x) @staticmethod @type_conversion def toFloat64(x): - return F('toFloat64', x) + return F("toFloat64", x) @staticmethod @type_conversion def toDecimal32(x, scale): - return F('toDecimal32', x, scale) + return F("toDecimal32", x, scale) @staticmethod @type_conversion def toDecimal64(x, scale): - return F('toDecimal64', x, scale) + return F("toDecimal64", x, scale) @staticmethod @type_conversion def toDecimal128(x, scale): - return F('toDecimal128', x, scale) + return F("toDecimal128", x, scale) @staticmethod @type_conversion def toDate(x): - return F('toDate', x) + return F("toDate", x) @staticmethod @type_conversion def toDateTime(x): - return F('toDateTime', x) + return F("toDateTime", x) @staticmethod @type_conversion def toDateTime64(x, precision, timezone=NO_VALUE): - return F('toDateTime64', x, precision, timezone) + return F("toDateTime64", x, precision, timezone) @staticmethod def toString(x): - return F('toString', x) + return F("toString", x) @staticmethod def toFixedString(s, length): - return F('toFixedString', s, length) + return F("toFixedString", s, length) @staticmethod def toStringCutToZero(s): - return F('toStringCutToZero', s) + return F("toStringCutToZero", s) @staticmethod def CAST(x, type): - return F('CAST', x, type) + return F("CAST", x, type) @staticmethod @type_conversion def parseDateTimeBestEffort(d, timezone=NO_VALUE): - return F('parseDateTimeBestEffort', d, timezone) + return F("parseDateTimeBestEffort", d, timezone) # Functions for working with strings @staticmethod def empty(s): - return F('empty', s) + return F("empty", s) @staticmethod def notEmpty(s): - return F('notEmpty', s) + return F("notEmpty", s) @staticmethod @with_utf8_support def length(s): - return F('length', s) + return F("length", s) @staticmethod @with_utf8_support def lower(s): - return F('lower', s) + return F("lower", s) @staticmethod @with_utf8_support def upper(s): - return F('upper', s) + return F("upper", s) @staticmethod @with_utf8_support def reverse(s): - return F('reverse', s) + return F("reverse", s) @staticmethod def concat(*args): - return F('concat', *args) + return F("concat", *args) @staticmethod @with_utf8_support def substring(s, offset, length): - return F('substring', s, offset, length) + return F("substring", s, offset, length) @staticmethod def appendTrailingCharIfAbsent(s, c): - return F('appendTrailingCharIfAbsent', s, c) + return F("appendTrailingCharIfAbsent", s, c) @staticmethod def convertCharset(s, from_charset, to_charset): - return F('convertCharset', s, from_charset, to_charset) + return F("convertCharset", s, from_charset, to_charset) @staticmethod def base64Encode(s): - return F('base64Encode', s) + return F("base64Encode", s) @staticmethod def base64Decode(s): - return F('base64Decode', s) + return F("base64Decode", s) @staticmethod def tryBase64Decode(s): - return F('tryBase64Decode', s) + return F("tryBase64Decode", s) @staticmethod def endsWith(s, suffix): - return F('endsWith', s, suffix) + return F("endsWith", s, suffix) @staticmethod def startsWith(s, prefix): - return F('startsWith', s, prefix) + return F("startsWith", s, prefix) @staticmethod def trimLeft(s): - return F('trimLeft', s) + return F("trimLeft", s) @staticmethod def trimRight(s): - return F('trimRight', s) + return F("trimRight", s) @staticmethod def trimBoth(s): - return F('trimBoth', s) + return F("trimBoth", s) @staticmethod def CRC32(s): - return F('CRC32', s) + return F("CRC32", s) # Functions for searching in strings @staticmethod @with_utf8_support def position(haystack, needle): - return F('position', haystack, needle) + return F("position", haystack, needle) @staticmethod @with_utf8_support def positionCaseInsensitive(haystack, needle): - return F('positionCaseInsensitive', haystack, needle) + return F("positionCaseInsensitive", haystack, needle) @staticmethod def like(haystack, pattern): - return F('like', haystack, pattern) + return F("like", haystack, pattern) @staticmethod def notLike(haystack, pattern): - return F('notLike', haystack, pattern) + return F("notLike", haystack, pattern) @staticmethod def match(haystack, pattern): - return F('match', haystack, pattern) + return F("match", haystack, pattern) @staticmethod def extract(haystack, pattern): - return F('extract', haystack, pattern) + return F("extract", haystack, pattern) @staticmethod def extractAll(haystack, pattern): - return F('extractAll', haystack, pattern) + return F("extractAll", haystack, pattern) @staticmethod @with_utf8_support def ngramDistance(haystack, needle): - return F('ngramDistance', haystack, needle) + return F("ngramDistance", haystack, needle) @staticmethod @with_utf8_support def ngramDistanceCaseInsensitive(haystack, needle): - return F('ngramDistanceCaseInsensitive', haystack, needle) + return F("ngramDistanceCaseInsensitive", haystack, needle) @staticmethod @with_utf8_support def ngramSearch(haystack, needle): - return F('ngramSearch', haystack, needle) + return F("ngramSearch", haystack, needle) @staticmethod @with_utf8_support def ngramSearchCaseInsensitive(haystack, needle): - return F('ngramSearchCaseInsensitive', haystack, needle) + return F("ngramSearchCaseInsensitive", haystack, needle) # Functions for replacing in strings @staticmethod def replace(haystack, pattern, replacement): - return F('replace', haystack, pattern, replacement) + return F("replace", haystack, pattern, replacement) @staticmethod def replaceAll(haystack, pattern, replacement): - return F('replaceAll', haystack, pattern, replacement) + return F("replaceAll", haystack, pattern, replacement) @staticmethod def replaceOne(haystack, pattern, replacement): - return F('replaceOne', haystack, pattern, replacement) + return F("replaceOne", haystack, pattern, replacement) @staticmethod def replaceRegexpAll(haystack, pattern, replacement): - return F('replaceRegexpAll', haystack, pattern, replacement) + return F("replaceRegexpAll", haystack, pattern, replacement) @staticmethod def replaceRegexpOne(haystack, pattern, replacement): - return F('replaceRegexpOne', haystack, pattern, replacement) + return F("replaceRegexpOne", haystack, pattern, replacement) @staticmethod def regexpQuoteMeta(x): - return F('regexpQuoteMeta', x) + return F("regexpQuoteMeta", x) # Mathematical functions @staticmethod def e(): - return F('e') + return F("e") @staticmethod def pi(): - return F('pi') + return F("pi") @staticmethod def exp(x): - return F('exp', x) + return F("exp", x) @staticmethod def log(x): - return F('log', x) + return F("log", x) + ln = log @staticmethod def exp2(x): - return F('exp2', x) + return F("exp2", x) @staticmethod def log2(x): - return F('log2', x) + return F("log2", x) @staticmethod def exp10(x): - return F('exp10', x) + return F("exp10", x) @staticmethod def log10(x): - return F('log10', x) + return F("log10", x) @staticmethod def sqrt(x): - return F('sqrt', x) + return F("sqrt", x) @staticmethod def cbrt(x): - return F('cbrt', x) + return F("cbrt", x) @staticmethod def erf(x): - return F('erf', x) + return F("erf", x) @staticmethod def erfc(x): - return F('erfc', x) + return F("erfc", x) @staticmethod def lgamma(x): - return F('lgamma', x) + return F("lgamma", x) @staticmethod def tgamma(x): - return F('tgamma', x) + return F("tgamma", x) @staticmethod def sin(x): - return F('sin', x) + return F("sin", x) @staticmethod def cos(x): - return F('cos', x) + return F("cos", x) @staticmethod def tan(x): - return F('tan', x) + return F("tan", x) @staticmethod def asin(x): - return F('asin', x) + return F("asin", x) @staticmethod def acos(x): - return F('acos', x) + return F("acos", x) @staticmethod def atan(x): - return F('atan', x) + return F("atan", x) @staticmethod def power(x, y): - return F('power', x, y) + return F("power", x, y) + pow = power @staticmethod def intExp10(x): - return F('intExp10', x) + return F("intExp10", x) @staticmethod def intExp2(x): - return F('intExp2', x) + return F("intExp2", x) # Rounding functions @staticmethod def floor(x, n=None): - return F('floor', x, n) if n else F('floor', x) + return F("floor", x, n) if n else F("floor", x) @staticmethod def ceiling(x, n=None): - return F('ceiling', x, n) if n else F('ceiling', x) + return F("ceiling", x, n) if n else F("ceiling", x) + ceil = ceiling @staticmethod def round(x, n=None): - return F('round', x, n) if n else F('round', x) + return F("round", x, n) if n else F("round", x) @staticmethod def roundAge(x): - return F('roundAge', x) + return F("roundAge", x) @staticmethod def roundDown(x, y): - return F('roundDown', x, y) + return F("roundDown", x, y) @staticmethod def roundDuration(x): - return F('roundDuration', x) + return F("roundDuration", x) @staticmethod def roundToExp2(x): - return F('roundToExp2', x) + return F("roundToExp2", x) # Functions for working with arrays @staticmethod def emptyArrayDate(): - return F('emptyArrayDate') + return F("emptyArrayDate") @staticmethod def emptyArrayDateTime(): - return F('emptyArrayDateTime') + return F("emptyArrayDateTime") @staticmethod def emptyArrayFloat32(): - return F('emptyArrayFloat32') + return F("emptyArrayFloat32") @staticmethod def emptyArrayFloat64(): - return F('emptyArrayFloat64') + return F("emptyArrayFloat64") @staticmethod def emptyArrayInt16(): - return F('emptyArrayInt16') + return F("emptyArrayInt16") @staticmethod def emptyArrayInt32(): - return F('emptyArrayInt32') + return F("emptyArrayInt32") @staticmethod def emptyArrayInt64(): - return F('emptyArrayInt64') + return F("emptyArrayInt64") @staticmethod def emptyArrayInt8(): - return F('emptyArrayInt8') + return F("emptyArrayInt8") @staticmethod def emptyArrayString(): - return F('emptyArrayString') + return F("emptyArrayString") @staticmethod def emptyArrayUInt16(): - return F('emptyArrayUInt16') + return F("emptyArrayUInt16") @staticmethod def emptyArrayUInt32(): - return F('emptyArrayUInt32') + return F("emptyArrayUInt32") @staticmethod def emptyArrayUInt64(): - return F('emptyArrayUInt64') + return F("emptyArrayUInt64") @staticmethod def emptyArrayUInt8(): - return F('emptyArrayUInt8') + return F("emptyArrayUInt8") @staticmethod def emptyArrayToSingle(x): - return F('emptyArrayToSingle', x) + return F("emptyArrayToSingle", x) @staticmethod def range(n): - return F('range', n) + return F("range", n) @staticmethod def array(*args): - return F('array', *args) + return F("array", *args) @staticmethod def arrayConcat(*args): - return F('arrayConcat', *args) + return F("arrayConcat", *args) @staticmethod def arrayElement(arr, n): - return F('arrayElement', arr, n) + return F("arrayElement", arr, n) @staticmethod def tupleElement(arr, n): - return F('tupleElement', arr, n) + return F("tupleElement", arr, n) @staticmethod def has(arr, x): - return F('has', arr, x) + return F("has", arr, x) @staticmethod def hasAll(arr, x): - return F('hasAll', arr, x) + return F("hasAll", arr, x) @staticmethod def hasAny(arr, x): - return F('hasAny', arr, x) + return F("hasAny", arr, x) @staticmethod def geohashEncode(x, y, precision=12): - return F('geohashEncode', x, y, precision) + return F("geohashEncode", x, y, precision) @staticmethod def indexOf(arr, x): - return F('indexOf', arr, x) + return F("indexOf", arr, x) @staticmethod def countEqual(arr, x): - return F('countEqual', arr, x) + return F("countEqual", arr, x) @staticmethod def arrayEnumerate(arr): - return F('arrayEnumerate', arr) + return F("arrayEnumerate", arr) @staticmethod def arrayEnumerateDense(*args): - return F('arrayEnumerateDense', *args) + return F("arrayEnumerateDense", *args) @staticmethod def arrayEnumerateDenseRanked(*args): - return F('arrayEnumerateDenseRanked', *args) + return F("arrayEnumerateDenseRanked", *args) @staticmethod def arrayEnumerateUniq(*args): - return F('arrayEnumerateUniq', *args) + return F("arrayEnumerateUniq", *args) @staticmethod def arrayEnumerateUniqRanked(*args): - return F('arrayEnumerateUniqRanked', *args) + return F("arrayEnumerateUniqRanked", *args) @staticmethod def arrayPopBack(arr): - return F('arrayPopBack', arr) + return F("arrayPopBack", arr) @staticmethod def arrayPopFront(arr): - return F('arrayPopFront', arr) + return F("arrayPopFront", arr) @staticmethod def arrayPushBack(arr, x): - return F('arrayPushBack', arr, x) + return F("arrayPushBack", arr, x) @staticmethod def arrayPushFront(arr, x): - return F('arrayPushFront', arr, x) + return F("arrayPushFront", arr, x) @staticmethod def arrayResize(array, size, extender=None): - return F('arrayResize', array, size, extender) if extender is not None else F('arrayResize', array, size) + return ( + F("arrayResize", array, size, extender) + if extender is not None + else F("arrayResize", array, size) + ) @staticmethod def arraySlice(array, offset, length=None): - return F('arraySlice', array, offset, length) if length is not None else F('arraySlice', array, offset) + return ( + F("arraySlice", array, offset, length) + if length is not None + else F("arraySlice", array, offset) + ) @staticmethod def arrayUniq(*args): - return F('arrayUniq', *args) + return F("arrayUniq", *args) @staticmethod def arrayJoin(arr): - return F('arrayJoin', arr) + return F("arrayJoin", arr) @staticmethod def arrayDifference(arr): - return F('arrayDifference', arr) + return F("arrayDifference", arr) @staticmethod def arrayDistinct(x): - return F('arrayDistinct', x) + return F("arrayDistinct", x) @staticmethod def arrayIntersect(*args): - return F('arrayIntersect', *args) + return F("arrayIntersect", *args) @staticmethod def arrayReduce(agg_func_name, *args): - return F('arrayReduce', agg_func_name, *args) + return F("arrayReduce", agg_func_name, *args) @staticmethod def arrayReverse(arr): - return F('arrayReverse', arr) + return F("arrayReverse", arr) # Functions for splitting and merging strings and arrays @staticmethod def splitByChar(sep, s): - return F('splitByChar', sep, s) + return F("splitByChar", sep, s) @staticmethod def splitByString(sep, s): - return F('splitByString', sep, s) + return F("splitByString", sep, s) @staticmethod def arrayStringConcat(arr, sep=None): - return F('arrayStringConcat', arr, sep) if sep else F('arrayStringConcat', arr) + return F("arrayStringConcat", arr, sep) if sep else F("arrayStringConcat", arr) @staticmethod def alphaTokens(s): - return F('alphaTokens', s) + return F("alphaTokens", s) # Bit functions @staticmethod def bitAnd(x, y): - return F('bitAnd', x, y) + return F("bitAnd", x, y) @staticmethod def bitNot(x): - return F('bitNot', x) + return F("bitNot", x) @staticmethod def bitOr(x, y): - return F('bitOr', x, y) + return F("bitOr", x, y) @staticmethod def bitRotateLeft(x, y): - return F('bitRotateLeft', x, y) + return F("bitRotateLeft", x, y) @staticmethod def bitRotateRight(x, y): - return F('bitRotateRight', x, y) + return F("bitRotateRight", x, y) @staticmethod def bitShiftLeft(x, y): - return F('bitShiftLeft', x, y) + return F("bitShiftLeft", x, y) @staticmethod def bitShiftRight(x, y): - return F('bitShiftRight', x, y) + return F("bitShiftRight", x, y) @staticmethod def bitTest(x, y): - return F('bitTest', x, y) + return F("bitTest", x, y) @staticmethod def bitTestAll(x, *args): - return F('bitTestAll', x, *args) + return F("bitTestAll", x, *args) @staticmethod def bitTestAny(x, *args): - return F('bitTestAny', x, *args) + return F("bitTestAny", x, *args) @staticmethod def bitXor(x, y): - return F('bitXor', x, y) + return F("bitXor", x, y) # Bitmap functions @staticmethod def bitmapAnd(x, y): - return F('bitmapAnd', x, y) + return F("bitmapAnd", x, y) @staticmethod def bitmapAndCardinality(x, y): - return F('bitmapAndCardinality', x, y) + return F("bitmapAndCardinality", x, y) @staticmethod def bitmapAndnot(x, y): - return F('bitmapAndnot', x, y) + return F("bitmapAndnot", x, y) @staticmethod def bitmapAndnotCardinality(x, y): - return F('bitmapAndnotCardinality', x, y) + return F("bitmapAndnotCardinality", x, y) @staticmethod def bitmapBuild(x): - return F('bitmapBuild', x) + return F("bitmapBuild", x) @staticmethod def bitmapCardinality(x): - return F('bitmapCardinality', x) + return F("bitmapCardinality", x) @staticmethod def bitmapContains(haystack, needle): - return F('bitmapContains', haystack, needle) + return F("bitmapContains", haystack, needle) @staticmethod def bitmapHasAll(x, y): - return F('bitmapHasAll', x, y) + return F("bitmapHasAll", x, y) @staticmethod def bitmapHasAny(x, y): - return F('bitmapHasAny', x, y) + return F("bitmapHasAny", x, y) @staticmethod def bitmapOr(x, y): - return F('bitmapOr', x, y) + return F("bitmapOr", x, y) @staticmethod def bitmapOrCardinality(x, y): - return F('bitmapOrCardinality', x, y) + return F("bitmapOrCardinality", x, y) @staticmethod def bitmapToArray(x): - return F('bitmapToArray', x) + return F("bitmapToArray", x) @staticmethod def bitmapXor(x, y): - return F('bitmapXor', x, y) + return F("bitmapXor", x, y) @staticmethod def bitmapXorCardinality(x, y): - return F('bitmapXorCardinality', x, y) + return F("bitmapXorCardinality", x, y) # Hash functions @staticmethod def halfMD5(*args): - return F('halfMD5', *args) + return F("halfMD5", *args) @staticmethod def MD5(s): - return F('MD5', s) + return F("MD5", s) @staticmethod def sipHash128(*args): - return F('sipHash128', *args) + return F("sipHash128", *args) @staticmethod def sipHash64(*args): - return F('sipHash64', *args) + return F("sipHash64", *args) @staticmethod def cityHash64(*args): - return F('cityHash64', *args) + return F("cityHash64", *args) @staticmethod def intHash32(x): - return F('intHash32', x) + return F("intHash32", x) @staticmethod def intHash64(x): - return F('intHash64', x) + return F("intHash64", x) @staticmethod def SHA1(s): - return F('SHA1', s) + return F("SHA1", s) @staticmethod def SHA224(s): - return F('SHA224', s) + return F("SHA224", s) @staticmethod def SHA256(s): - return F('SHA256', s) + return F("SHA256", s) @staticmethod def URLHash(url, n=None): - return F('URLHash', url, n) if n is not None else F('URLHash', url) + return F("URLHash", url, n) if n is not None else F("URLHash", url) @staticmethod def farmHash64(*args): - return F('farmHash64',*args) + return F("farmHash64", *args) @staticmethod def javaHash(s): - return F('javaHash', s) + return F("javaHash", s) @staticmethod def hiveHash(s): - return F('hiveHash', s) + return F("hiveHash", s) @staticmethod def metroHash64(*args): - return F('metroHash64', *args) + return F("metroHash64", *args) @staticmethod def jumpConsistentHash(x, buckets): - return F('jumpConsistentHash', x, buckets) + return F("jumpConsistentHash", x, buckets) @staticmethod def murmurHash2_32(*args): - return F('murmurHash2_32', *args) + return F("murmurHash2_32", *args) @staticmethod def murmurHash2_64(*args): - return F('murmurHash2_64', *args) + return F("murmurHash2_64", *args) @staticmethod def murmurHash3_32(*args): - return F('murmurHash3_32', *args) + return F("murmurHash3_32", *args) @staticmethod def murmurHash3_64(*args): - return F('murmurHash3_64', *args) + return F("murmurHash3_64", *args) @staticmethod def murmurHash3_128(s): - return F('murmurHash3_128', s) + return F("murmurHash3_128", s) @staticmethod def xxHash32(*args): - return F('xxHash32', *args) + return F("xxHash32", *args) @staticmethod def xxHash64(*args): - return F('xxHash64', *args) + return F("xxHash64", *args) # Functions for generating pseudo-random numbers @staticmethod def rand(dummy=None): - return F('rand') if dummy is None else F('rand', dummy) + return F("rand") if dummy is None else F("rand", dummy) @staticmethod def rand64(dummy=None): - return F('rand64') if dummy is None else F('rand64', dummy) + return F("rand64") if dummy is None else F("rand64", dummy) @staticmethod def randConstant(dummy=None): - return F('randConstant') if dummy is None else F('randConstant', dummy) + return F("randConstant") if dummy is None else F("randConstant", dummy) # Encoding functions @staticmethod def hex(x): - return F('hex', x) + return F("hex", x) @staticmethod def unhex(x): - return F('unhex', x) + return F("unhex", x) @staticmethod def bitmaskToArray(x): - return F('bitmaskToArray', x) + return F("bitmaskToArray", x) @staticmethod def bitmaskToList(x): - return F('bitmaskToList', x) + return F("bitmaskToList", x) # Functions for working with UUID @staticmethod def generateUUIDv4(): - return F('generateUUIDv4') + return F("generateUUIDv4") @staticmethod def toUUID(s): - return F('toUUID', s) + return F("toUUID", s) @staticmethod def UUIDNumToString(s): - return F('UUIDNumToString', s) + return F("UUIDNumToString", s) @staticmethod def UUIDStringToNum(s): - return F('UUIDStringToNum', s) + return F("UUIDStringToNum", s) # Functions for working with IP addresses @staticmethod def IPv4CIDRToRange(ipv4, cidr): - return F('IPv4CIDRToRange', ipv4, cidr) + return F("IPv4CIDRToRange", ipv4, cidr) @staticmethod def IPv4NumToString(num): - return F('IPv4NumToString', num) + return F("IPv4NumToString", num) @staticmethod def IPv4NumToStringClassC(num): - return F('IPv4NumToStringClassC', num) + return F("IPv4NumToStringClassC", num) @staticmethod def IPv4StringToNum(s): - return F('IPv4StringToNum', s) + return F("IPv4StringToNum", s) @staticmethod def IPv4ToIPv6(ipv4): - return F('IPv4ToIPv6', ipv4) + return F("IPv4ToIPv6", ipv4) @staticmethod def IPv6CIDRToRange(ipv6, cidr): - return F('IPv6CIDRToRange', ipv6, cidr) + return F("IPv6CIDRToRange", ipv6, cidr) @staticmethod def IPv6NumToString(num): - return F('IPv6NumToString', num) + return F("IPv6NumToString", num) @staticmethod def IPv6StringToNum(s): - return F('IPv6StringToNum', s) + return F("IPv6StringToNum", s) @staticmethod def toIPv4(ipv4): - return F('toIPv4', ipv4) + return F("toIPv4", ipv4) @staticmethod def toIPv6(ipv6): - return F('toIPv6', ipv6) + return F("toIPv6", ipv6) # Aggregate functions @staticmethod @aggregate def any(x): - return F('any', x) + return F("any", x) @staticmethod @aggregate def anyHeavy(x): - return F('anyHeavy', x) + return F("anyHeavy", x) @staticmethod @aggregate def anyLast(x): - return F('anyLast', x) + return F("anyLast", x) @staticmethod @aggregate def argMax(x, y): - return F('argMax', x, y) + return F("argMax", x, y) @staticmethod @aggregate def argMin(x, y): - return F('argMin', x, y) + return F("argMin", x, y) @staticmethod @aggregate def avg(x): - return F('avg', x) + return F("avg", x) @staticmethod @aggregate def corr(x, y): - return F('corr', x, y) + return F("corr", x, y) @staticmethod @aggregate def count(): - return F('count') + return F("count") @staticmethod @aggregate def covarPop(x, y): - return F('covarPop', x, y) + return F("covarPop", x, y) @staticmethod @aggregate def covarSamp(x, y): - return F('covarSamp', x, y) + return F("covarSamp", x, y) @staticmethod @aggregate def kurtPop(x): - return F('kurtPop', x) + return F("kurtPop", x) @staticmethod @aggregate def kurtSamp(x): - return F('kurtSamp', x) + return F("kurtSamp", x) @staticmethod @aggregate def min(x): - return F('min', x) + return F("min", x) @staticmethod @aggregate def max(x): - return F('max', x) + return F("max", x) @staticmethod @aggregate def skewPop(x): - return F('skewPop', x) + return F("skewPop", x) @staticmethod @aggregate def skewSamp(x): - return F('skewSamp', x) + return F("skewSamp", x) @staticmethod @aggregate def sum(x): - return F('sum', x) + return F("sum", x) @staticmethod @aggregate def uniq(*args): - return F('uniq', *args) + return F("uniq", *args) @staticmethod @aggregate def uniqExact(*args): - return F('uniqExact', *args) + return F("uniqExact", *args) @staticmethod @aggregate def uniqHLL12(*args): - return F('uniqHLL12', *args) + return F("uniqHLL12", *args) @staticmethod @aggregate def varPop(x): - return F('varPop', x) + return F("varPop", x) @staticmethod @aggregate def varSamp(x): - return F('varSamp', x) + return F("varSamp", x) @staticmethod @aggregate def stddevPop(expr): - return F('stddevPop', expr) + return F("stddevPop", expr) @staticmethod @aggregate def stddevSamp(expr): - return F('stddevSamp', expr) + return F("stddevSamp", expr) @staticmethod @aggregate @parametric def quantile(expr): - return F('quantile', expr) + return F("quantile", expr) @staticmethod @aggregate @parametric def quantileDeterministic(expr, determinator): - return F('quantileDeterministic', expr, determinator) + return F("quantileDeterministic", expr, determinator) @staticmethod @aggregate @parametric def quantileExact(expr): - return F('quantileExact', expr) + return F("quantileExact", expr) @staticmethod @aggregate @parametric def quantileExactWeighted(expr, weight): - return F('quantileExactWeighted', expr, weight) + return F("quantileExactWeighted", expr, weight) @staticmethod @aggregate @parametric def quantileTiming(expr): - return F('quantileTiming', expr) + return F("quantileTiming", expr) @staticmethod @aggregate @parametric def quantileTimingWeighted(expr, weight): - return F('quantileTimingWeighted', expr, weight) + return F("quantileTimingWeighted", expr, weight) @staticmethod @aggregate @parametric def quantileTDigest(expr): - return F('quantileTDigest', expr) + return F("quantileTDigest", expr) @staticmethod @aggregate @parametric def quantileTDigestWeighted(expr, weight): - return F('quantileTDigestWeighted', expr, weight) + return F("quantileTDigestWeighted", expr, weight) @staticmethod @aggregate @parametric def quantiles(expr): - return F('quantiles', expr) + return F("quantiles", expr) @staticmethod @aggregate @parametric def quantilesDeterministic(expr, determinator): - return F('quantilesDeterministic', expr, determinator) + return F("quantilesDeterministic", expr, determinator) @staticmethod @aggregate @parametric def quantilesExact(expr): - return F('quantilesExact', expr) + return F("quantilesExact", expr) @staticmethod @aggregate @parametric def quantilesExactWeighted(expr, weight): - return F('quantilesExactWeighted', expr, weight) + return F("quantilesExactWeighted", expr, weight) @staticmethod @aggregate @parametric def quantilesTiming(expr): - return F('quantilesTiming', expr) + return F("quantilesTiming", expr) @staticmethod @aggregate @parametric def quantilesTimingWeighted(expr, weight): - return F('quantilesTimingWeighted', expr, weight) + return F("quantilesTimingWeighted", expr, weight) @staticmethod @aggregate @parametric def quantilesTDigest(expr): - return F('quantilesTDigest', expr) + return F("quantilesTDigest", expr) @staticmethod @aggregate @parametric def quantilesTDigestWeighted(expr, weight): - return F('quantilesTDigestWeighted', expr, weight) + return F("quantilesTDigestWeighted", expr, weight) @staticmethod @aggregate @parametric def topK(expr): - return F('topK', expr) + return F("topK", expr) @staticmethod @aggregate @parametric def topKWeighted(expr, weight): - return F('topKWeighted', expr, weight) + return F("topKWeighted", expr, weight) # Null handling functions @staticmethod def ifNull(x, y): - return F('ifNull', x, y) + return F("ifNull", x, y) @staticmethod def nullIf(x, y): - return F('nullIf', x, y) + return F("nullIf", x, y) @staticmethod def isNotNull(x): - return F('isNotNull', x) + return F("isNotNull", x) @staticmethod def isNull(x): - return F('isNull', x) + return F("isNull", x) @staticmethod def coalesce(*args): - return F('coalesce', *args) + return F("coalesce", *args) # Misc functions @staticmethod def ifNotFinite(x, y): - return F('ifNotFinite', x, y) + return F("ifNotFinite", x, y) @staticmethod def isFinite(x): - return F('isFinite', x) + return F("isFinite", x) @staticmethod def isInfinite(x): - return F('isInfinite', x) + return F("isInfinite", x) @staticmethod def isNaN(x): - return F('isNaN', x) + return F("isNaN", x) @staticmethod def least(x, y): - return F('least', x, y) + return F("least", x, y) @staticmethod def greatest(x, y): - return F('greatest', x, y) + return F("greatest", x, y) # Dictionary functions @staticmethod def dictGet(dict_name, attr_name, id_expr): - return F('dictGet', dict_name, attr_name, id_expr) + return F("dictGet", dict_name, attr_name, id_expr) @staticmethod def dictGetOrDefault(dict_name, attr_name, id_expr, default): - return F('dictGetOrDefault', dict_name, attr_name, id_expr, default) + return F("dictGetOrDefault", dict_name, attr_name, id_expr, default) @staticmethod def dictHas(dict_name, id_expr): - return F('dictHas', dict_name, id_expr) + return F("dictHas", dict_name, id_expr) @staticmethod def dictGetHierarchy(dict_name, id_expr): - return F('dictGetHierarchy', dict_name, id_expr) + return F("dictGetHierarchy", dict_name, id_expr) @staticmethod def dictIsIn(dict_name, child_id_expr, ancestor_id_expr): - return F('dictIsIn', dict_name, child_id_expr, ancestor_id_expr) + return F("dictIsIn", dict_name, child_id_expr, ancestor_id_expr) # Expose only relevant classes in import * -__all__ = ['F'] - +__all__ = ["F"] diff --git a/src/clickhouse_orm/migrations.py b/src/clickhouse_orm/migrations.py index efcc19f..279bc47 100644 --- a/src/clickhouse_orm/migrations.py +++ b/src/clickhouse_orm/migrations.py @@ -5,7 +5,7 @@ from .utils import get_subclass_names import logging -logger = logging.getLogger('migrations') +logger = logging.getLogger("migrations") class Operation: @@ -14,7 +14,7 @@ class Operation: """ def apply(self, database): - raise NotImplementedError() # pragma: no cover + raise NotImplementedError() # pragma: no cover class ModelOperation(Operation): @@ -30,9 +30,9 @@ class ModelOperation(Operation): self.table_name = model_class.table_name() def _alter_table(self, database, cmd): - ''' + """ Utility for running ALTER TABLE commands. - ''' + """ cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd) logger.debug(cmd) database.raw(cmd) @@ -44,7 +44,7 @@ class CreateTable(ModelOperation): """ def apply(self, database): - logger.info(' Create table %s', self.table_name) + logger.info(" Create table %s", self.table_name) if issubclass(self.model_class, BufferModel): database.create_table(self.model_class.engine.main_model) database.create_table(self.model_class) @@ -65,7 +65,7 @@ class AlterTable(ModelOperation): return [(row.name, row.type) for row in database.select(query)] def apply(self, database): - logger.info(' Alter table %s', self.table_name) + logger.info(" Alter table %s", self.table_name) # Note that MATERIALIZED and ALIAS fields are always at the end of the DESC, # ADD COLUMN ... AFTER doesn't affect it @@ -74,8 +74,8 @@ class AlterTable(ModelOperation): # Identify fields that were deleted from the model deleted_fields = set(table_fields.keys()) - set(self.model_class.fields()) for name in deleted_fields: - logger.info(' Drop column %s', name) - self._alter_table(database, 'DROP COLUMN %s' % name) + logger.info(" Drop column %s", name) + self._alter_table(database, "DROP COLUMN %s" % name) del table_fields[name] # Identify fields that were added to the model @@ -83,13 +83,13 @@ class AlterTable(ModelOperation): for name, field in self.model_class.fields().items(): is_regular_field = not (field.materialized or field.alias) if name not in table_fields: - logger.info(' Add column %s', name) - cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database)) + logger.info(" Add column %s", name) + cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database)) if is_regular_field: if prev_name: - cmd += ' AFTER %s' % prev_name + cmd += " AFTER %s" % prev_name else: - cmd += ' FIRST' + cmd += " FIRST" self._alter_table(database, cmd) if is_regular_field: @@ -101,16 +101,24 @@ class AlterTable(ModelOperation): # The order of class attributes can be changed any time, so we can't count on it # Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save # attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47 - model_fields = {name: field.get_sql(with_default_expression=False, db=database) - for name, field in self.model_class.fields().items()} + model_fields = { + name: field.get_sql(with_default_expression=False, db=database) + for name, field in self.model_class.fields().items() + } for field_name, field_sql in self._get_table_fields(database): # All fields must have been created and dropped by this moment - assert field_name in model_fields, 'Model fields and table columns in disagreement' + assert field_name in model_fields, "Model fields and table columns in disagreement" if field_sql != model_fields[field_name]: - logger.info(' Change type of column %s from %s to %s', field_name, field_sql, - model_fields[field_name]) - self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name])) + logger.info( + " Change type of column %s from %s to %s", + field_name, + field_sql, + model_fields[field_name], + ) + self._alter_table( + database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]) + ) class AlterTableWithBuffer(ModelOperation): @@ -135,7 +143,7 @@ class DropTable(ModelOperation): """ def apply(self, database): - logger.info(' Drop table %s', self.table_name) + logger.info(" Drop table %s", self.table_name) database.drop_table(self.model_class) @@ -148,28 +156,29 @@ class AlterConstraints(ModelOperation): """ def apply(self, database): - logger.info(' Alter constraints for %s', self.table_name) + logger.info(" Alter constraints for %s", self.table_name) existing = self._get_constraint_names(database) # Go over constraints in the model for constraint in self.model_class._constraints.values(): # Check if it's a new constraint if constraint.name not in existing: - logger.info(' Add constraint %s', constraint.name) - self._alter_table(database, 'ADD %s' % constraint.create_table_sql()) + logger.info(" Add constraint %s", constraint.name) + self._alter_table(database, "ADD %s" % constraint.create_table_sql()) else: existing.remove(constraint.name) # Remaining constraints in `existing` are obsolete for name in existing: - logger.info(' Drop constraint %s', name) - self._alter_table(database, 'DROP CONSTRAINT `%s`' % name) + logger.info(" Drop constraint %s", name) + self._alter_table(database, "DROP CONSTRAINT `%s`" % name) def _get_constraint_names(self, database): """ Returns a set containing the names of existing constraints in the table. """ import re - table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name) - matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def) + + table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name) + matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def) return set(matches) @@ -191,33 +200,34 @@ class AlterIndexes(ModelOperation): self.reindex = reindex def apply(self, database): - logger.info(' Alter indexes for %s', self.table_name) + logger.info(" Alter indexes for %s", self.table_name) existing = self._get_index_names(database) logger.info(existing) # Go over indexes in the model for index in self.model_class._indexes.values(): # Check if it's a new index if index.name not in existing: - logger.info(' Add index %s', index.name) - self._alter_table(database, 'ADD %s' % index.create_table_sql()) + logger.info(" Add index %s", index.name) + self._alter_table(database, "ADD %s" % index.create_table_sql()) else: existing.remove(index.name) # Remaining indexes in `existing` are obsolete for name in existing: - logger.info(' Drop index %s', name) - self._alter_table(database, 'DROP INDEX `%s`' % name) + logger.info(" Drop index %s", name) + self._alter_table(database, "DROP INDEX `%s`" % name) # Reindex if self.reindex: - logger.info(' Build indexes on table') - database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name) + logger.info(" Build indexes on table") + database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name) def _get_index_names(self, database): """ Returns a set containing the names of existing indexes in the table. """ import re - table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name) - matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def) + + table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name) + matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def) return set(matches) @@ -225,16 +235,17 @@ class RunPython(Operation): """ A migration operation that executes a Python function. """ + def __init__(self, func): - ''' + """ Initializer. The given Python function will be called with a single argument - the Database instance to apply the migration to. - ''' + """ assert callable(func), "'func' argument must be function" self._func = func def apply(self, database): - logger.info(' Executing python operation %s', self._func.__name__) + logger.info(" Executing python operation %s", self._func.__name__) self._func(database) @@ -244,17 +255,17 @@ class RunSQL(Operation): """ def __init__(self, sql): - ''' + """ Initializer. The given sql argument must be a valid SQL statement or list of statements. - ''' + """ if isinstance(sql, str): sql = [sql] assert isinstance(sql, list), "'sql' argument must be string or list of strings" self._sql = sql def apply(self, database): - logger.info(' Executing raw SQL operations') + logger.info(" Executing raw SQL operations") for item in self._sql: database.raw(item) @@ -268,11 +279,11 @@ class MigrationHistory(Model): module_name = StringField() applied = DateField() - engine = MergeTree('applied', ('package_name', 'module_name')) + engine = MergeTree("applied", ("package_name", "module_name")) @classmethod def table_name(cls): - return 'infi_clickhouse_orm_migrations' + return "infi_clickhouse_orm_migrations" # Expose only relevant classes in import * diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index daf4a54..27b1a93 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -17,7 +17,7 @@ from .engines import Merge, Distributed, Memory if TYPE_CHECKING: from clickhouse_orm.database import Database -logger = getLogger('clickhouse_orm') +logger = getLogger("clickhouse_orm") class Constraint: @@ -38,7 +38,7 @@ class Constraint: """ Returns the SQL statement for defining this constraint during table creation. """ - return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) + return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr)) class Index: @@ -66,8 +66,11 @@ class Index: """ Returns the SQL statement for defining this index during table creation. """ - return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % ( - self.name, arg_to_sql(self.expr), self.type, self.granularity + return "INDEX `%s` %s TYPE %s GRANULARITY %d" % ( + self.name, + arg_to_sql(self.expr), + self.type, + self.granularity, ) @staticmethod @@ -76,7 +79,7 @@ class Index: An index that stores extremes of the specified expression (if the expression is tuple, then it stores extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. """ - return 'minmax' + return "minmax" @staticmethod def set(max_rows: int) -> str: @@ -85,11 +88,12 @@ class Index: or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable on a block of data. """ - return 'set(%d)' % max_rows + return "set(%d)" % max_rows @staticmethod - def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int, - number_of_hash_functions: int, random_seed: int) -> str: + def ngrambf_v1( + n: int, size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int + ) -> str: """ An index that stores a Bloom filter containing all ngrams from a block of data. Works only with strings. Can be used for optimization of equals, like and in expressions. @@ -100,13 +104,17 @@ class Index: - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. """ - return 'ngrambf_v1(%d, %d, %d, %d)' % ( - n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed + return "ngrambf_v1(%d, %d, %d, %d)" % ( + n, + size_of_bloom_filter_in_bytes, + number_of_hash_functions, + random_seed, ) @staticmethod - def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, - random_seed: int) -> str: + def tokenbf_v1( + size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int + ) -> str: """ An index that stores a Bloom filter containing string tokens. Tokens are sequences separated by non-alphanumeric characters. @@ -116,8 +124,10 @@ class Index: - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. """ - return 'tokenbf_v1(%d, %d, %d)' % ( - size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed + return "tokenbf_v1(%d, %d, %d)" % ( + size_of_bloom_filter_in_bytes, + number_of_hash_functions, + random_seed, ) @staticmethod @@ -128,7 +138,7 @@ class Index: - `false_positive` - the probability (between 0 and 1) of receiving a false positive response from the filter """ - return 'bloom_filter(%f)' % false_positive + return "bloom_filter(%f)" % false_positive class ModelBase(type): @@ -183,23 +193,23 @@ class ModelBase(type): _indexes=indexes, _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _defaults=defaults, - _has_funcs_as_defaults=has_funcs_as_defaults + _has_funcs_as_defaults=has_funcs_as_defaults, ) model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs) # Let each field, constraint and index know its parent and its own name for n, obj in chain(fields, constraints.items(), indexes.items()): - setattr(obj, 'parent', model) - setattr(obj, 'name', n) + setattr(obj, "parent", model) + setattr(obj, "name", n) return model @classmethod - def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): + def create_ad_hoc_model(cls, fields, model_name="AdHocModel"): # fields is a list of tuples (name, db_type) # Check if model exists in cache fields = list(fields) - cache_key = model_name + ' ' + str(fields) + cache_key = model_name + " " + str(fields) if cache_key in cls.ad_hoc_model_cache: return cls.ad_hoc_model_cache[cache_key] # Create an ad hoc model class @@ -217,28 +227,25 @@ class ModelBase(type): import clickhouse_orm.contrib.geo.fields as geo_fields # Enums - if db_type.startswith('Enum'): + if db_type.startswith("Enum"): return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) # DateTime with timezone - if db_type.startswith('DateTime('): + if db_type.startswith("DateTime("): timezone = db_type[9:-1] - return orm_fields.DateTimeField( - timezone=timezone[1:-1] if timezone else None - ) + return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None) # DateTime64 - if db_type.startswith('DateTime64('): - precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')] + if db_type.startswith("DateTime64("): + precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")] return orm_fields.DateTime64Field( - precision=int(precision), - timezone=timezone[0][1:-1] if timezone else None + precision=int(precision), timezone=timezone[0][1:-1] if timezone else None ) # Arrays - if db_type.startswith('Array'): + if db_type.startswith("Array"): inner_field = cls.create_ad_hoc_field(db_type[6:-1]) return orm_fields.ArrayField(inner_field) # Tuples - if db_type.startswith('Tuple'): - types = [s.strip().split(' ') for s in db_type[6:-1].split(',')] + if db_type.startswith("Tuple"): + types = [s.strip().split(" ") for s in db_type[6:-1].split(",")] name_fields = [] for i, tp in enumerate(types): if len(tp) == 2: @@ -247,27 +254,27 @@ class ModelBase(type): name_fields.append((str(i), cls.create_ad_hoc_field(tp[0]))) return orm_fields.TupleField(name_fields=name_fields) # FixedString - if db_type.startswith('FixedString'): + if db_type.startswith("FixedString"): length = int(db_type[12:-1]) return orm_fields.FixedStringField(length) # Decimal / Decimal32 / Decimal64 / Decimal128 - if db_type.startswith('Decimal'): - p = db_type.index('(') - args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')] - field_class = getattr(orm_fields, db_type[:p] + 'Field') + if db_type.startswith("Decimal"): + p = db_type.index("(") + args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")] + field_class = getattr(orm_fields, db_type[:p] + "Field") return field_class(*args) # Nullable - if db_type.startswith('Nullable'): - inner_field = cls.create_ad_hoc_field(db_type[9 : -1]) + if db_type.startswith("Nullable"): + inner_field = cls.create_ad_hoc_field(db_type[9:-1]) return orm_fields.NullableField(inner_field) # LowCardinality - if db_type.startswith('LowCardinality'): - inner_field = cls.create_ad_hoc_field(db_type[15 : -1]) + if db_type.startswith("LowCardinality"): + inner_field = cls.create_ad_hoc_field(db_type[15:-1]) return orm_fields.LowCardinalityField(inner_field) # Simple fields - name = db_type + 'Field' + name = db_type + "Field" if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)): - raise NotImplementedError('No field class for %s' % db_type) + raise NotImplementedError("No field class for %s" % db_type) field_class = getattr(orm_fields, name, None) or getattr(geo_fields, name, None) return field_class() @@ -282,6 +289,7 @@ class Model(metaclass=ModelBase): cpu_percent = Float32Field() engine = Memory() """ + _has_funcs_as_defaults: bool _constraints: dict[str, Constraint] _indexes: dict[str, Index] @@ -318,7 +326,7 @@ class Model(metaclass=ModelBase): setattr(self, name, value) else: raise AttributeError( - '%s does not have a field called %s' % (self.__class__.__name__, name) + "%s does not have a field called %s" % (self.__class__.__name__, name) ) def __setattr__(self, name, value): @@ -383,29 +391,29 @@ class Model(metaclass=ModelBase): """ Returns the SQL statement for creating a table for this model. """ - parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] + parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())] # Fields items = [] for name, field in cls.fields().items(): - items.append(' %s %s' % (name, field.get_sql(db=db))) + items.append(" %s %s" % (name, field.get_sql(db=db))) # Constraints for c in cls._constraints.values(): - items.append(' %s' % c.create_table_sql()) + items.append(" %s" % c.create_table_sql()) # Indexes for i in cls._indexes.values(): - items.append(' %s' % i.create_table_sql()) - parts.append(',\n'.join(items)) + items.append(" %s" % i.create_table_sql()) + parts.append(",\n".join(items)) # Engine - parts.append(')') - parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) - return '\n'.join(parts) + parts.append(")") + parts.append("ENGINE = " + cls.engine.create_table_sql(db)) + return "\n".join(parts) @classmethod def drop_table_sql(cls, db: Database) -> str: """ Returns the SQL command for deleting this model's table. """ - return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) + return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name()) @classmethod def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): @@ -422,7 +430,7 @@ class Model(metaclass=ModelBase): kwargs = {} for name in field_names: field = getattr(cls, name) - field_timezone = getattr(field, 'timezone', None) or timezone_in_use + field_timezone = getattr(field, "timezone", None) or timezone_in_use kwargs[name] = field.to_python(next(values), field_timezone) obj = cls(**kwargs) @@ -439,7 +447,9 @@ class Model(metaclass=ModelBase): """ data = self.__dict__ fields = self.fields(writable=not include_readonly) - return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) + return "\t".join( + field.to_db_string(data[name], quote=False) for name, field in fields.items() + ) def to_tskv(self, include_readonly=True): """ @@ -453,16 +463,16 @@ class Model(metaclass=ModelBase): parts = [] for name, field in fields.items(): if data[name] != NO_VALUE: - parts.append(name + '=' + field.to_db_string(data[name], quote=False)) - return '\t'.join(parts) + parts.append(name + "=" + field.to_db_string(data[name], quote=False)) + return "\t".join(parts) def to_db_string(self) -> bytes: """ Returns the instance as a bytestring ready to be inserted into the database. """ s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) - s += '\n' - return s.encode('utf-8') + s += "\n" + return s.encode("utf-8") def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]: """ @@ -519,19 +529,18 @@ class Model(metaclass=ModelBase): class BufferModel(Model): - @classmethod def create_table_sql(cls, db: Database) -> str: """ Returns the SQL statement for creating a table for this model. """ parts = [ - 'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % ( - db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name()) + "CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`" + % (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name()) ] engine_str = cls.engine.create_table_sql(db) parts.append(engine_str) - return ' '.join(parts) + return " ".join(parts) class MergeModel(Model): @@ -540,6 +549,7 @@ class MergeModel(Model): Predefines virtual _table column an controls that rows can't be inserted to this table type https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge """ + readonly = True # Virtual fields can't be inserted into database @@ -551,15 +561,16 @@ class MergeModel(Model): Returns the SQL statement for creating a table for this model. """ assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" - parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] + parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())] cols = [] for name, field in cls.fields().items(): - if name != '_table': - cols.append(' %s %s' % (name, field.get_sql(db=db))) - parts.append(',\n'.join(cols)) - parts.append(')') - parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) - return '\n'.join(parts) + if name != "_table": + cols.append(" %s %s" % (name, field.get_sql(db=db))) + parts.append(",\n".join(cols)) + parts.append(")") + parts.append("ENGINE = " + cls.engine.create_table_sql(db)) + return "\n".join(parts) + # TODO: base class for models that require specific engine @@ -574,8 +585,9 @@ class DistributedModel(Model): Sets the `Database` that this model instance belongs to. This is done automatically when the instance is read from the database or written to it. """ - assert isinstance(self.engine, Distributed),\ - "engine must be an instance of engines.Distributed" + assert isinstance( + self.engine, Distributed + ), "engine must be an instance of engines.Distributed" super().set_database(db) @classmethod @@ -616,15 +628,20 @@ class DistributedModel(Model): return # find out all the superclasses of the Model that store any data - storage_models = [b for b in cls.__bases__ if issubclass(b, Model) - and not issubclass(b, DistributedModel)] + storage_models = [ + b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel) + ] if not storage_models: - raise TypeError("When defining Distributed engine without the table_name " - "ensure that your model has a parent model") + raise TypeError( + "When defining Distributed engine without the table_name " + "ensure that your model has a parent model" + ) if len(storage_models) > 1: - raise TypeError("When defining Distributed engine without the table_name " - "ensure that your model has exactly one non-distributed superclass") + raise TypeError( + "When defining Distributed engine without the table_name " + "ensure that your model has exactly one non-distributed superclass" + ) # enable correct SQL for engine cls.engine.table = storage_models[0] @@ -637,10 +654,12 @@ class DistributedModel(Model): assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" parts = [ - 'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format( - db.db_name, cls.table_name(), cls.engine.table_name), - 'ENGINE = ' + cls.engine.create_table_sql(db)] - return '\n'.join(parts) + "CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format( + db.db_name, cls.table_name(), cls.engine.table_name + ), + "ENGINE = " + cls.engine.create_table_sql(db), + ] + return "\n".join(parts) class TemporaryModel(Model): @@ -657,30 +676,31 @@ class TemporaryModel(Model): https://clickhouse.com/docs/en/sql-reference/statements/create/table/#temporary-tables """ + _temporary = True @classmethod def create_table_sql(cls, db: Database) -> str: assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance" - parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()] + parts = ["CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (" % cls.table_name()] # Fields items = [] for name, field in cls.fields().items(): - items.append(' %s %s' % (name, field.get_sql(db=db))) + items.append(" %s %s" % (name, field.get_sql(db=db))) # Constraints for c in cls._constraints.values(): - items.append(' %s' % c.create_table_sql()) + items.append(" %s" % c.create_table_sql()) # Indexes for i in cls._indexes.values(): - items.append(' %s' % i.create_table_sql()) - parts.append(',\n'.join(items)) + items.append(" %s" % i.create_table_sql()) + parts.append(",\n".join(items)) # Engine - parts.append(')') - parts.append('ENGINE = Memory') - return '\n'.join(parts) + parts.append(")") + parts.append("ENGINE = Memory") + return "\n".join(parts) # Expose only relevant classes in import * -MODEL = TypeVar('MODEL', bound=Model) +MODEL = TypeVar("MODEL", bound=Model) __all__ = get_subclass_names(locals(), (Model, Constraint, Index)) diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index 5912500..8fc30ad 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -11,7 +11,7 @@ from typing import ( Generic, TypeVar, AsyncIterator, - Iterator + Iterator, ) import pytz @@ -24,7 +24,7 @@ if TYPE_CHECKING: from clickhouse_orm.models import Model from clickhouse_orm.database import Database, Page -MODEL = TypeVar('MODEL', bound='Model') +MODEL = TypeVar("MODEL", bound="Model") class Operator: @@ -59,9 +59,9 @@ class SimpleOperator(Operator): def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) - if value == '\\N' and self._sql_for_null is not None: - return ' '.join([field_name, self._sql_for_null]) - return ' '.join([field.name, self._sql_operator, value]) + if value == "\\N" and self._sql_for_null is not None: + return " ".join([field_name, self._sql_for_null]) + return " ".join([field.name, self._sql_operator, value]) class InOperator(Operator): @@ -81,7 +81,7 @@ class InOperator(Operator): pass else: value = comma_join([self._value_to_sql(field, v) for v in value]) - return '%s IN (%s)' % (field.name, value) + return "%s IN (%s)" % (field.name, value) class GlobalInOperator(Operator): @@ -95,7 +95,7 @@ class GlobalInOperator(Operator): pass else: value = comma_join([self._value_to_sql(field, v) for v in value]) - return '%s GLOBAL IN (%s)' % (field.name, value) + return "%s GLOBAL IN (%s)" % (field.name, value) class LikeOperator(Operator): @@ -111,11 +111,11 @@ class LikeOperator(Operator): def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value, quote=False) - value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') + value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_") pattern = self._pattern.format(value) if self._case_sensitive: - return '%s LIKE \'%s\'' % (field.name, pattern) - return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern) + return "%s LIKE '%s'" % (field.name, pattern) + return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field.name, pattern) class IExactOperator(Operator): @@ -126,7 +126,7 @@ class IExactOperator(Operator): def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) - return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value) + return "lowerUTF8(%s) = lowerUTF8(%s)" % (field.name, value) class NotOperator(Operator): @@ -139,7 +139,7 @@ class NotOperator(Operator): def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: # Negate the base operator - return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) + return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value) class BetweenOperator(Operator): @@ -154,16 +154,22 @@ class BetweenOperator(Operator): def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) - value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len( - str(value[0])) > 0 else None - value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len( - str(value[1])) > 0 else None + value0 = ( + self._value_to_sql(field, value[0]) + if value[0] is not None or len(str(value[0])) > 0 + else None + ) + value1 = ( + self._value_to_sql(field, value[1]) + if value[1] is not None or len(str(value[1])) > 0 + else None + ) if value0 and value1: - return '%s BETWEEN %s AND %s' % (field.name, value0, value1) + return "%s BETWEEN %s AND %s" % (field.name, value0, value1) if value0 and not value1: - return ' '.join([field.name, '>=', value0]) + return " ".join([field.name, ">=", value0]) if value1 and not value0: - return ' '.join([field.name, '<=', value1]) + return " ".join([field.name, "<=", value1]) # Define the set of builtin operators @@ -175,24 +181,24 @@ def register_operator(name: str, sql: Operator): _operators[name] = sql -register_operator('eq', SimpleOperator('=', 'IS NULL')) -register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) -register_operator('gt', SimpleOperator('>')) -register_operator('gte', SimpleOperator('>=')) -register_operator('lt', SimpleOperator('<')) -register_operator('lte', SimpleOperator('<=')) -register_operator('between', BetweenOperator()) -register_operator('in', InOperator()) -register_operator('gin', GlobalInOperator()) -register_operator('not_in', NotOperator(InOperator())) -register_operator('not_gin', NotOperator(GlobalInOperator())) -register_operator('contains', LikeOperator('%{}%')) -register_operator('startswith', LikeOperator('{}%')) -register_operator('endswith', LikeOperator('%{}')) -register_operator('icontains', LikeOperator('%{}%', False)) -register_operator('istartswith', LikeOperator('{}%', False)) -register_operator('iendswith', LikeOperator('%{}', False)) -register_operator('iexact', IExactOperator()) +register_operator("eq", SimpleOperator("=", "IS NULL")) +register_operator("ne", SimpleOperator("!=", "IS NOT NULL")) +register_operator("gt", SimpleOperator(">")) +register_operator("gte", SimpleOperator(">=")) +register_operator("lt", SimpleOperator("<")) +register_operator("lte", SimpleOperator("<=")) +register_operator("between", BetweenOperator()) +register_operator("in", InOperator()) +register_operator("gin", GlobalInOperator()) +register_operator("not_in", NotOperator(InOperator())) +register_operator("not_gin", NotOperator(GlobalInOperator())) +register_operator("contains", LikeOperator("%{}%")) +register_operator("startswith", LikeOperator("{}%")) +register_operator("endswith", LikeOperator("%{}")) +register_operator("icontains", LikeOperator("%{}%", False)) +register_operator("istartswith", LikeOperator("{}%", False)) +register_operator("iendswith", LikeOperator("%{}", False)) +register_operator("iexact", IExactOperator()) class Cond: @@ -214,8 +220,8 @@ class FieldCond(Cond): self._operator = _operators.get(operator) if self._operator is None: # The field name contains __ like my__field - self._field_name = field_name + '__' + operator - self._operator = _operators['eq'] + self._field_name = field_name + "__" + operator + self._operator = _operators["eq"] self._value = value def to_sql(self, model_cls: type[Model]) -> str: @@ -228,12 +234,13 @@ class FieldCond(Cond): class Q: - AND_MODE = 'AND' - OR_MODE = 'OR' + AND_MODE = "AND" + OR_MODE = "OR" def __init__(self, *filter_funcs, **filter_fields): - self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in - filter_fields.items()] + self._conds = list(filter_funcs) + [ + self._build_cond(k, v) for k, v in filter_fields.items() + ] self._children = [] self._negate = False self._mode = self.AND_MODE @@ -263,10 +270,10 @@ class Q: return q def _build_cond(self, key, value): - if '__' in key: - field_name, operator = key.rsplit('__', 1) + if "__" in key: + field_name, operator = key.rsplit("__", 1) else: - field_name, operator = key, 'eq' + field_name, operator = key, "eq" return FieldCond(field_name, operator, value) def to_sql(self, model_cls: type[Model]) -> str: @@ -280,16 +287,16 @@ class Q: if not condition_sql: # Empty Q() object returns everything - sql = '1' + sql = "1" elif len(condition_sql) == 1: # Skip not needed brackets over single condition sql = condition_sql[0] else: # Each condition must be enclosed in brackets, or order of operations may be wrong - sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql) + sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql) if self._negate: - sql = 'NOT (%s)' % sql + sql = "NOT (%s)" % sql return sql @@ -400,16 +407,16 @@ class QuerySet(Generic[MODEL]): def __getitem__(self, s): if isinstance(s, int): # Single index - assert s >= 0, 'negative indexes are not supported' + assert s >= 0, "negative indexes are not supported" queryset = self._clone() queryset._limits = (s, 1) return next(iter(queryset)) # Slice - assert s.step in (None, 1), 'step is not supported in slices' + assert s.step in (None, 1), "step is not supported in slices" start = s.start or 0 - stop = s.stop or 2 ** 63 - 1 - assert start >= 0 and stop >= 0, 'negative indexes are not supported' - assert start <= stop, 'start of slice cannot be smaller than its end' + stop = s.stop or 2**63 - 1 + assert start >= 0 and stop >= 0, "negative indexes are not supported" + assert start <= stop, "start of slice cannot be smaller than its end" queryset = self._clone() queryset._limits = (start, stop - start) return queryset @@ -425,7 +432,7 @@ class QuerySet(Generic[MODEL]): offset_limit = (0, offset_limit) offset = offset_limit[0] limit = offset_limit[1] - assert offset >= 0 and limit >= 0, 'negative limits are not supported' + assert offset >= 0 and limit >= 0, "negative limits are not supported" queryset = self._clone() queryset._limit_by = (offset, limit) queryset._limit_by_fields = fields_or_expr @@ -435,44 +442,44 @@ class QuerySet(Generic[MODEL]): """ Returns the selected fields or expressions as a SQL string. """ - fields = '*' + fields = "*" if self._fields: - fields = comma_join('`%s`' % field for field in self._fields) + fields = comma_join("`%s`" % field for field in self._fields) return fields def as_sql(self) -> str: """ Returns the whole query as a SQL string. """ - distinct = 'DISTINCT ' if self._distinct else '' - final = ' FINAL' if self._final else '' - table_name = '`%s`' % self._model_cls.table_name() + distinct = "DISTINCT " if self._distinct else "" + final = " FINAL" if self._final else "" + table_name = "`%s`" % self._model_cls.table_name() if self._model_cls.is_system_model(): - table_name = '`system`.' + table_name + table_name = "`system`." + table_name params = (distinct, self.select_fields_as_sql(), table_name, final) - sql = 'SELECT %s%s\nFROM %s%s' % params + sql = "SELECT %s%s\nFROM %s%s" % params if self._prewhere_q and not self._prewhere_q.is_empty: - sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) + sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True) if self._where_q and not self._where_q.is_empty: - sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) + sql += "\nWHERE " + self.conditions_as_sql(prewhere=False) if self._grouping_fields: - sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields) + sql += "\nGROUP BY %s" % comma_join("%s" % field for field in self._grouping_fields) if self._grouping_with_totals: - sql += ' WITH TOTALS' + sql += " WITH TOTALS" if self._order_by: - sql += '\nORDER BY ' + self.order_by_as_sql() + sql += "\nORDER BY " + self.order_by_as_sql() if self._limit_by: - sql += '\nLIMIT %d, %d' % self._limit_by - sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields) + sql += "\nLIMIT %d, %d" % self._limit_by + sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields) if self._limits: - sql += '\nLIMIT %d, %d' % self._limits + sql += "\nLIMIT %d, %d" % self._limits return sql @@ -480,10 +487,12 @@ class QuerySet(Generic[MODEL]): """ Returns the contents of the query's `ORDER BY` clause as a string. """ - return comma_join([ - '%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field) - for field in self._order_by - ]) + return comma_join( + [ + "%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field) + for field in self._order_by + ] + ) def conditions_as_sql(self, prewhere=False) -> str: """ @@ -498,7 +507,7 @@ class QuerySet(Generic[MODEL]): """ if self._distinct or self._limits: # Use a subquery, since a simple count won't be accurate - sql = 'SELECT count() FROM (%s)' % self.as_sql() + sql = "SELECT count() FROM (%s)" % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 @@ -527,8 +536,8 @@ class QuerySet(Generic[MODEL]): def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]": from clickhouse_orm.funcs import F - inverse = kwargs.pop('_inverse', False) - prewhere = kwargs.pop('prewhere', False) + inverse = kwargs.pop("_inverse", False) + prewhere = kwargs.pop("prewhere", False) queryset = self._clone() @@ -588,14 +597,14 @@ class QuerySet(Generic[MODEL]): if page_num == -1: page_num = pages_total elif page_num < 1: - raise ValueError('Invalid page number: %d' % page_num) + raise ValueError("Invalid page number: %d" % page_num) offset = (page_num - 1) * page_size return Page( - objects=list(self[offset: offset + page_size]), + objects=list(self[offset : offset + page_size]), number_of_objects=count, pages_total=pages_total, number=page_num, - page_size=page_size + page_size=page_size, ) def distinct(self) -> "QuerySet[MODEL]": @@ -616,8 +625,8 @@ class QuerySet(Generic[MODEL]): if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): raise TypeError( - 'final() method can be used only with the CollapsingMergeTree' - ' and ReplacingMergeTree engines' + "final() method can be used only with the CollapsingMergeTree" + " and ReplacingMergeTree engines" ) queryset = self._clone() @@ -631,7 +640,7 @@ class QuerySet(Generic[MODEL]): """ self._verify_mutation_allowed() conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) - sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions) + sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions) self._database.raw(sql) return self @@ -641,12 +650,14 @@ class QuerySet(Generic[MODEL]): Keyword arguments specify the field names and expressions to use for the update. Note that ClickHouse performs updates in the background, so they are not immediate. """ - assert kwargs, 'No fields specified for update' + assert kwargs, "No fields specified for update" self._verify_mutation_allowed() - fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) + fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) - sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % ( - self._model_cls.table_name(), fields, conditions + sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % ( + self._model_cls.table_name(), + fields, + conditions, ) self._database.raw(sql) return self @@ -655,10 +666,10 @@ class QuerySet(Generic[MODEL]): """ Checks that the queryset's state allows mutations. Raises an AssertionError if not. """ - assert not self._limits, 'Mutations are not allowed after slicing the queryset' - assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' - assert not self._distinct, 'Mutations are not allowed after calling distinct()' - assert not self._final, 'Mutations are not allowed after calling final()' + assert not self._limits, "Mutations are not allowed after slicing the queryset" + assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)" + assert not self._distinct, "Mutations are not allowed after calling distinct()" + assert not self._final, "Mutations are not allowed after calling final()" def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]": """ @@ -687,7 +698,7 @@ class AggregateQuerySet(QuerySet[MODEL]): self, base_queryset: QuerySet, grouping_fields: tuple[Any], - calculated_fields: dict[str, str] + calculated_fields: dict[str, str], ): """ Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`. @@ -705,7 +716,7 @@ class AggregateQuerySet(QuerySet[MODEL]): At least one calculated field is required. """ super().__init__(base_queryset._model_cls, base_queryset._database) - assert calculated_fields, 'No calculated fields specified for aggregation' + assert calculated_fields, "No calculated fields specified for aggregation" self._fields = grouping_fields self._grouping_fields = grouping_fields self._calculated_fields = calculated_fields @@ -734,8 +745,9 @@ class AggregateQuerySet(QuerySet[MODEL]): created with. """ for name in args: - assert name in self._fields or name in self._calculated_fields, \ - 'Cannot group by `%s` since it is not included in the query' % name + assert name in self._fields or name in self._calculated_fields, ( + "Cannot group by `%s` since it is not included in the query" % name + ) queryset = copy(self) queryset._grouping_fields = args return queryset @@ -750,14 +762,16 @@ class AggregateQuerySet(QuerySet[MODEL]): """ This method is not supported on `AggregateQuerySet`. """ - raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') + raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet") def select_fields_as_sql(self) -> str: """ Returns the selected fields or expressions as a SQL string. """ - return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in - self._calculated_fields.items()]) + return comma_join( + [str(f) for f in self._fields] + + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()] + ) def __iter__(self) -> Iterator[Model]: """ @@ -778,7 +792,7 @@ class AggregateQuerySet(QuerySet[MODEL]): """ Returns the number of rows after aggregation. """ - sql = 'SELECT count() FROM (%s)' % self.as_sql() + sql = "SELECT count() FROM (%s)" % self.as_sql() raw = self._database.raw(sql) if isinstance(raw, CoroutineType): return raw @@ -795,7 +809,7 @@ class AggregateQuerySet(QuerySet[MODEL]): return queryset def _verify_mutation_allowed(self): - raise AssertionError('Cannot mutate an AggregateQuerySet') + raise AssertionError("Cannot mutate an AggregateQuerySet") # Expose only relevant classes in import * diff --git a/src/clickhouse_orm/session.py b/src/clickhouse_orm/session.py index ef5be08..feb063c 100644 --- a/src/clickhouse_orm/session.py +++ b/src/clickhouse_orm/session.py @@ -2,8 +2,8 @@ import uuid from typing import Optional from contextvars import ContextVar, Token -ctx_session_id: ContextVar[str] = ContextVar('ck.session_id') -ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout') +ctx_session_id: ContextVar[str] = ContextVar("ck.session_id") +ctx_session_timeout: ContextVar[float] = ContextVar("ck.session_timeout") class SessionContext: diff --git a/src/clickhouse_orm/system_models.py b/src/clickhouse_orm/system_models.py index 69b67fa..d17635f 100644 --- a/src/clickhouse_orm/system_models.py +++ b/src/clickhouse_orm/system_models.py @@ -16,12 +16,15 @@ class SystemPart(Model): This model operates only fields, described in the reference. Other fields are ignored. https://clickhouse.tech/docs/en/system_tables/system.parts/ """ - OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'}) + + OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"}) _readonly = True _system = True - database = StringField() # Name of the database where the table that this part belongs to is located. + database = ( + StringField() + ) # Name of the database where the table that this part belongs to is located. table = StringField() # Name of the table that this part belongs to. engine = StringField() # Name of the table engine, without parameters. partition = StringField() # Name of the partition, in the format YYYYMM. @@ -43,7 +46,9 @@ class SystemPart(Model): # Time the directory with the part was modified. Usually corresponds to the part's creation time. modification_time = DateTimeField() - remove_time = DateTimeField() # For inactive parts only - the time when the part became inactive. + remove_time = ( + DateTimeField() + ) # For inactive parts only - the time when the part became inactive. # The number of places where the part is used. A value greater than 2 indicates # that this part participates in queries or merges. @@ -51,12 +56,13 @@ class SystemPart(Model): @classmethod def table_name(cls): - return 'parts' + return "parts" """ Next methods return SQL for some operations, which can be done with partitions https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts """ + def _partition_operation_sql(self, operation, settings=None, from_part=None): """ Performs some operation over partition @@ -68,9 +74,16 @@ class SystemPart(Model): Returns: Operation execution result """ operation = operation.upper() - assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(self.OPERATIONS) + assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join( + self.OPERATIONS + ) - sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (self._database.db_name, self.table, operation, self.partition) + sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % ( + self._database.db_name, + self.table, + operation, + self.partition, + ) if from_part is not None: sql += " FROM %s" % from_part self._database.raw(sql, settings=settings, stream=False) @@ -83,7 +96,7 @@ class SystemPart(Model): Returns: SQL Query """ - return self._partition_operation_sql('DETACH', settings=settings) + return self._partition_operation_sql("DETACH", settings=settings) def drop(self, settings=None): """ @@ -93,7 +106,7 @@ class SystemPart(Model): Returns: SQL Query """ - return self._partition_operation_sql('DROP', settings=settings) + return self._partition_operation_sql("DROP", settings=settings) def attach(self, settings=None): """ @@ -103,7 +116,7 @@ class SystemPart(Model): Returns: SQL Query """ - return self._partition_operation_sql('ATTACH', settings=settings) + return self._partition_operation_sql("ATTACH", settings=settings) def freeze(self, settings=None): """ @@ -113,7 +126,7 @@ class SystemPart(Model): Returns: SQL Query """ - return self._partition_operation_sql('FREEZE', settings=settings) + return self._partition_operation_sql("FREEZE", settings=settings) def fetch(self, zookeeper_path, settings=None): """ @@ -124,7 +137,7 @@ class SystemPart(Model): Returns: SQL Query """ - return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path) + return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path) @classmethod def get(cls, database, conditions=""): @@ -140,9 +153,12 @@ class SystemPart(Model): assert isinstance(conditions, str), "conditions must be a string" if conditions: conditions += " AND" - field_names = ','.join(cls.fields()) - return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" % - (field_names, cls.table_name(), conditions, database.db_name), model_class=cls) + field_names = ",".join(cls.fields()) + return database.select( + "SELECT %s FROM `system`.%s WHERE %s database='%s'" + % (field_names, cls.table_name(), conditions, database.db_name), + model_class=cls, + ) @classmethod def get_active(cls, database, conditions=""): @@ -155,8 +171,8 @@ class SystemPart(Model): Returns: A list of SystemPart objects """ if conditions: - conditions += ' AND ' - conditions += 'active' + conditions += " AND " + conditions += "active" return SystemPart.get(database, conditions=conditions) diff --git a/src/clickhouse_orm/utils.py b/src/clickhouse_orm/utils.py index 8f97eb8..441ee6b 100644 --- a/src/clickhouse_orm/utils.py +++ b/src/clickhouse_orm/utils.py @@ -10,10 +10,10 @@ SPECIAL_CHARS = { "\t": "\\t", "\0": "\\0", "\\": "\\\\", - "'": "\\'" + "'": "\\'", } -SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") +SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]") POINT_REGEX = re.compile(r"\((?P\d+(\.\d+)?),(?P\d+(\.\d+)?)\)") RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]") @@ -36,11 +36,11 @@ def escape(value, quote=True): def unescape(value): - return codecs.escape_decode(value)[0].decode('utf-8') + return codecs.escape_decode(value)[0].decode("utf-8") def string_or_func(obj): - return obj.to_sql() if hasattr(obj, 'to_sql') else obj + return obj.to_sql() if hasattr(obj, "to_sql") else obj def arg_to_sql(arg): @@ -50,6 +50,7 @@ def arg_to_sql(arg): None, numbers, timezones, arrays/iterables. """ from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet + if isinstance(arg, F): return arg.to_sql() if isinstance(arg, Field): @@ -67,22 +68,22 @@ def arg_to_sql(arg): if isinstance(arg, tzinfo): return StringField().to_db_string(arg.tzname(None)) if arg is None: - return 'NULL' + return "NULL" if isinstance(arg, QuerySet): return "(%s)" % arg if isinstance(arg, tuple): - return '(' + comma_join(arg_to_sql(x) for x in arg) + ')' + return "(" + comma_join(arg_to_sql(x) for x in arg) + ")" if is_iterable(arg): - return '[' + comma_join(arg_to_sql(x) for x in arg) + ']' + return "[" + comma_join(arg_to_sql(x) for x in arg) + "]" return str(arg) def parse_tsv(line): if isinstance(line, bytes): line = line.decode() - if line and line[-1] == '\n': + if line and line[-1] == "\n": line = line[:-1] - return [unescape(value) for value in line.split(str('\t'))] + return [unescape(value) for value in line.split(str("\t"))] def parse_array(array_string): @@ -92,17 +93,17 @@ def parse_array(array_string): "(1,2,3)" ==> [1, 2, 3] """ # Sanity check - if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])': + if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])": raise ValueError('Invalid array string: "%s"' % array_string) # Drop opening brace array_string = array_string[1:] # Go over the string, lopping off each value at the beginning until nothing is left values = [] while True: - if array_string in '])': + if array_string in "])": # End of array return values - elif array_string[0] in ', ': + elif array_string[0] in ", ": # In between values array_string = array_string[1:] elif array_string[0] == "'": @@ -110,13 +111,13 @@ def parse_array(array_string): match = re.search(r"[^\\]'", array_string) if match is None: raise ValueError('Missing closing quote: "%s"' % array_string) - values.append(array_string[1: match.start() + 1]) - array_string = array_string[match.end():] + values.append(array_string[1 : match.start() + 1]) + array_string = array_string[match.end() :] else: # Start of non-quoted value, find its end match = re.search(r",|\]|\)", array_string) - values.append(array_string[0: match.start()]) - array_string = array_string[match.end() - 1:] + values.append(array_string[0 : match.start()]) + array_string = array_string[match.end() - 1 :] def import_submodules(package_name): @@ -124,9 +125,10 @@ def import_submodules(package_name): Import all submodules of a module. """ import importlib, pkgutil + package = importlib.import_module(package_name) return { - name: importlib.import_module(package_name + '.' + name) + name: importlib.import_module(package_name + "." + name) for _, name, _ in pkgutil.iter_modules(package.__path__) } @@ -136,9 +138,9 @@ def comma_join(items, stringify=False): Joins an iterable of strings with commas. """ if stringify: - return ', '.join(str(item) for item in items) + return ", ".join(str(item) for item in items) else: - return ', '.join(items) + return ", ".join(items) def is_iterable(obj): @@ -154,6 +156,7 @@ def is_iterable(obj): def get_subclass_names(locals, base_class): from inspect import isclass + return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] @@ -164,7 +167,7 @@ class NoValue: """ def __repr__(self): - return 'NO_VALUE' + return "NO_VALUE" NO_VALUE = NoValue()