mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-01 10:49:53 +03:00
Migrate code style to Black
This commit is contained in:
parent
df2d778919
commit
d22683f28c
526
docs/ref.md
526
docs/ref.md
|
@ -1,526 +0,0 @@
|
|||
Class Reference
|
||||
===============
|
||||
|
||||
infi.clickhouse_orm.database
|
||||
----------------------------
|
||||
|
||||
### Database
|
||||
|
||||
#### Database(db_name, db_url="http://localhost:8123/", username=None, password=None, readonly=False)
|
||||
|
||||
Initializes a database instance. Unless it's readonly, the database will be
|
||||
created on the ClickHouse server if it does not already exist.
|
||||
|
||||
- `db_name`: name of the database to connect to.
|
||||
- `db_url`: URL of the ClickHouse server.
|
||||
- `username`: optional connection credentials.
|
||||
- `password`: optional connection credentials.
|
||||
- `readonly`: use a read-only connection.
|
||||
|
||||
|
||||
#### count(model_class, conditions=None)
|
||||
|
||||
Counts the number of records in the model's table.
|
||||
|
||||
- `model_class`: the model to count.
|
||||
- `conditions`: optional SQL conditions (contents of the WHERE clause).
|
||||
|
||||
|
||||
#### create_database()
|
||||
|
||||
Creates the database on the ClickHouse server if it does not already exist.
|
||||
|
||||
|
||||
#### create_table(model_class)
|
||||
|
||||
Creates a table for the given model class, if it does not exist already.
|
||||
|
||||
|
||||
#### drop_database()
|
||||
|
||||
Deletes the database on the ClickHouse server.
|
||||
|
||||
|
||||
#### drop_table(model_class)
|
||||
|
||||
Drops the database table of the given model class, if it exists.
|
||||
|
||||
|
||||
#### insert(model_instances, batch_size=1000)
|
||||
|
||||
Insert records into the database.
|
||||
|
||||
- `model_instances`: any iterable containing instances of a single model class.
|
||||
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
|
||||
|
||||
|
||||
#### migrate(migrations_package_name, up_to=9999)
|
||||
|
||||
Executes schema migrations.
|
||||
|
||||
- `migrations_package_name` - fully qualified name of the Python package
|
||||
containing the migrations.
|
||||
- `up_to` - number of the last migration to apply.
|
||||
|
||||
|
||||
#### paginate(model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None)
|
||||
|
||||
Selects records and returns a single page of model instances.
|
||||
|
||||
- `model_class`: the model class matching the query's table,
|
||||
or `None` for getting back instances of an ad-hoc model.
|
||||
- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause).
|
||||
- `page_num`: the page number (1-based), or -1 to get the last page.
|
||||
- `page_size`: number of records to return per page.
|
||||
- `conditions`: optional SQL conditions (contents of the WHERE clause).
|
||||
- `settings`: query settings to send as HTTP GET parameters
|
||||
|
||||
The result is a namedtuple containing `objects` (list), `number_of_objects`,
|
||||
`pages_total`, `number` (of the current page), and `page_size`.
|
||||
|
||||
|
||||
#### raw(query, settings=None, stream=False)
|
||||
|
||||
Performs a query and returns its output as text.
|
||||
|
||||
- `query`: the SQL query to execute.
|
||||
- `settings`: query settings to send as HTTP GET parameters
|
||||
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
|
||||
|
||||
|
||||
#### select(query, model_class=None, settings=None)
|
||||
|
||||
Performs a query and returns a generator of model instances.
|
||||
|
||||
- `query`: the SQL query to execute.
|
||||
- `model_class`: the model class matching the query's table,
|
||||
or `None` for getting back instances of an ad-hoc model.
|
||||
- `settings`: query settings to send as HTTP GET parameters
|
||||
|
||||
|
||||
### DatabaseException
|
||||
|
||||
Extends Exception
|
||||
|
||||
Raised when a database operation fails.
|
||||
|
||||
infi.clickhouse_orm.models
|
||||
--------------------------
|
||||
|
||||
### Model
|
||||
|
||||
A base class for ORM models.
|
||||
|
||||
#### Model(**kwargs)
|
||||
|
||||
Creates a model instance, using keyword arguments as field values.
|
||||
Since values are immediately converted to their Pythonic type,
|
||||
invalid values will cause a `ValueError` to be raised.
|
||||
Unrecognized field names will cause an `AttributeError`.
|
||||
|
||||
|
||||
#### Model.create_table_sql(db)
|
||||
|
||||
Returns the SQL command for creating a table for this model.
|
||||
|
||||
|
||||
#### Model.drop_table_sql(db)
|
||||
|
||||
Returns the SQL command for deleting this model's table.
|
||||
|
||||
|
||||
#### Model.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
|
||||
|
||||
Create a model instance from a tab-separated line. The line may or may not include a newline.
|
||||
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
|
||||
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
|
||||
|
||||
- `line`: the TSV-formatted data.
|
||||
- `field_names`: names of the model fields in the data.
|
||||
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
|
||||
- `database`: if given, sets the database that this instance belongs to.
|
||||
|
||||
|
||||
#### get_database()
|
||||
|
||||
Gets the `Database` that this model instance belongs to.
|
||||
Returns `None` unless the instance was read from the database or written to it.
|
||||
|
||||
|
||||
#### get_field(name)
|
||||
|
||||
Gets a `Field` instance given its name, or `None` if not found.
|
||||
|
||||
|
||||
#### Model.objects_in(database)
|
||||
|
||||
Returns a `QuerySet` for selecting instances of this model class.
|
||||
|
||||
|
||||
#### set_database(db)
|
||||
|
||||
Sets the `Database` that this model instance belongs to.
|
||||
This is done automatically when the instance is read from the database or written to it.
|
||||
|
||||
|
||||
#### Model.table_name()
|
||||
|
||||
Returns the model's database table name. By default this is the
|
||||
class name converted to lowercase. Override this if you want to use
|
||||
a different table name.
|
||||
|
||||
|
||||
#### to_dict(include_readonly=True, field_names=None)
|
||||
|
||||
Returns the instance's column values as a dict.
|
||||
|
||||
- `include_readonly`: if false, returns only fields that can be inserted into database.
|
||||
- `field_names`: an iterable of field names to return (optional)
|
||||
|
||||
|
||||
#### to_tsv(include_readonly=True)
|
||||
|
||||
Returns the instance's column values as a tab-separated line. A newline is not included.
|
||||
|
||||
- `include_readonly`: if false, returns only fields that can be inserted into database.
|
||||
|
||||
|
||||
### BufferModel
|
||||
|
||||
Extends Model
|
||||
|
||||
#### BufferModel(**kwargs)
|
||||
|
||||
Creates a model instance, using keyword arguments as field values.
|
||||
Since values are immediately converted to their Pythonic type,
|
||||
invalid values will cause a `ValueError` to be raised.
|
||||
Unrecognized field names will cause an `AttributeError`.
|
||||
|
||||
|
||||
#### BufferModel.create_table_sql(db)
|
||||
|
||||
Returns the SQL command for creating a table for this model.
|
||||
|
||||
|
||||
#### BufferModel.drop_table_sql(db)
|
||||
|
||||
Returns the SQL command for deleting this model's table.
|
||||
|
||||
|
||||
#### BufferModel.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
|
||||
|
||||
Create a model instance from a tab-separated line. The line may or may not include a newline.
|
||||
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
|
||||
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
|
||||
|
||||
- `line`: the TSV-formatted data.
|
||||
- `field_names`: names of the model fields in the data.
|
||||
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
|
||||
- `database`: if given, sets the database that this instance belongs to.
|
||||
|
||||
|
||||
#### get_database()
|
||||
|
||||
Gets the `Database` that this model instance belongs to.
|
||||
Returns `None` unless the instance was read from the database or written to it.
|
||||
|
||||
|
||||
#### get_field(name)
|
||||
|
||||
Gets a `Field` instance given its name, or `None` if not found.
|
||||
|
||||
|
||||
#### BufferModel.objects_in(database)
|
||||
|
||||
Returns a `QuerySet` for selecting instances of this model class.
|
||||
|
||||
|
||||
#### set_database(db)
|
||||
|
||||
Sets the `Database` that this model instance belongs to.
|
||||
This is done automatically when the instance is read from the database or written to it.
|
||||
|
||||
|
||||
#### BufferModel.table_name()
|
||||
|
||||
Returns the model's database table name. By default this is the
|
||||
class name converted to lowercase. Override this if you want to use
|
||||
a different table name.
|
||||
|
||||
|
||||
#### to_dict(include_readonly=True, field_names=None)
|
||||
|
||||
Returns the instance's column values as a dict.
|
||||
|
||||
- `include_readonly`: if false, returns only fields that can be inserted into database.
|
||||
- `field_names`: an iterable of field names to return (optional)
|
||||
|
||||
|
||||
#### to_tsv(include_readonly=True)
|
||||
|
||||
Returns the instance's column values as a tab-separated line. A newline is not included.
|
||||
|
||||
- `include_readonly`: if false, returns only fields that can be inserted into database.
|
||||
|
||||
|
||||
infi.clickhouse_orm.fields
|
||||
--------------------------
|
||||
|
||||
### Field
|
||||
|
||||
Abstract base class for all field types.
|
||||
|
||||
#### Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### StringField
|
||||
|
||||
Extends Field
|
||||
|
||||
#### StringField(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### DateField
|
||||
|
||||
Extends Field
|
||||
|
||||
#### DateField(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### DateTimeField
|
||||
|
||||
Extends Field
|
||||
|
||||
#### DateTimeField(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### BaseIntField
|
||||
|
||||
Extends Field
|
||||
|
||||
Abstract base class for all integer-type fields.
|
||||
|
||||
#### BaseIntField(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### BaseFloatField
|
||||
|
||||
Extends Field
|
||||
|
||||
Abstract base class for all float-type fields.
|
||||
|
||||
#### BaseFloatField(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### BaseEnumField
|
||||
|
||||
Extends Field
|
||||
|
||||
Abstract base class for all enum-type fields.
|
||||
|
||||
#### BaseEnumField(enum_cls, default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### ArrayField
|
||||
|
||||
Extends Field
|
||||
|
||||
#### ArrayField(inner_field, default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### FixedStringField
|
||||
|
||||
Extends StringField
|
||||
|
||||
#### FixedStringField(length, default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### UInt8Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### UInt8Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### UInt16Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### UInt16Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### UInt32Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### UInt32Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### UInt64Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### UInt64Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Int8Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### Int8Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Int16Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### Int16Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Int32Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### Int32Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Int64Field
|
||||
|
||||
Extends BaseIntField
|
||||
|
||||
#### Int64Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Float32Field
|
||||
|
||||
Extends BaseFloatField
|
||||
|
||||
#### Float32Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Float64Field
|
||||
|
||||
Extends BaseFloatField
|
||||
|
||||
#### Float64Field(default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Enum8Field
|
||||
|
||||
Extends BaseEnumField
|
||||
|
||||
#### Enum8Field(enum_cls, default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
### Enum16Field
|
||||
|
||||
Extends BaseEnumField
|
||||
|
||||
#### Enum16Field(enum_cls, default=None, alias=None, materialized=None)
|
||||
|
||||
|
||||
infi.clickhouse_orm.engines
|
||||
---------------------------
|
||||
|
||||
### Engine
|
||||
|
||||
### TinyLog
|
||||
|
||||
Extends Engine
|
||||
|
||||
### Log
|
||||
|
||||
Extends Engine
|
||||
|
||||
### Memory
|
||||
|
||||
Extends Engine
|
||||
|
||||
### MergeTree
|
||||
|
||||
Extends Engine
|
||||
|
||||
#### MergeTree(date_col, key_cols, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
|
||||
|
||||
|
||||
### Buffer
|
||||
|
||||
Extends Engine
|
||||
|
||||
Here we define Buffer engine
|
||||
Read more here https://clickhouse.tech/reference_en.html#Buffer
|
||||
|
||||
#### Buffer(main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, min_bytes=10000000, max_bytes=100000000)
|
||||
|
||||
|
||||
### CollapsingMergeTree
|
||||
|
||||
Extends MergeTree
|
||||
|
||||
#### CollapsingMergeTree(date_col, key_cols, sign_col, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
|
||||
|
||||
|
||||
### SummingMergeTree
|
||||
|
||||
Extends MergeTree
|
||||
|
||||
#### SummingMergeTree(date_col, key_cols, summing_cols=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
|
||||
|
||||
|
||||
### ReplacingMergeTree
|
||||
|
||||
Extends MergeTree
|
||||
|
||||
#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
|
||||
|
||||
|
||||
infi.clickhouse_orm.query
|
||||
-------------------------
|
||||
|
||||
### QuerySet
|
||||
|
||||
#### QuerySet(model_cls, database)
|
||||
|
||||
|
||||
#### conditions_as_sql(prewhere=True)
|
||||
|
||||
Return the contents of the queryset's WHERE or `PREWHERE` clause.
|
||||
|
||||
|
||||
#### count()
|
||||
|
||||
Returns the number of matching model instances.
|
||||
|
||||
|
||||
#### exclude(**kwargs)
|
||||
|
||||
Returns a new QuerySet instance that excludes all rows matching the conditions.
|
||||
|
||||
|
||||
#### filter(**kwargs)
|
||||
|
||||
Returns a new QuerySet instance that includes only rows matching the conditions.
|
||||
|
||||
|
||||
#### only(*field_names)
|
||||
|
||||
Limit the query to return only the specified field names.
|
||||
Useful when there are large fields that are not needed,
|
||||
or for creating a subquery to use with an IN operator.
|
||||
|
||||
|
||||
#### order_by(*field_names)
|
||||
|
||||
Returns a new QuerySet instance with the ordering changed.
|
||||
|
||||
|
||||
#### order_by_as_sql()
|
||||
|
||||
Return the contents of the queryset's ORDER BY clause.
|
||||
|
||||
|
||||
#### query()
|
||||
|
||||
Return the the queryset as SQL.
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from clickhouse_orm.database import Database, ServerError, DatabaseException, lo
|
|||
|
||||
# pylint: disable=C0116
|
||||
|
||||
|
||||
class AioDatabase(Database):
|
||||
_client_class = httpx.AsyncClient
|
||||
|
||||
|
@ -25,7 +26,7 @@ class AioDatabase(Database):
|
|||
if self._readonly:
|
||||
if not self.db_exists:
|
||||
raise DatabaseException(
|
||||
'Database does not exist, and cannot be created under readonly connection'
|
||||
"Database does not exist, and cannot be created under readonly connection"
|
||||
)
|
||||
self.connection_readonly = await self._is_connection_readonly()
|
||||
self.readonly = True
|
||||
|
@ -44,10 +45,7 @@ class AioDatabase(Database):
|
|||
await self.request_session.aclose()
|
||||
|
||||
async def _send(
|
||||
self,
|
||||
data: str | bytes | AsyncGenerator,
|
||||
settings: dict = None,
|
||||
stream: bool = False
|
||||
self, data: str | bytes | AsyncGenerator, settings: dict = None, stream: bool = False
|
||||
):
|
||||
r = await super()._send(data, settings, stream)
|
||||
if r.status_code != 200:
|
||||
|
@ -55,11 +53,7 @@ class AioDatabase(Database):
|
|||
raise ServerError(r.text)
|
||||
return r
|
||||
|
||||
async def count(
|
||||
self,
|
||||
model_class: type[MODEL],
|
||||
conditions=None
|
||||
) -> int:
|
||||
async def count(self, model_class: type[MODEL], conditions=None) -> int:
|
||||
"""
|
||||
Counts the number of records in the model's table.
|
||||
|
||||
|
@ -70,14 +64,14 @@ class AioDatabase(Database):
|
|||
|
||||
if not self._init:
|
||||
raise DatabaseException(
|
||||
'The AioDatabase object must execute the init method before it can be used'
|
||||
"The AioDatabase object must execute the init method before it can be used"
|
||||
)
|
||||
|
||||
query = 'SELECT count() FROM $table'
|
||||
query = "SELECT count() FROM $table"
|
||||
if conditions:
|
||||
if isinstance(conditions, Q):
|
||||
conditions = conditions.to_sql(model_class)
|
||||
query += ' WHERE ' + str(conditions)
|
||||
query += " WHERE " + str(conditions)
|
||||
query = self._substitute(query, model_class)
|
||||
r = await self._send(query)
|
||||
return int(r.text) if r.text else 0
|
||||
|
@ -86,14 +80,14 @@ class AioDatabase(Database):
|
|||
"""
|
||||
Creates the database on the ClickHouse server if it does not already exist.
|
||||
"""
|
||||
await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
|
||||
await self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
|
||||
self.db_exists = True
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Deletes the database on the ClickHouse server.
|
||||
"""
|
||||
await self._send('DROP DATABASE `%s`' % self.db_name)
|
||||
await self._send("DROP DATABASE `%s`" % self.db_name)
|
||||
self.db_exists = False
|
||||
|
||||
async def create_table(self, model_class: type[MODEL]) -> None:
|
||||
|
@ -102,7 +96,7 @@ class AioDatabase(Database):
|
|||
"""
|
||||
if not self._init:
|
||||
raise DatabaseException(
|
||||
'The AioDatabase object must execute the init method before it can be used'
|
||||
"The AioDatabase object must execute the init method before it can be used"
|
||||
)
|
||||
if model_class.is_system_model():
|
||||
raise DatabaseException("You can't create system table")
|
||||
|
@ -110,7 +104,7 @@ class AioDatabase(Database):
|
|||
raise DatabaseException(
|
||||
"Creating a temporary table must be within the lifetime of a session "
|
||||
)
|
||||
if getattr(model_class, 'engine') is None:
|
||||
if getattr(model_class, "engine") is None:
|
||||
raise DatabaseException(f"%s class must define an engine" % model_class.__name__)
|
||||
await self._send(model_class.create_table_sql(self))
|
||||
|
||||
|
@ -121,7 +115,7 @@ class AioDatabase(Database):
|
|||
"""
|
||||
if not self._init:
|
||||
raise DatabaseException(
|
||||
'The AioDatabase object must execute the init method before it can be used'
|
||||
"The AioDatabase object must execute the init method before it can be used"
|
||||
)
|
||||
|
||||
await self._send(model_class.create_temporary_table_sql(self, table_name))
|
||||
|
@ -132,7 +126,7 @@ class AioDatabase(Database):
|
|||
"""
|
||||
if not self._init:
|
||||
raise DatabaseException(
|
||||
'The AioDatabase object must execute the init method before it can be used'
|
||||
"The AioDatabase object must execute the init method before it can be used"
|
||||
)
|
||||
|
||||
if model_class.is_system_model():
|
||||
|
@ -146,18 +140,14 @@ class AioDatabase(Database):
|
|||
"""
|
||||
if not self._init:
|
||||
raise DatabaseException(
|
||||
'The AioDatabase object must execute the init method before it can be used'
|
||||
"The AioDatabase object must execute the init method before it can be used"
|
||||
)
|
||||
|
||||
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
|
||||
r = await self._send(sql % (self.db_name, model_class.table_name()))
|
||||
return r.text.strip() == '1'
|
||||
return r.text.strip() == "1"
|
||||
|
||||
async def get_model_for_table(
|
||||
self,
|
||||
table_name: str,
|
||||
system_table: bool = False
|
||||
):
|
||||
async def get_model_for_table(self, table_name: str, system_table: bool = False):
|
||||
"""
|
||||
Generates a model class from an existing table in the database.
|
||||
This can be used for querying tables which don't have a corresponding model class,
|
||||
|
@ -166,7 +156,7 @@ class AioDatabase(Database):
|
|||
- `table_name`: the table to create a model for
|
||||
- `system_table`: whether the table is a system table, or belongs to the current database
|
||||
"""
|
||||
db_name = 'system' if system_table else self.db_name
|
||||
db_name = "system" if system_table else self.db_name
|
||||
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
|
||||
lines = await self._send(sql)
|
||||
fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()]
|
||||
|
@ -192,14 +182,13 @@ class AioDatabase(Database):
|
|||
if first_instance.is_read_only() or first_instance.is_system_model():
|
||||
raise DatabaseException("You can't insert into read only and system tables")
|
||||
|
||||
fields_list = ','.join(
|
||||
['`%s`' % name for name in first_instance.fields(writable=True)])
|
||||
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
|
||||
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
|
||||
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
|
||||
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
|
||||
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
|
||||
|
||||
async def gen():
|
||||
buf = BytesIO()
|
||||
buf.write(self._substitute(query, model_class).encode('utf-8'))
|
||||
buf.write(self._substitute(query, model_class).encode("utf-8"))
|
||||
first_instance.set_database(self)
|
||||
buf.write(first_instance.to_db_string())
|
||||
# Collect lines in batches of batch_size
|
||||
|
@ -217,13 +206,11 @@ class AioDatabase(Database):
|
|||
# Return any remaining lines in partial batch
|
||||
if lines:
|
||||
yield buf.getvalue()
|
||||
|
||||
await self._send(gen())
|
||||
|
||||
async def select(
|
||||
self,
|
||||
query: str,
|
||||
model_class: Optional[type[MODEL]] = None,
|
||||
settings: Optional[dict] = None
|
||||
self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
|
||||
) -> AsyncGenerator[MODEL, None]:
|
||||
"""
|
||||
Performs a query and returns a generator of model instances.
|
||||
|
@ -233,7 +220,7 @@ class AioDatabase(Database):
|
|||
or `None` for getting back instances of an ad-hoc model.
|
||||
- `settings`: query settings to send as HTTP GET parameters
|
||||
"""
|
||||
query += ' FORMAT TabSeparatedWithNamesAndTypes'
|
||||
query += " FORMAT TabSeparatedWithNamesAndTypes"
|
||||
query = self._substitute(query, model_class)
|
||||
r = await self._send(query, settings, True)
|
||||
try:
|
||||
|
@ -245,7 +232,8 @@ class AioDatabase(Database):
|
|||
elif not field_types:
|
||||
field_types = parse_tsv(line)
|
||||
model_class = model_class or ModelBase.create_ad_hoc_model(
|
||||
zip(field_names, field_types))
|
||||
zip(field_names, field_types)
|
||||
)
|
||||
elif line.strip():
|
||||
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
|
||||
except StopIteration:
|
||||
|
@ -271,7 +259,7 @@ class AioDatabase(Database):
|
|||
page_num: int = 1,
|
||||
page_size: int = 100,
|
||||
conditions=None,
|
||||
settings: Optional[dict] = None
|
||||
settings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Selects records and returns a single page of model instances.
|
||||
|
@ -294,22 +282,22 @@ class AioDatabase(Database):
|
|||
if page_num == -1:
|
||||
page_num = max(pages_total, 1)
|
||||
elif page_num < 1:
|
||||
raise ValueError('Invalid page number: %d' % page_num)
|
||||
raise ValueError("Invalid page number: %d" % page_num)
|
||||
offset = (page_num - 1) * page_size
|
||||
query = 'SELECT * FROM $table'
|
||||
query = "SELECT * FROM $table"
|
||||
if conditions:
|
||||
if isinstance(conditions, Q):
|
||||
conditions = conditions.to_sql(model_class)
|
||||
query += ' WHERE ' + str(conditions)
|
||||
query += ' ORDER BY %s' % order_by
|
||||
query += ' LIMIT %d, %d' % (offset, page_size)
|
||||
query += " WHERE " + str(conditions)
|
||||
query += " ORDER BY %s" % order_by
|
||||
query += " LIMIT %d, %d" % (offset, page_size)
|
||||
query = self._substitute(query, model_class)
|
||||
return Page(
|
||||
objects=[r async for r in self.select(query, model_class, settings)] if count else [],
|
||||
number_of_objects=count,
|
||||
pages_total=pages_total,
|
||||
number=page_num,
|
||||
page_size=page_size
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
async def migrate(self, migrations_package_name, up_to=9999):
|
||||
|
@ -322,19 +310,23 @@ class AioDatabase(Database):
|
|||
"""
|
||||
from ..migrations import MigrationHistory
|
||||
|
||||
logger = logging.getLogger('migrations')
|
||||
logger = logging.getLogger("migrations")
|
||||
applied_migrations = await self._get_applied_migrations(migrations_package_name)
|
||||
modules = import_submodules(migrations_package_name)
|
||||
unapplied_migrations = set(modules.keys()) - applied_migrations
|
||||
for name in sorted(unapplied_migrations):
|
||||
logger.info('Applying migration %s...', name)
|
||||
logger.info("Applying migration %s...", name)
|
||||
for operation in modules[name].operations:
|
||||
operation.apply(self)
|
||||
await self.insert([MigrationHistory(
|
||||
package_name=migrations_package_name,
|
||||
module_name=name,
|
||||
applied=datetime.date.today()
|
||||
)])
|
||||
await self.insert(
|
||||
[
|
||||
MigrationHistory(
|
||||
package_name=migrations_package_name,
|
||||
module_name=name,
|
||||
applied=datetime.date.today(),
|
||||
)
|
||||
]
|
||||
)
|
||||
if int(name[:4]) >= up_to:
|
||||
break
|
||||
|
||||
|
@ -342,28 +334,28 @@ class AioDatabase(Database):
|
|||
r = await self._send(
|
||||
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
|
||||
)
|
||||
return r.text.strip() == '1'
|
||||
return r.text.strip() == "1"
|
||||
|
||||
async def _is_connection_readonly(self):
|
||||
r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
|
||||
return r.text.strip() != '0'
|
||||
return r.text.strip() != "0"
|
||||
|
||||
async def _get_server_timezone(self):
|
||||
try:
|
||||
r = await self._send('SELECT timezone()')
|
||||
r = await self._send("SELECT timezone()")
|
||||
return pytz.timezone(r.text.strip())
|
||||
except ServerError as err:
|
||||
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
|
||||
logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
|
||||
return pytz.utc
|
||||
|
||||
async def _get_server_version(self, as_tuple=True):
|
||||
try:
|
||||
r = await self._send('SELECT version();')
|
||||
r = await self._send("SELECT version();")
|
||||
ver = r.text
|
||||
except ServerError as err:
|
||||
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
|
||||
ver = '1.1.0'
|
||||
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
|
||||
logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
|
||||
ver = "1.1.0"
|
||||
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
|
||||
|
||||
async def _get_applied_migrations(self, migrations_package_name):
|
||||
from ..migrations import MigrationHistory
|
||||
|
|
|
@ -11,10 +11,10 @@ class Point:
|
|||
self.y = float(y)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Point x={self.x} y={self.y}>'
|
||||
return f"<Point x={self.x} y={self.y}>"
|
||||
|
||||
def to_db_string(self):
|
||||
return f'({self.x},{self.y})'
|
||||
return f"({self.x},{self.y})"
|
||||
|
||||
|
||||
class Ring:
|
||||
|
@ -29,16 +29,16 @@ class Ring:
|
|||
return len(self.array)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Ring {self.to_db_string()}>'
|
||||
return f"<Ring {self.to_db_string()}>"
|
||||
|
||||
def to_db_string(self):
|
||||
return f'[{",".join(pt.to_db_string() for pt in self.array)}]'
|
||||
|
||||
|
||||
def parse_point(array_string: str) -> Point:
|
||||
if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')':
|
||||
if len(array_string) < 2 or array_string[0] != "(" or array_string[-1] != ")":
|
||||
raise ValueError('Invalid point string: "%s"' % array_string)
|
||||
x, y = array_string.strip('()').split(',')
|
||||
x, y = array_string.strip("()").split(",")
|
||||
return Point(x, y)
|
||||
|
||||
|
||||
|
@ -47,14 +47,14 @@ def parse_ring(array_string: str) -> Ring:
|
|||
raise ValueError('Invalid ring string: "%s"' % array_string)
|
||||
ring = []
|
||||
for point in POINT_REGEX.finditer(array_string):
|
||||
x, y = point.group('x'), point.group('y')
|
||||
x, y = point.group("x"), point.group("y")
|
||||
ring.append(Point(x, y))
|
||||
return Ring(ring)
|
||||
|
||||
|
||||
class PointField(Field):
|
||||
class_default = Point(0, 0)
|
||||
db_type = 'Point'
|
||||
db_type = "Point"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -63,7 +63,7 @@ class PointField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
super().__init__(default, alias, materialized, readonly, codec, db_column)
|
||||
self.inner_field = Float64Field()
|
||||
|
@ -73,10 +73,10 @@ class PointField(Field):
|
|||
value = parse_point(value)
|
||||
elif isinstance(value, (tuple, list)):
|
||||
if len(value) != 2:
|
||||
raise ValueError('PointField takes 2 value, but %s were given' % len(value))
|
||||
raise ValueError("PointField takes 2 value, but %s were given" % len(value))
|
||||
value = Point(value[0], value[1])
|
||||
if not isinstance(value, Point):
|
||||
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
|
||||
raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
|
||||
return value
|
||||
|
||||
def validate(self, value):
|
||||
|
@ -91,7 +91,7 @@ class PointField(Field):
|
|||
|
||||
class RingField(Field):
|
||||
class_default = [Point(0, 0)]
|
||||
db_type = 'Ring'
|
||||
db_type = "Ring"
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
if isinstance(value, str):
|
||||
|
@ -100,11 +100,11 @@ class RingField(Field):
|
|||
ring = []
|
||||
for point in value:
|
||||
if len(point) != 2:
|
||||
raise ValueError('Point takes 2 value, but %s were given' % len(value))
|
||||
raise ValueError("Point takes 2 value, but %s were given" % len(value))
|
||||
ring.append(Point(point[0], point[1]))
|
||||
value = Ring(ring)
|
||||
if not isinstance(value, Ring):
|
||||
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
|
||||
raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
|
||||
return value
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
|
|
|
@ -16,8 +16,8 @@ from .utils import parse_tsv, import_submodules
|
|||
from .session import ctx_session_id, ctx_session_timeout
|
||||
|
||||
|
||||
logger = logging.getLogger('clickhouse_orm')
|
||||
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
||||
logger = logging.getLogger("clickhouse_orm")
|
||||
Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
|
||||
|
||||
|
||||
class DatabaseException(Exception):
|
||||
|
@ -30,6 +30,7 @@ class ServerError(DatabaseException):
|
|||
"""
|
||||
Raised when a server returns an error.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.code = None
|
||||
processed = self.get_error_code_msg(message)
|
||||
|
@ -43,21 +44,30 @@ class ServerError(DatabaseException):
|
|||
|
||||
ERROR_PATTERNS = (
|
||||
# ClickHouse prior to v19.3.3
|
||||
re.compile(r'''
|
||||
re.compile(
|
||||
r"""
|
||||
Code:\ (?P<code>\d+),
|
||||
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
|
||||
\ e.what\(\)\ =\ (?P<type2>[^ \n]+)
|
||||
''', re.VERBOSE | re.DOTALL),
|
||||
""",
|
||||
re.VERBOSE | re.DOTALL,
|
||||
),
|
||||
# ClickHouse v19.3.3+
|
||||
re.compile(r'''
|
||||
re.compile(
|
||||
r"""
|
||||
Code:\ (?P<code>\d+),
|
||||
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
|
||||
''', re.VERBOSE | re.DOTALL),
|
||||
""",
|
||||
re.VERBOSE | re.DOTALL,
|
||||
),
|
||||
# ClickHouse v21+
|
||||
re.compile(r'''
|
||||
re.compile(
|
||||
r"""
|
||||
Code:\ (?P<code>\d+).
|
||||
\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
|
||||
''', re.VERBOSE | re.DOTALL),
|
||||
""",
|
||||
re.VERBOSE | re.DOTALL,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -72,7 +82,7 @@ class ServerError(DatabaseException):
|
|||
match = pattern.match(full_error_message)
|
||||
if match:
|
||||
# assert match.group('type1') == match.group('type2')
|
||||
return int(match.group('code')), match.group('msg').strip()
|
||||
return int(match.group("code")), match.group("msg").strip()
|
||||
|
||||
return 0, full_error_message
|
||||
|
||||
|
@ -86,11 +96,21 @@ class Database:
|
|||
Database instances connect to a specific ClickHouse database for running queries,
|
||||
inserting data and other operations.
|
||||
"""
|
||||
|
||||
_client_class = httpx.Client
|
||||
|
||||
def __init__(self, db_name, db_url='http://localhost:8123/',
|
||||
username=None, password=None, readonly=False, auto_create=True,
|
||||
timeout=60, verify_ssl_cert=True, log_statements=False):
|
||||
def __init__(
|
||||
self,
|
||||
db_name,
|
||||
db_url="http://localhost:8123/",
|
||||
username=None,
|
||||
password=None,
|
||||
readonly=False,
|
||||
auto_create=True,
|
||||
timeout=60,
|
||||
verify_ssl_cert=True,
|
||||
log_statements=False,
|
||||
):
|
||||
"""
|
||||
Initializes a database instance. Unless it's readonly, the database will be
|
||||
created on the ClickHouse server if it does not already exist.
|
||||
|
@ -114,7 +134,7 @@ class Database:
|
|||
self.timeout = timeout
|
||||
self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout)
|
||||
if username:
|
||||
self.request_session.auth = (username, password or '')
|
||||
self.request_session.auth = (username, password or "")
|
||||
self.log_statements = log_statements
|
||||
self.settings = {}
|
||||
self.db_exists = False # this is required before running _is_existing_database
|
||||
|
@ -134,7 +154,7 @@ class Database:
|
|||
if self._readonly:
|
||||
if not self.db_exists:
|
||||
raise DatabaseException(
|
||||
'Database does not exist, and cannot be created under readonly connection'
|
||||
"Database does not exist, and cannot be created under readonly connection"
|
||||
)
|
||||
self.connection_readonly = self._is_connection_readonly()
|
||||
self.readonly = True
|
||||
|
@ -155,14 +175,14 @@ class Database:
|
|||
"""
|
||||
Creates the database on the ClickHouse server if it does not already exist.
|
||||
"""
|
||||
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
|
||||
self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
|
||||
self.db_exists = True
|
||||
|
||||
def drop_database(self):
|
||||
"""
|
||||
Deletes the database on the ClickHouse server.
|
||||
"""
|
||||
self._send('DROP DATABASE `%s`' % self.db_name)
|
||||
self._send("DROP DATABASE `%s`" % self.db_name)
|
||||
self.db_exists = False
|
||||
|
||||
def create_table(self, model_class: type[MODEL]) -> None:
|
||||
|
@ -171,7 +191,7 @@ class Database:
|
|||
"""
|
||||
if model_class.is_system_model():
|
||||
raise DatabaseException("You can't create system table")
|
||||
if getattr(model_class, 'engine') is None:
|
||||
if getattr(model_class, "engine") is None:
|
||||
raise DatabaseException("%s class must define an engine" % model_class.__name__)
|
||||
self._send(model_class.create_table_sql(self))
|
||||
|
||||
|
@ -190,13 +210,9 @@ class Database:
|
|||
"""
|
||||
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
|
||||
r = self._send(sql % (self.db_name, model_class.table_name()))
|
||||
return r.text.strip() == '1'
|
||||
return r.text.strip() == "1"
|
||||
|
||||
def get_model_for_table(
|
||||
self,
|
||||
table_name: str,
|
||||
system_table: bool = False
|
||||
):
|
||||
def get_model_for_table(self, table_name: str, system_table: bool = False):
|
||||
"""
|
||||
Generates a model class from an existing table in the database.
|
||||
This can be used for querying tables which don't have a corresponding model class,
|
||||
|
@ -205,7 +221,7 @@ class Database:
|
|||
- `table_name`: the table to create a model for
|
||||
- `system_table`: whether the table is a system table, or belongs to the current database
|
||||
"""
|
||||
db_name = 'system' if system_table else self.db_name
|
||||
db_name = "system" if system_table else self.db_name
|
||||
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
|
||||
lines = self._send(sql).iter_lines()
|
||||
fields = [parse_tsv(line)[:2] for line in lines]
|
||||
|
@ -222,7 +238,7 @@ class Database:
|
|||
The name must be string, and the value is converted to string in case
|
||||
it isn't. To remove a setting, pass `None` as the value.
|
||||
"""
|
||||
assert isinstance(name, str), 'Setting name must be a string'
|
||||
assert isinstance(name, str), "Setting name must be a string"
|
||||
if value is None:
|
||||
self.settings.pop(name, None)
|
||||
else:
|
||||
|
@ -246,14 +262,13 @@ class Database:
|
|||
if first_instance.is_read_only() or first_instance.is_system_model():
|
||||
raise DatabaseException("You can't insert into read only and system tables")
|
||||
|
||||
fields_list = ','.join(
|
||||
['`%s`' % name for name in first_instance.fields(writable=True)])
|
||||
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
|
||||
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
|
||||
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
|
||||
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
|
||||
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
|
||||
|
||||
def gen():
|
||||
buf = BytesIO()
|
||||
buf.write(self._substitute(query, model_class).encode('utf-8'))
|
||||
buf.write(self._substitute(query, model_class).encode("utf-8"))
|
||||
first_instance.set_database(self)
|
||||
buf.write(first_instance.to_db_string())
|
||||
# Collect lines in batches of batch_size
|
||||
|
@ -271,12 +286,11 @@ class Database:
|
|||
# Return any remaining lines in partial batch
|
||||
if lines:
|
||||
yield buf.getvalue()
|
||||
|
||||
self._send(gen())
|
||||
|
||||
def count(
|
||||
self,
|
||||
model_class: Optional[type[MODEL]],
|
||||
conditions: Optional[Union[str, 'Q']] = None
|
||||
self, model_class: Optional[type[MODEL]], conditions: Optional[Union[str, "Q"]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Counts the number of records in the model's table.
|
||||
|
@ -286,20 +300,17 @@ class Database:
|
|||
"""
|
||||
from clickhouse_orm.query import Q
|
||||
|
||||
query = 'SELECT count() FROM $table'
|
||||
query = "SELECT count() FROM $table"
|
||||
if conditions:
|
||||
if isinstance(conditions, Q):
|
||||
conditions = conditions.to_sql(model_class)
|
||||
query += ' WHERE ' + str(conditions)
|
||||
query += " WHERE " + str(conditions)
|
||||
query = self._substitute(query, model_class)
|
||||
r = self._send(query)
|
||||
return int(r.text) if r.text else 0
|
||||
|
||||
def select(
|
||||
self,
|
||||
query: str,
|
||||
model_class: Optional[type[MODEL]] = None,
|
||||
settings: Optional[dict] = None
|
||||
self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
|
||||
) -> Generator[MODEL, None, None]:
|
||||
"""
|
||||
Performs a query and returns a generator of model instances.
|
||||
|
@ -309,7 +320,7 @@ class Database:
|
|||
or `None` for getting back instances of an ad-hoc model.
|
||||
- `settings`: query settings to send as HTTP GET parameters
|
||||
"""
|
||||
query += ' FORMAT TabSeparatedWithNamesAndTypes'
|
||||
query += " FORMAT TabSeparatedWithNamesAndTypes"
|
||||
query = self._substitute(query, model_class)
|
||||
r = self._send(query, settings, True)
|
||||
try:
|
||||
|
@ -345,7 +356,7 @@ class Database:
|
|||
page_num: int = 1,
|
||||
page_size: int = 100,
|
||||
conditions=None,
|
||||
settings: Optional[dict] = None
|
||||
settings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Selects records and returns a single page of model instances.
|
||||
|
@ -362,27 +373,28 @@ class Database:
|
|||
`pages_total`, `number` (of the current page), and `page_size`.
|
||||
"""
|
||||
from clickhouse_orm.query import Q
|
||||
|
||||
count = self.count(model_class, conditions)
|
||||
pages_total = int(ceil(count / float(page_size)))
|
||||
if page_num == -1:
|
||||
page_num = max(pages_total, 1)
|
||||
elif page_num < 1:
|
||||
raise ValueError('Invalid page number: %d' % page_num)
|
||||
raise ValueError("Invalid page number: %d" % page_num)
|
||||
offset = (page_num - 1) * page_size
|
||||
query = 'SELECT * FROM $table'
|
||||
query = "SELECT * FROM $table"
|
||||
if conditions:
|
||||
if isinstance(conditions, Q):
|
||||
conditions = conditions.to_sql(model_class)
|
||||
query += ' WHERE ' + str(conditions)
|
||||
query += ' ORDER BY %s' % order_by
|
||||
query += ' LIMIT %d, %d' % (offset, page_size)
|
||||
query += " WHERE " + str(conditions)
|
||||
query += " ORDER BY %s" % order_by
|
||||
query += " LIMIT %d, %d" % (offset, page_size)
|
||||
query = self._substitute(query, model_class)
|
||||
return Page(
|
||||
objects=list(self.select(query, model_class, settings)) if count else [],
|
||||
number_of_objects=count,
|
||||
pages_total=pages_total,
|
||||
number=page_num,
|
||||
page_size=page_size
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def migrate(self, migrations_package_name, up_to=9999):
|
||||
|
@ -395,19 +407,23 @@ class Database:
|
|||
"""
|
||||
from .migrations import MigrationHistory # pylint: disable=C0415
|
||||
|
||||
logger = logging.getLogger('migrations')
|
||||
logger = logging.getLogger("migrations")
|
||||
applied_migrations = self._get_applied_migrations(migrations_package_name)
|
||||
modules = import_submodules(migrations_package_name)
|
||||
unapplied_migrations = set(modules.keys()) - applied_migrations
|
||||
for name in sorted(unapplied_migrations):
|
||||
logger.info('Applying migration %s...', name)
|
||||
logger.info("Applying migration %s...", name)
|
||||
for operation in modules[name].operations:
|
||||
operation.apply(self)
|
||||
self.insert([MigrationHistory(
|
||||
package_name=migrations_package_name,
|
||||
module_name=name,
|
||||
applied=datetime.date.today())
|
||||
])
|
||||
self.insert(
|
||||
[
|
||||
MigrationHistory(
|
||||
package_name=migrations_package_name,
|
||||
module_name=name,
|
||||
applied=datetime.date.today(),
|
||||
)
|
||||
]
|
||||
)
|
||||
if int(name[:4]) >= up_to:
|
||||
break
|
||||
|
||||
|
@ -432,19 +448,14 @@ class Database:
|
|||
query = self._substitute(query, MigrationHistory)
|
||||
return set(obj.module_name for obj in self.select(query))
|
||||
|
||||
def _send(
|
||||
self,
|
||||
data: str | bytes | Generator,
|
||||
settings: dict = None,
|
||||
stream: bool = False
|
||||
):
|
||||
def _send(self, data: str | bytes | Generator, settings: dict = None, stream: bool = False):
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
data = data.encode("utf-8")
|
||||
if self.log_statements:
|
||||
logger.info(data)
|
||||
params = self._build_params(settings)
|
||||
request = self.request_session.build_request(
|
||||
method='POST', url=self.db_url, content=data, params=params
|
||||
method="POST", url=self.db_url, content=data, params=params
|
||||
)
|
||||
r = self.request_session.send(request, stream=stream)
|
||||
if isinstance(r, httpx.Response) and r.status_code != 200:
|
||||
|
@ -457,52 +468,52 @@ class Database:
|
|||
params.update(self.settings)
|
||||
params.update(self._context_params)
|
||||
if self.db_exists:
|
||||
params['database'] = self.db_name
|
||||
params["database"] = self.db_name
|
||||
# Send the readonly flag, unless the connection is already readonly (to prevent db error)
|
||||
if self.readonly and not self.connection_readonly:
|
||||
params['readonly'] = '1'
|
||||
params["readonly"] = "1"
|
||||
return params
|
||||
|
||||
def _substitute(self, query, model_class=None):
|
||||
"""
|
||||
Replaces $db and $table placeholders in the query.
|
||||
"""
|
||||
if '$' in query:
|
||||
if "$" in query:
|
||||
mapping = dict(db="`%s`" % self.db_name)
|
||||
if model_class:
|
||||
if model_class.is_system_model():
|
||||
mapping['table'] = "`system`.`%s`" % model_class.table_name()
|
||||
mapping["table"] = "`system`.`%s`" % model_class.table_name()
|
||||
elif model_class.is_temporary_model():
|
||||
mapping['table'] = "`%s`" % model_class.table_name()
|
||||
mapping["table"] = "`%s`" % model_class.table_name()
|
||||
else:
|
||||
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
||||
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
||||
query = Template(query).safe_substitute(mapping)
|
||||
return query
|
||||
|
||||
def _get_server_timezone(self):
|
||||
try:
|
||||
r = self._send('SELECT timezone()')
|
||||
r = self._send("SELECT timezone()")
|
||||
return pytz.timezone(r.text.strip())
|
||||
except ServerError as err:
|
||||
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
|
||||
logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
|
||||
return pytz.utc
|
||||
|
||||
def _get_server_version(self, as_tuple=True):
|
||||
try:
|
||||
r = self._send('SELECT version();')
|
||||
r = self._send("SELECT version();")
|
||||
ver = r.text
|
||||
except ServerError as err:
|
||||
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
|
||||
ver = '1.1.0'
|
||||
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
|
||||
logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
|
||||
ver = "1.1.0"
|
||||
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
|
||||
|
||||
def _is_existing_database(self):
|
||||
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
|
||||
return r.text.strip() == '1'
|
||||
return r.text.strip() == "1"
|
||||
|
||||
def _is_connection_readonly(self):
|
||||
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
|
||||
return r.text.strip() != '0'
|
||||
return r.text.strip() != "0"
|
||||
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
|
|
|
@ -11,35 +11,30 @@ if TYPE_CHECKING:
|
|||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.funcs import F
|
||||
|
||||
logger = logging.getLogger('clickhouse_orm')
|
||||
logger = logging.getLogger("clickhouse_orm")
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def create_table_sql(self, db: Database) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class TinyLog(Engine):
|
||||
|
||||
def create_table_sql(self, db):
|
||||
return 'TinyLog'
|
||||
return "TinyLog"
|
||||
|
||||
|
||||
class Log(Engine):
|
||||
|
||||
def create_table_sql(self, db):
|
||||
return 'Log'
|
||||
return "Log"
|
||||
|
||||
|
||||
class Memory(Engine):
|
||||
|
||||
def create_table_sql(self, db):
|
||||
return 'Memory'
|
||||
return "Memory"
|
||||
|
||||
|
||||
class MergeTree(Engine):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_col: Optional[str] = None,
|
||||
|
@ -49,22 +44,27 @@ class MergeTree(Engine):
|
|||
replica_table_path: Optional[str] = None,
|
||||
replica_name: Optional[str] = None,
|
||||
partition_key: Optional[Union[list, tuple]] = None,
|
||||
primary_key: Optional[Union[list, tuple]] = None
|
||||
primary_key: Optional[Union[list, tuple]] = None,
|
||||
):
|
||||
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple'
|
||||
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present'
|
||||
assert primary_key is None or type(primary_key) in (list, tuple), \
|
||||
'primary_key must be a list or tuple'
|
||||
assert partition_key is None or type(partition_key) in (list, tuple),\
|
||||
'partition_key must be tuple or list if present'
|
||||
assert (replica_table_path is None) == (replica_name is None), \
|
||||
'both replica_table_path and replica_name must be specified'
|
||||
assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
|
||||
assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
|
||||
assert primary_key is None or type(primary_key) in (
|
||||
list,
|
||||
tuple,
|
||||
), "primary_key must be a list or tuple"
|
||||
assert partition_key is None or type(partition_key) in (
|
||||
list,
|
||||
tuple,
|
||||
), "partition_key must be tuple or list if present"
|
||||
assert (replica_table_path is None) == (
|
||||
replica_name is None
|
||||
), "both replica_table_path and replica_name must be specified"
|
||||
|
||||
# These values conflict with each other (old and new syntax of table engines.
|
||||
# So let's control only one of them is given.
|
||||
assert date_col or partition_key, "You must set either date_col or partition_key"
|
||||
self.date_col = date_col
|
||||
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,)
|
||||
self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
|
||||
self.primary_key = primary_key
|
||||
|
||||
self.order_by = order_by
|
||||
|
@ -76,28 +76,33 @@ class MergeTree(Engine):
|
|||
# I changed field name for new reality and syntax
|
||||
@property
|
||||
def key_cols(self):
|
||||
logger.warning('`key_cols` attribute is deprecated and may be removed in future. '
|
||||
'Use `order_by` attribute instead')
|
||||
logger.warning(
|
||||
"`key_cols` attribute is deprecated and may be removed in future. "
|
||||
"Use `order_by` attribute instead"
|
||||
)
|
||||
return self.order_by
|
||||
|
||||
@key_cols.setter
|
||||
def key_cols(self, value):
|
||||
logger.warning('`key_cols` attribute is deprecated and may be removed in future. '
|
||||
'Use `order_by` attribute instead')
|
||||
logger.warning(
|
||||
"`key_cols` attribute is deprecated and may be removed in future. "
|
||||
"Use `order_by` attribute instead"
|
||||
)
|
||||
self.order_by = value
|
||||
|
||||
def create_table_sql(self, db: Database) -> str:
|
||||
name = self.__class__.__name__
|
||||
if self.replica_name:
|
||||
name = 'Replicated' + name
|
||||
name = "Replicated" + name
|
||||
|
||||
# In ClickHouse 1.1.54310 custom partitioning key was introduced
|
||||
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
|
||||
# Let's check version and use new syntax if available
|
||||
if db.server_version >= (1, 1, 54310):
|
||||
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \
|
||||
% (comma_join(self.partition_key, stringify=True),
|
||||
comma_join(self.order_by, stringify=True))
|
||||
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
|
||||
comma_join(self.partition_key, stringify=True),
|
||||
comma_join(self.order_by, stringify=True),
|
||||
)
|
||||
|
||||
if self.primary_key:
|
||||
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
|
||||
|
@ -110,16 +115,17 @@ class MergeTree(Engine):
|
|||
elif not self.date_col:
|
||||
# Can't import it globally due to circular import
|
||||
from clickhouse_orm.database import DatabaseException
|
||||
|
||||
raise DatabaseException(
|
||||
"Custom partitioning is not supported before ClickHouse 1.1.54310. "
|
||||
"Please update your server or use date_col syntax."
|
||||
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
|
||||
)
|
||||
else:
|
||||
partition_sql = ''
|
||||
partition_sql = ""
|
||||
|
||||
params = self._build_sql_params(db)
|
||||
return '%s(%s) %s' % (name, comma_join(params), partition_sql)
|
||||
return "%s(%s) %s" % (name, comma_join(params), partition_sql)
|
||||
|
||||
def _build_sql_params(self, db: Database) -> list[str]:
|
||||
params = []
|
||||
|
@ -134,22 +140,34 @@ class MergeTree(Engine):
|
|||
params.append(self.date_col)
|
||||
if self.sampling_expr:
|
||||
params.append(self.sampling_expr)
|
||||
params.append('(%s)' % comma_join(self.order_by, stringify=True))
|
||||
params.append("(%s)" % comma_join(self.order_by, stringify=True))
|
||||
params.append(str(self.index_granularity))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class CollapsingMergeTree(MergeTree):
|
||||
|
||||
def __init__(
|
||||
self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None,
|
||||
index_granularity=8192, replica_table_path=None, replica_name=None,
|
||||
partition_key=None, primary_key=None
|
||||
self,
|
||||
date_col=None,
|
||||
order_by=(),
|
||||
sign_col="sign",
|
||||
sampling_expr=None,
|
||||
index_granularity=8192,
|
||||
replica_table_path=None,
|
||||
replica_name=None,
|
||||
partition_key=None,
|
||||
primary_key=None,
|
||||
):
|
||||
super(CollapsingMergeTree, self).__init__(
|
||||
date_col, order_by, sampling_expr, index_granularity,
|
||||
replica_table_path, replica_name, partition_key, primary_key
|
||||
date_col,
|
||||
order_by,
|
||||
sampling_expr,
|
||||
index_granularity,
|
||||
replica_table_path,
|
||||
replica_name,
|
||||
partition_key,
|
||||
primary_key,
|
||||
)
|
||||
self.sign_col = sign_col
|
||||
|
||||
|
@ -160,37 +178,63 @@ class CollapsingMergeTree(MergeTree):
|
|||
|
||||
|
||||
class SummingMergeTree(MergeTree):
|
||||
|
||||
def __init__(
|
||||
self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None,
|
||||
index_granularity=8192, replica_table_path=None, replica_name=None,
|
||||
partition_key=None, primary_key=None
|
||||
self,
|
||||
date_col=None,
|
||||
order_by=(),
|
||||
summing_cols=None,
|
||||
sampling_expr=None,
|
||||
index_granularity=8192,
|
||||
replica_table_path=None,
|
||||
replica_name=None,
|
||||
partition_key=None,
|
||||
primary_key=None,
|
||||
):
|
||||
super(SummingMergeTree, self).__init__(
|
||||
date_col, order_by, sampling_expr, index_granularity,
|
||||
replica_table_path, replica_name, partition_key, primary_key
|
||||
date_col,
|
||||
order_by,
|
||||
sampling_expr,
|
||||
index_granularity,
|
||||
replica_table_path,
|
||||
replica_name,
|
||||
partition_key,
|
||||
primary_key,
|
||||
)
|
||||
assert type is None or type(summing_cols) in (list, tuple), \
|
||||
'summing_cols must be a list or tuple'
|
||||
assert type is None or type(summing_cols) in (
|
||||
list,
|
||||
tuple,
|
||||
), "summing_cols must be a list or tuple"
|
||||
self.summing_cols = summing_cols
|
||||
|
||||
def _build_sql_params(self, db: Database) -> list[str]:
|
||||
params = super(SummingMergeTree, self)._build_sql_params(db)
|
||||
if self.summing_cols:
|
||||
params.append('(%s)' % comma_join(self.summing_cols))
|
||||
params.append("(%s)" % comma_join(self.summing_cols))
|
||||
return params
|
||||
|
||||
|
||||
class ReplacingMergeTree(MergeTree):
|
||||
|
||||
def __init__(
|
||||
self, date_col=None, order_by=(), ver_col=None, sampling_expr=None,
|
||||
index_granularity=8192, replica_table_path=None, replica_name=None,
|
||||
partition_key=None, primary_key=None
|
||||
self,
|
||||
date_col=None,
|
||||
order_by=(),
|
||||
ver_col=None,
|
||||
sampling_expr=None,
|
||||
index_granularity=8192,
|
||||
replica_table_path=None,
|
||||
replica_name=None,
|
||||
partition_key=None,
|
||||
primary_key=None,
|
||||
):
|
||||
super(ReplacingMergeTree, self).__init__(
|
||||
date_col, order_by, sampling_expr, index_granularity,
|
||||
replica_table_path, replica_name, partition_key, primary_key
|
||||
date_col,
|
||||
order_by,
|
||||
sampling_expr,
|
||||
index_granularity,
|
||||
replica_table_path,
|
||||
replica_name,
|
||||
partition_key,
|
||||
primary_key,
|
||||
)
|
||||
self.ver_col = ver_col
|
||||
|
||||
|
@ -217,7 +261,7 @@ class Buffer(Engine):
|
|||
min_rows: int = 10000,
|
||||
max_rows: int = 1000000,
|
||||
min_bytes: int = 10000000,
|
||||
max_bytes: int = 100000000
|
||||
max_bytes: int = 100000000,
|
||||
):
|
||||
self.main_model = main_model
|
||||
self.num_layers = num_layers
|
||||
|
@ -231,11 +275,17 @@ class Buffer(Engine):
|
|||
def create_table_sql(self, db: Database) -> str:
|
||||
# Overriden create_table_sql example:
|
||||
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
|
||||
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % (
|
||||
db.db_name, self.main_model.table_name(), self.num_layers,
|
||||
self.min_time, self.max_time, self.min_rows,
|
||||
self.max_rows, self.min_bytes, self.max_bytes
|
||||
)
|
||||
sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
|
||||
db.db_name,
|
||||
self.main_model.table_name(),
|
||||
self.num_layers,
|
||||
self.min_time,
|
||||
self.max_time,
|
||||
self.min_rows,
|
||||
self.max_rows,
|
||||
self.min_bytes,
|
||||
self.max_bytes,
|
||||
)
|
||||
return sql
|
||||
|
||||
|
||||
|
@ -265,6 +315,7 @@ class Distributed(Engine):
|
|||
See full documentation here
|
||||
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
|
||||
"""
|
||||
|
||||
def __init__(self, cluster, table=None, sharding_key=None):
|
||||
"""
|
||||
- `cluster`: what cluster to access data from
|
||||
|
@ -292,12 +343,15 @@ class Distributed(Engine):
|
|||
def create_table_sql(self, db: Database) -> str:
|
||||
name = self.__class__.__name__
|
||||
params = self._build_sql_params(db)
|
||||
return '%s(%s)' % (name, ', '.join(params))
|
||||
return "%s(%s)" % (name, ", ".join(params))
|
||||
|
||||
def _build_sql_params(self, db: Database) -> list[str]:
|
||||
if self.table_name is None:
|
||||
raise ValueError("Cannot create {} engine: specify an underlying table".format(
|
||||
self.__class__.__name__))
|
||||
raise ValueError(
|
||||
"Cannot create {} engine: specify an underlying table".format(
|
||||
self.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
|
||||
if self.sharding_key:
|
||||
|
|
|
@ -21,13 +21,14 @@ if TYPE_CHECKING:
|
|||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.database import Database
|
||||
|
||||
logger = getLogger('clickhouse_orm')
|
||||
logger = getLogger("clickhouse_orm")
|
||||
|
||||
|
||||
class Field(FunctionOperatorsMixin):
|
||||
"""
|
||||
Abstract base class for all field types.
|
||||
"""
|
||||
|
||||
name: str = None # this is set by the parent model
|
||||
parent: type["Model"] = None # this is set by the parent model
|
||||
creation_counter: int = 0 # used for keeping the model fields ordered
|
||||
|
@ -41,21 +42,29 @@ class Field(FunctionOperatorsMixin):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert [default, alias, materialized].count(None) >= 2, \
|
||||
"Only one of default, alias and materialized parameters can be given"
|
||||
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
|
||||
"Alias parameter must be a string or function object, if given"
|
||||
assert (materialized is None or isinstance(materialized, F) or
|
||||
isinstance(materialized, str) and materialized != ""), \
|
||||
"Materialized parameter must be a string or function object, if given"
|
||||
assert readonly is None or type(
|
||||
readonly) is bool, "readonly parameter must be bool if given"
|
||||
assert codec is None or isinstance(codec, str) and codec != "", \
|
||||
"Codec field must be string, if given"
|
||||
assert db_column is None or isinstance(db_column, str) and db_column != "", \
|
||||
"db_column field must be string, if given"
|
||||
assert [default, alias, materialized].count(
|
||||
None
|
||||
) >= 2, "Only one of default, alias and materialized parameters can be given"
|
||||
assert (
|
||||
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
|
||||
), "Alias parameter must be a string or function object, if given"
|
||||
assert (
|
||||
materialized is None
|
||||
or isinstance(materialized, F)
|
||||
or isinstance(materialized, str)
|
||||
and materialized != ""
|
||||
), "Materialized parameter must be a string or function object, if given"
|
||||
assert (
|
||||
readonly is None or type(readonly) is bool
|
||||
), "readonly parameter must be bool if given"
|
||||
assert (
|
||||
codec is None or isinstance(codec, str) and codec != ""
|
||||
), "Codec field must be string, if given"
|
||||
assert (
|
||||
db_column is None or isinstance(db_column, str) and db_column != ""
|
||||
), "db_column field must be string, if given"
|
||||
|
||||
self.creation_counter = Field.creation_counter
|
||||
Field.creation_counter += 1
|
||||
|
@ -70,7 +79,7 @@ class Field(FunctionOperatorsMixin):
|
|||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.__class__.__name__
|
||||
return "<%s>" % self.__class__.__name__
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
"""
|
||||
|
@ -92,9 +101,10 @@ class Field(FunctionOperatorsMixin):
|
|||
Utility method to check that the given value is between min_value and max_value.
|
||||
"""
|
||||
if value < min_value or value > max_value:
|
||||
raise ValueError('%s out of range - %s is not between %s and %s' % (
|
||||
self.__class__.__name__, value, min_value, max_value
|
||||
))
|
||||
raise ValueError(
|
||||
"%s out of range - %s is not between %s and %s"
|
||||
% (self.__class__.__name__, value, min_value, max_value)
|
||||
)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
"""
|
||||
|
@ -114,7 +124,7 @@ class Field(FunctionOperatorsMixin):
|
|||
sql = self.db_type
|
||||
args = self.get_db_type_args()
|
||||
if args:
|
||||
sql += '(%s)' % comma_join(args)
|
||||
sql += "(%s)" % comma_join(args)
|
||||
if with_default_expression:
|
||||
sql += self._extra_params(db)
|
||||
return sql
|
||||
|
@ -124,18 +134,18 @@ class Field(FunctionOperatorsMixin):
|
|||
return []
|
||||
|
||||
def _extra_params(self, db: Database) -> str:
|
||||
sql = ''
|
||||
sql = ""
|
||||
if self.alias:
|
||||
sql += ' ALIAS %s' % string_or_func(self.alias)
|
||||
sql += " ALIAS %s" % string_or_func(self.alias)
|
||||
elif self.materialized:
|
||||
sql += ' MATERIALIZED %s' % string_or_func(self.materialized)
|
||||
sql += " MATERIALIZED %s" % string_or_func(self.materialized)
|
||||
elif isinstance(self.default, F):
|
||||
sql += ' DEFAULT %s' % self.default.to_sql()
|
||||
sql += " DEFAULT %s" % self.default.to_sql()
|
||||
elif self.default:
|
||||
default = self.to_db_string(self.default)
|
||||
sql += ' DEFAULT %s' % default
|
||||
sql += " DEFAULT %s" % default
|
||||
if self.codec and db and db.has_codec_support and not self.alias:
|
||||
sql += ' CODEC(%s)' % self.codec
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
def isinstance(self, types) -> bool:
|
||||
|
@ -149,28 +159,27 @@ class Field(FunctionOperatorsMixin):
|
|||
"""
|
||||
if isinstance(self, types):
|
||||
return True
|
||||
inner_field = getattr(self, 'inner_field', None)
|
||||
inner_field = getattr(self, "inner_field", None)
|
||||
while inner_field:
|
||||
if isinstance(inner_field, types):
|
||||
return True
|
||||
inner_field = getattr(inner_field, 'inner_field', None)
|
||||
inner_field = getattr(inner_field, "inner_field", None)
|
||||
return False
|
||||
|
||||
|
||||
class StringField(Field):
|
||||
class_default = ''
|
||||
db_type = 'String'
|
||||
class_default = ""
|
||||
db_type = "String"
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
return value.decode('UTF-8')
|
||||
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||
return value.decode("UTF-8")
|
||||
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
|
||||
|
||||
|
||||
class FixedStringField(StringField):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
length: int,
|
||||
|
@ -178,22 +187,22 @@ class FixedStringField(StringField):
|
|||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: Optional[bool] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
self._length = length
|
||||
self.db_type = 'FixedString(%d)' % length
|
||||
self.db_type = "FixedString(%d)" % length
|
||||
super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> str:
|
||||
value = super(FixedStringField, self).to_python(value, timezone_in_use)
|
||||
return value.rstrip('\0')
|
||||
return value.rstrip("\0")
|
||||
|
||||
def validate(self, value):
|
||||
if isinstance(value, str):
|
||||
value = value.encode('UTF-8')
|
||||
value = value.encode("UTF-8")
|
||||
if len(value) > self._length:
|
||||
raise ValueError(
|
||||
f'Value of {len(value)} bytes is too long for FixedStringField({self._length})'
|
||||
f"Value of {len(value)} bytes is too long for FixedStringField({self._length})"
|
||||
)
|
||||
|
||||
|
||||
|
@ -201,7 +210,7 @@ class DateField(Field):
|
|||
min_value = datetime.date(1970, 1, 1)
|
||||
max_value = datetime.date(2105, 12, 31)
|
||||
class_default = min_value
|
||||
db_type = 'Date'
|
||||
db_type = "Date"
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> datetime.date:
|
||||
if isinstance(value, datetime.datetime):
|
||||
|
@ -211,10 +220,10 @@ class DateField(Field):
|
|||
if isinstance(value, int):
|
||||
return DateField.class_default + datetime.timedelta(days=value)
|
||||
if isinstance(value, str):
|
||||
if value == '0000-00-00':
|
||||
if value == "0000-00-00":
|
||||
return DateField.min_value
|
||||
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
|
||||
def validate(self, value):
|
||||
self._range_check(value, DateField.min_value, DateField.max_value)
|
||||
|
@ -225,7 +234,7 @@ class DateField(Field):
|
|||
|
||||
class DateTimeField(Field):
|
||||
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
||||
db_type = 'DateTime'
|
||||
db_type = "DateTime"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -235,7 +244,7 @@ class DateTimeField(Field):
|
|||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None,
|
||||
timezone: Optional[Union[BaseTzInfo, str]] = None
|
||||
timezone: Optional[Union[BaseTzInfo, str]] = None,
|
||||
):
|
||||
super().__init__(default, alias, materialized, readonly, codec, db_column)
|
||||
# assert not timezone, 'Temporarily field timezone is not supported'
|
||||
|
@ -257,7 +266,7 @@ class DateTimeField(Field):
|
|||
if isinstance(value, int):
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
if isinstance(value, str):
|
||||
if value == '0000-00-00 00:00:00':
|
||||
if value == "0000-00-00 00:00:00":
|
||||
return self.class_default
|
||||
if len(value) == 10:
|
||||
try:
|
||||
|
@ -275,14 +284,14 @@ class DateTimeField(Field):
|
|||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = timezone_in_use.localize(dt)
|
||||
return dt
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
return escape('%010d' % timegm(value.utctimetuple()), quote)
|
||||
return escape("%010d" % timegm(value.utctimetuple()), quote)
|
||||
|
||||
|
||||
class DateTime64Field(DateTimeField):
|
||||
db_type = 'DateTime64'
|
||||
db_type = "DateTime64"
|
||||
|
||||
"""
|
||||
|
||||
|
@ -303,10 +312,10 @@ class DateTime64Field(DateTimeField):
|
|||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None,
|
||||
timezone: Optional[Union[BaseTzInfo, str]] = None,
|
||||
precision: int = 6
|
||||
precision: int = 6,
|
||||
):
|
||||
super().__init__(default, alias, materialized, readonly, codec, db_column, timezone)
|
||||
assert precision is None or isinstance(precision, int), 'Precision must be int type'
|
||||
assert precision is None or isinstance(precision, int), "Precision must be int type"
|
||||
self.precision = precision
|
||||
|
||||
def get_db_type_args(self):
|
||||
|
@ -322,11 +331,10 @@ class DateTime64Field(DateTimeField):
|
|||
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
|
||||
"""
|
||||
return escape(
|
||||
'{timestamp:0{width}.{precision}f}'.format(
|
||||
timestamp=value.timestamp(),
|
||||
width=11 + self.precision,
|
||||
precision=self.precision),
|
||||
quote
|
||||
"{timestamp:0{width}.{precision}f}".format(
|
||||
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
|
||||
),
|
||||
quote,
|
||||
)
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> datetime.datetime:
|
||||
|
@ -336,8 +344,8 @@ class DateTime64Field(DateTimeField):
|
|||
if isinstance(value, (int, float)):
|
||||
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||
if isinstance(value, str):
|
||||
left_part = value.split('.')[0]
|
||||
if left_part == '0000-00-00 00:00:00':
|
||||
left_part = value.split(".")[0]
|
||||
if left_part == "0000-00-00 00:00:00":
|
||||
return self.class_default
|
||||
if len(left_part) == 10:
|
||||
try:
|
||||
|
@ -357,7 +365,7 @@ class BaseIntField(Field):
|
|||
try:
|
||||
return int(value)
|
||||
except:
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
# There's no need to call escape since numbers do not contain
|
||||
|
@ -370,50 +378,50 @@ class BaseIntField(Field):
|
|||
|
||||
class UInt8Field(BaseIntField):
|
||||
min_value = 0
|
||||
max_value = 2 ** 8 - 1
|
||||
db_type = 'UInt8'
|
||||
max_value = 2**8 - 1
|
||||
db_type = "UInt8"
|
||||
|
||||
|
||||
class UInt16Field(BaseIntField):
|
||||
min_value = 0
|
||||
max_value = 2 ** 16 - 1
|
||||
db_type = 'UInt16'
|
||||
max_value = 2**16 - 1
|
||||
db_type = "UInt16"
|
||||
|
||||
|
||||
class UInt32Field(BaseIntField):
|
||||
min_value = 0
|
||||
max_value = 2 ** 32 - 1
|
||||
db_type = 'UInt32'
|
||||
max_value = 2**32 - 1
|
||||
db_type = "UInt32"
|
||||
|
||||
|
||||
class UInt64Field(BaseIntField):
|
||||
min_value = 0
|
||||
max_value = 2 ** 64 - 1
|
||||
db_type = 'UInt64'
|
||||
max_value = 2**64 - 1
|
||||
db_type = "UInt64"
|
||||
|
||||
|
||||
class Int8Field(BaseIntField):
|
||||
min_value = -2 ** 7
|
||||
max_value = 2 ** 7 - 1
|
||||
db_type = 'Int8'
|
||||
min_value = -(2**7)
|
||||
max_value = 2**7 - 1
|
||||
db_type = "Int8"
|
||||
|
||||
|
||||
class Int16Field(BaseIntField):
|
||||
min_value = -2 ** 15
|
||||
max_value = 2 ** 15 - 1
|
||||
db_type = 'Int16'
|
||||
min_value = -(2**15)
|
||||
max_value = 2**15 - 1
|
||||
db_type = "Int16"
|
||||
|
||||
|
||||
class Int32Field(BaseIntField):
|
||||
min_value = -2 ** 31
|
||||
max_value = 2 ** 31 - 1
|
||||
db_type = 'Int32'
|
||||
min_value = -(2**31)
|
||||
max_value = 2**31 - 1
|
||||
db_type = "Int32"
|
||||
|
||||
|
||||
class Int64Field(BaseIntField):
|
||||
min_value = -2 ** 63
|
||||
max_value = 2 ** 63 - 1
|
||||
db_type = 'Int64'
|
||||
min_value = -(2**63)
|
||||
max_value = 2**63 - 1
|
||||
db_type = "Int64"
|
||||
|
||||
|
||||
class BaseFloatField(Field):
|
||||
|
@ -425,7 +433,7 @@ class BaseFloatField(Field):
|
|||
try:
|
||||
return float(value)
|
||||
except:
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
# There's no need to call escape since numbers do not contain
|
||||
|
@ -434,11 +442,11 @@ class BaseFloatField(Field):
|
|||
|
||||
|
||||
class Float32Field(BaseFloatField):
|
||||
db_type = 'Float32'
|
||||
db_type = "Float32"
|
||||
|
||||
|
||||
class Float64Field(BaseFloatField):
|
||||
db_type = 'Float64'
|
||||
db_type = "Float64"
|
||||
|
||||
|
||||
class DecimalField(Field):
|
||||
|
@ -454,13 +462,13 @@ class DecimalField(Field):
|
|||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
|
||||
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
|
||||
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
|
||||
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
|
||||
self.precision = precision
|
||||
self.scale = scale
|
||||
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale)
|
||||
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
|
||||
with localcontext() as ctx:
|
||||
ctx.prec = 38
|
||||
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
|
||||
|
@ -473,9 +481,9 @@ class DecimalField(Field):
|
|||
try:
|
||||
value = Decimal(value)
|
||||
except:
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
|
||||
if not value.is_finite():
|
||||
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
|
||||
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
|
||||
return self._round(value)
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
|
@ -498,14 +506,13 @@ class Decimal32Field(DecimalField):
|
|||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
super().__init__(9, scale, default, alias, materialized, readonly, db_column)
|
||||
self.db_type = 'Decimal32(%d)' % scale
|
||||
self.db_type = "Decimal32(%d)" % scale
|
||||
|
||||
|
||||
class Decimal64Field(DecimalField):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
|
@ -513,14 +520,13 @@ class Decimal64Field(DecimalField):
|
|||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
super().__init__(18, scale, default, alias, materialized, readonly, db_column)
|
||||
self.db_type = 'Decimal64(%d)' % scale
|
||||
self.db_type = "Decimal64(%d)" % scale
|
||||
|
||||
|
||||
class Decimal128Field(DecimalField):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
|
@ -528,10 +534,10 @@ class Decimal128Field(DecimalField):
|
|||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
super().__init__(38, scale, default, alias, materialized, readonly, db_column)
|
||||
self.db_type = 'Decimal128(%d)' % scale
|
||||
self.db_type = "Decimal128(%d)" % scale
|
||||
|
||||
|
||||
class BaseEnumField(Field):
|
||||
|
@ -547,7 +553,7 @@ class BaseEnumField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
self.enum_cls = enum_cls
|
||||
if default is None:
|
||||
|
@ -564,7 +570,7 @@ class BaseEnumField(Field):
|
|||
except Exception:
|
||||
return self.enum_cls(value)
|
||||
if isinstance(value, bytes):
|
||||
decoded = value.decode('UTF-8')
|
||||
decoded = value.decode("UTF-8")
|
||||
try:
|
||||
return self.enum_cls[decoded]
|
||||
except Exception:
|
||||
|
@ -573,13 +579,13 @@ class BaseEnumField(Field):
|
|||
return self.enum_cls(value)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
|
||||
raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
return escape(value.name, quote)
|
||||
|
||||
def get_db_type_args(self):
|
||||
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
|
||||
return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
|
||||
|
||||
@classmethod
|
||||
def create_ad_hoc_field(cls, db_type) -> BaseEnumField:
|
||||
|
@ -590,17 +596,17 @@ class BaseEnumField(Field):
|
|||
members = {}
|
||||
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
|
||||
members[match.group(1)] = int(match.group(2))
|
||||
enum_cls = Enum('AdHocEnum', members)
|
||||
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
|
||||
enum_cls = Enum("AdHocEnum", members)
|
||||
field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
|
||||
return field_class(enum_cls)
|
||||
|
||||
|
||||
class Enum8Field(BaseEnumField):
|
||||
db_type = 'Enum8'
|
||||
db_type = "Enum8"
|
||||
|
||||
|
||||
class Enum16Field(BaseEnumField):
|
||||
db_type = 'Enum16'
|
||||
db_type = "Enum16"
|
||||
|
||||
|
||||
class ArrayField(Field):
|
||||
|
@ -614,12 +620,14 @@ class ArrayField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), \
|
||||
"The first argument of ArrayField must be a Field instance"
|
||||
assert not isinstance(inner_field, ArrayField), \
|
||||
"Multidimensional array fields are not supported by the ORM"
|
||||
assert isinstance(
|
||||
inner_field, Field
|
||||
), "The first argument of ArrayField must be a Field instance"
|
||||
assert not isinstance(
|
||||
inner_field, ArrayField
|
||||
), "Multidimensional array fields are not supported by the ORM"
|
||||
self.inner_field = inner_field
|
||||
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column)
|
||||
|
||||
|
@ -627,9 +635,9 @@ class ArrayField(Field):
|
|||
if isinstance(value, str):
|
||||
value = parse_array(value)
|
||||
elif isinstance(value, bytes):
|
||||
value = parse_array(value.decode('UTF-8'))
|
||||
value = parse_array(value.decode("UTF-8"))
|
||||
elif not isinstance(value, (list, tuple)):
|
||||
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
|
||||
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
|
||||
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
||||
|
||||
def validate(self, value):
|
||||
|
@ -638,12 +646,12 @@ class ArrayField(Field):
|
|||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
|
||||
return '[' + comma_join(array) + ']'
|
||||
return "[" + comma_join(array) + "]"
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None) -> str:
|
||||
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||
sql += ' CODEC(%s)' % self.codec
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
|
||||
|
@ -658,17 +666,19 @@ class TupleField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
self.names = {}
|
||||
self.inner_fields = []
|
||||
for (name, field) in name_fields:
|
||||
if name in self.names:
|
||||
raise ValueError('The Field name conflict')
|
||||
assert isinstance(field, Field), \
|
||||
"The first argument of TupleField must be a Field instance"
|
||||
assert not isinstance(field, (ArrayField, TupleField)), \
|
||||
"Multidimensional array fields are not supported by the ORM"
|
||||
raise ValueError("The Field name conflict")
|
||||
assert isinstance(
|
||||
field, Field
|
||||
), "The first argument of TupleField must be a Field instance"
|
||||
assert not isinstance(
|
||||
field, (ArrayField, TupleField)
|
||||
), "Multidimensional array fields are not supported by the ORM"
|
||||
self.names[name] = field
|
||||
self.inner_fields.append(field)
|
||||
self.class_default = tuple(field.class_default for field in self.inner_fields)
|
||||
|
@ -677,16 +687,19 @@ class TupleField(Field):
|
|||
def to_python(self, value, timezone_in_use) -> tuple:
|
||||
if isinstance(value, str):
|
||||
value = parse_array(value)
|
||||
value = (self.inner_fields[i].to_python(v, timezone_in_use)
|
||||
for i, v in enumerate(value))
|
||||
value = (
|
||||
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
|
||||
)
|
||||
elif isinstance(value, bytes):
|
||||
value = parse_array(value.decode('UTF-8'))
|
||||
value = (self.inner_fields[i].to_python(v, timezone_in_use)
|
||||
for i, v in enumerate(value))
|
||||
value = parse_array(value.decode("UTF-8"))
|
||||
value = (
|
||||
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
|
||||
)
|
||||
elif not isinstance(value, (list, tuple)):
|
||||
raise ValueError('TupleField expects list or tuple, not %s' % type(value))
|
||||
return tuple(self.inner_fields[i].to_python(v, timezone_in_use)
|
||||
for i, v in enumerate(value))
|
||||
raise ValueError("TupleField expects list or tuple, not %s" % type(value))
|
||||
return tuple(
|
||||
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
for i, v in enumerate(value):
|
||||
|
@ -694,21 +707,22 @@ class TupleField(Field):
|
|||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)]
|
||||
return '(' + comma_join(array) + ')'
|
||||
return "(" + comma_join(array) + ")"
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None) -> str:
|
||||
inner_sql = ', '.join('%s %s' % (name, field.get_sql(False))
|
||||
for name, field in self.names.items())
|
||||
inner_sql = ", ".join(
|
||||
"%s %s" % (name, field.get_sql(False)) for name, field in self.names.items()
|
||||
)
|
||||
|
||||
sql = 'Tuple(%s)' % inner_sql
|
||||
sql = "Tuple(%s)" % inner_sql
|
||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||
sql += ' CODEC(%s)' % self.codec
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
|
||||
class UUIDField(Field):
|
||||
class_default = UUID(int=0)
|
||||
db_type = 'UUID'
|
||||
db_type = "UUID"
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> UUID:
|
||||
if isinstance(value, UUID):
|
||||
|
@ -722,7 +736,7 @@ class UUIDField(Field):
|
|||
elif isinstance(value, tuple):
|
||||
return UUID(fields=value)
|
||||
else:
|
||||
raise ValueError('Invalid value for UUIDField: %r' % value)
|
||||
raise ValueError("Invalid value for UUIDField: %r" % value)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
return escape(str(value), quote)
|
||||
|
@ -730,7 +744,7 @@ class UUIDField(Field):
|
|||
|
||||
class IPv4Field(Field):
|
||||
class_default = 0
|
||||
db_type = 'IPv4'
|
||||
db_type = "IPv4"
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> IPv4Address:
|
||||
if isinstance(value, IPv4Address):
|
||||
|
@ -738,7 +752,7 @@ class IPv4Field(Field):
|
|||
elif isinstance(value, (bytes, str, int)):
|
||||
return IPv4Address(value)
|
||||
else:
|
||||
raise ValueError('Invalid value for IPv4Address: %r' % value)
|
||||
raise ValueError("Invalid value for IPv4Address: %r" % value)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
return escape(str(value), quote)
|
||||
|
@ -746,7 +760,7 @@ class IPv4Field(Field):
|
|||
|
||||
class IPv6Field(Field):
|
||||
class_default = 0
|
||||
db_type = 'IPv6'
|
||||
db_type = "IPv6"
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> IPv6Address:
|
||||
if isinstance(value, IPv6Address):
|
||||
|
@ -754,7 +768,7 @@ class IPv6Field(Field):
|
|||
elif isinstance(value, (bytes, str, int)):
|
||||
return IPv6Address(value)
|
||||
else:
|
||||
raise ValueError('Invalid value for IPv6Address: %r' % value)
|
||||
raise ValueError("Invalid value for IPv6Address: %r" % value)
|
||||
|
||||
def to_db_string(self, value, quote=True):
|
||||
return escape(str(value), quote)
|
||||
|
@ -771,11 +785,13 @@ class NullableField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
extra_null_values: Optional[Iterable] = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), \
|
||||
"The first argument of NullableField must be a Field instance." \
|
||||
" Not: {}".format(inner_field)
|
||||
assert isinstance(
|
||||
inner_field, Field
|
||||
), "The first argument of NullableField must be a Field instance." " Not: {}".format(
|
||||
inner_field
|
||||
)
|
||||
self.inner_field = inner_field
|
||||
self._null_values = [None]
|
||||
if extra_null_values:
|
||||
|
@ -785,7 +801,7 @@ class NullableField(Field):
|
|||
)
|
||||
|
||||
def to_python(self, value, timezone_in_use):
|
||||
if value == '\\N' or value in self._null_values:
|
||||
if value == "\\N" or value in self._null_values:
|
||||
return None
|
||||
return self.inner_field.to_python(value, timezone_in_use)
|
||||
|
||||
|
@ -794,18 +810,17 @@ class NullableField(Field):
|
|||
|
||||
def to_db_string(self, value, quote=True):
|
||||
if value in self._null_values:
|
||||
return '\\N'
|
||||
return "\\N"
|
||||
return self.inner_field.to_db_string(value, quote=quote)
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None):
|
||||
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
|
||||
if with_default_expression:
|
||||
sql += self._extra_params(db)
|
||||
return sql
|
||||
|
||||
|
||||
class LowCardinalityField(Field):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_field: Field,
|
||||
|
@ -814,16 +829,20 @@ class LowCardinalityField(Field):
|
|||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: Optional[bool] = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert isinstance(inner_field, Field), \
|
||||
"The first argument of LowCardinalityField must be a Field instance." \
|
||||
" Not: {}".format(inner_field)
|
||||
assert not isinstance(inner_field, LowCardinalityField), \
|
||||
"LowCardinality inner fields are not supported by the ORM"
|
||||
assert not isinstance(inner_field, ArrayField), \
|
||||
"Array field inside LowCardinality are not supported by the ORM." \
|
||||
assert isinstance(
|
||||
inner_field, Field
|
||||
), "The first argument of LowCardinalityField must be a Field instance." " Not: {}".format(
|
||||
inner_field
|
||||
)
|
||||
assert not isinstance(
|
||||
inner_field, LowCardinalityField
|
||||
), "LowCardinality inner fields are not supported by the ORM"
|
||||
assert not isinstance(inner_field, ArrayField), (
|
||||
"Array field inside LowCardinality are not supported by the ORM."
|
||||
" Use Array(LowCardinality) instead"
|
||||
)
|
||||
self.inner_field = inner_field
|
||||
self.class_default = self.inner_field.class_default
|
||||
super().__init__(default, alias, materialized, readonly, codec, db_column)
|
||||
|
@ -839,12 +858,12 @@ class LowCardinalityField(Field):
|
|||
|
||||
def get_sql(self, with_default_expression=True, db=None):
|
||||
if db and db.has_low_cardinality_support:
|
||||
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
|
||||
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
|
||||
else:
|
||||
sql = self.inner_field.get_sql(with_default_expression=False)
|
||||
logger.warning(
|
||||
'LowCardinalityField not supported on clickhouse-server version < 19.0'
|
||||
' using {} as fallback'.format(self.inner_field.__class__.__name__)
|
||||
"LowCardinalityField not supported on clickhouse-server version < 19.0"
|
||||
" using {} as fallback".format(self.inner_field.__class__.__name__)
|
||||
)
|
||||
if with_default_expression:
|
||||
sql += self._extra_params(db)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,7 +5,7 @@ from .utils import get_subclass_names
|
|||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('migrations')
|
||||
logger = logging.getLogger("migrations")
|
||||
|
||||
|
||||
class Operation:
|
||||
|
@ -14,7 +14,7 @@ class Operation:
|
|||
"""
|
||||
|
||||
def apply(self, database):
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class ModelOperation(Operation):
|
||||
|
@ -30,9 +30,9 @@ class ModelOperation(Operation):
|
|||
self.table_name = model_class.table_name()
|
||||
|
||||
def _alter_table(self, database, cmd):
|
||||
'''
|
||||
"""
|
||||
Utility for running ALTER TABLE commands.
|
||||
'''
|
||||
"""
|
||||
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
|
||||
logger.debug(cmd)
|
||||
database.raw(cmd)
|
||||
|
@ -44,7 +44,7 @@ class CreateTable(ModelOperation):
|
|||
"""
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Create table %s', self.table_name)
|
||||
logger.info(" Create table %s", self.table_name)
|
||||
if issubclass(self.model_class, BufferModel):
|
||||
database.create_table(self.model_class.engine.main_model)
|
||||
database.create_table(self.model_class)
|
||||
|
@ -65,7 +65,7 @@ class AlterTable(ModelOperation):
|
|||
return [(row.name, row.type) for row in database.select(query)]
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Alter table %s', self.table_name)
|
||||
logger.info(" Alter table %s", self.table_name)
|
||||
|
||||
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
|
||||
# ADD COLUMN ... AFTER doesn't affect it
|
||||
|
@ -74,8 +74,8 @@ class AlterTable(ModelOperation):
|
|||
# Identify fields that were deleted from the model
|
||||
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
|
||||
for name in deleted_fields:
|
||||
logger.info(' Drop column %s', name)
|
||||
self._alter_table(database, 'DROP COLUMN %s' % name)
|
||||
logger.info(" Drop column %s", name)
|
||||
self._alter_table(database, "DROP COLUMN %s" % name)
|
||||
del table_fields[name]
|
||||
|
||||
# Identify fields that were added to the model
|
||||
|
@ -83,13 +83,13 @@ class AlterTable(ModelOperation):
|
|||
for name, field in self.model_class.fields().items():
|
||||
is_regular_field = not (field.materialized or field.alias)
|
||||
if name not in table_fields:
|
||||
logger.info(' Add column %s', name)
|
||||
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database))
|
||||
logger.info(" Add column %s", name)
|
||||
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
|
||||
if is_regular_field:
|
||||
if prev_name:
|
||||
cmd += ' AFTER %s' % prev_name
|
||||
cmd += " AFTER %s" % prev_name
|
||||
else:
|
||||
cmd += ' FIRST'
|
||||
cmd += " FIRST"
|
||||
self._alter_table(database, cmd)
|
||||
|
||||
if is_regular_field:
|
||||
|
@ -101,16 +101,24 @@ class AlterTable(ModelOperation):
|
|||
# The order of class attributes can be changed any time, so we can't count on it
|
||||
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
|
||||
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47
|
||||
model_fields = {name: field.get_sql(with_default_expression=False, db=database)
|
||||
for name, field in self.model_class.fields().items()}
|
||||
model_fields = {
|
||||
name: field.get_sql(with_default_expression=False, db=database)
|
||||
for name, field in self.model_class.fields().items()
|
||||
}
|
||||
for field_name, field_sql in self._get_table_fields(database):
|
||||
# All fields must have been created and dropped by this moment
|
||||
assert field_name in model_fields, 'Model fields and table columns in disagreement'
|
||||
assert field_name in model_fields, "Model fields and table columns in disagreement"
|
||||
|
||||
if field_sql != model_fields[field_name]:
|
||||
logger.info(' Change type of column %s from %s to %s', field_name, field_sql,
|
||||
model_fields[field_name])
|
||||
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name]))
|
||||
logger.info(
|
||||
" Change type of column %s from %s to %s",
|
||||
field_name,
|
||||
field_sql,
|
||||
model_fields[field_name],
|
||||
)
|
||||
self._alter_table(
|
||||
database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name])
|
||||
)
|
||||
|
||||
|
||||
class AlterTableWithBuffer(ModelOperation):
|
||||
|
@ -135,7 +143,7 @@ class DropTable(ModelOperation):
|
|||
"""
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Drop table %s', self.table_name)
|
||||
logger.info(" Drop table %s", self.table_name)
|
||||
database.drop_table(self.model_class)
|
||||
|
||||
|
||||
|
@ -148,28 +156,29 @@ class AlterConstraints(ModelOperation):
|
|||
"""
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Alter constraints for %s', self.table_name)
|
||||
logger.info(" Alter constraints for %s", self.table_name)
|
||||
existing = self._get_constraint_names(database)
|
||||
# Go over constraints in the model
|
||||
for constraint in self.model_class._constraints.values():
|
||||
# Check if it's a new constraint
|
||||
if constraint.name not in existing:
|
||||
logger.info(' Add constraint %s', constraint.name)
|
||||
self._alter_table(database, 'ADD %s' % constraint.create_table_sql())
|
||||
logger.info(" Add constraint %s", constraint.name)
|
||||
self._alter_table(database, "ADD %s" % constraint.create_table_sql())
|
||||
else:
|
||||
existing.remove(constraint.name)
|
||||
# Remaining constraints in `existing` are obsolete
|
||||
for name in existing:
|
||||
logger.info(' Drop constraint %s', name)
|
||||
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name)
|
||||
logger.info(" Drop constraint %s", name)
|
||||
self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
|
||||
|
||||
def _get_constraint_names(self, database):
|
||||
"""
|
||||
Returns a set containing the names of existing constraints in the table.
|
||||
"""
|
||||
import re
|
||||
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
|
||||
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def)
|
||||
|
||||
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
|
||||
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
|
||||
return set(matches)
|
||||
|
||||
|
||||
|
@ -191,33 +200,34 @@ class AlterIndexes(ModelOperation):
|
|||
self.reindex = reindex
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Alter indexes for %s', self.table_name)
|
||||
logger.info(" Alter indexes for %s", self.table_name)
|
||||
existing = self._get_index_names(database)
|
||||
logger.info(existing)
|
||||
# Go over indexes in the model
|
||||
for index in self.model_class._indexes.values():
|
||||
# Check if it's a new index
|
||||
if index.name not in existing:
|
||||
logger.info(' Add index %s', index.name)
|
||||
self._alter_table(database, 'ADD %s' % index.create_table_sql())
|
||||
logger.info(" Add index %s", index.name)
|
||||
self._alter_table(database, "ADD %s" % index.create_table_sql())
|
||||
else:
|
||||
existing.remove(index.name)
|
||||
# Remaining indexes in `existing` are obsolete
|
||||
for name in existing:
|
||||
logger.info(' Drop index %s', name)
|
||||
self._alter_table(database, 'DROP INDEX `%s`' % name)
|
||||
logger.info(" Drop index %s", name)
|
||||
self._alter_table(database, "DROP INDEX `%s`" % name)
|
||||
# Reindex
|
||||
if self.reindex:
|
||||
logger.info(' Build indexes on table')
|
||||
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name)
|
||||
logger.info(" Build indexes on table")
|
||||
database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
|
||||
|
||||
def _get_index_names(self, database):
|
||||
"""
|
||||
Returns a set containing the names of existing indexes in the table.
|
||||
"""
|
||||
import re
|
||||
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
|
||||
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def)
|
||||
|
||||
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
|
||||
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
|
||||
return set(matches)
|
||||
|
||||
|
||||
|
@ -225,16 +235,17 @@ class RunPython(Operation):
|
|||
"""
|
||||
A migration operation that executes a Python function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
'''
|
||||
"""
|
||||
Initializer. The given Python function will be called with a single
|
||||
argument - the Database instance to apply the migration to.
|
||||
'''
|
||||
"""
|
||||
assert callable(func), "'func' argument must be function"
|
||||
self._func = func
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Executing python operation %s', self._func.__name__)
|
||||
logger.info(" Executing python operation %s", self._func.__name__)
|
||||
self._func(database)
|
||||
|
||||
|
||||
|
@ -244,17 +255,17 @@ class RunSQL(Operation):
|
|||
"""
|
||||
|
||||
def __init__(self, sql):
|
||||
'''
|
||||
"""
|
||||
Initializer. The given sql argument must be a valid SQL statement or
|
||||
list of statements.
|
||||
'''
|
||||
"""
|
||||
if isinstance(sql, str):
|
||||
sql = [sql]
|
||||
assert isinstance(sql, list), "'sql' argument must be string or list of strings"
|
||||
self._sql = sql
|
||||
|
||||
def apply(self, database):
|
||||
logger.info(' Executing raw SQL operations')
|
||||
logger.info(" Executing raw SQL operations")
|
||||
for item in self._sql:
|
||||
database.raw(item)
|
||||
|
||||
|
@ -268,11 +279,11 @@ class MigrationHistory(Model):
|
|||
module_name = StringField()
|
||||
applied = DateField()
|
||||
|
||||
engine = MergeTree('applied', ('package_name', 'module_name'))
|
||||
engine = MergeTree("applied", ("package_name", "module_name"))
|
||||
|
||||
@classmethod
|
||||
def table_name(cls):
|
||||
return 'infi_clickhouse_orm_migrations'
|
||||
return "infi_clickhouse_orm_migrations"
|
||||
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
|
|
|
@ -17,7 +17,7 @@ from .engines import Merge, Distributed, Memory
|
|||
if TYPE_CHECKING:
|
||||
from clickhouse_orm.database import Database
|
||||
|
||||
logger = getLogger('clickhouse_orm')
|
||||
logger = getLogger("clickhouse_orm")
|
||||
|
||||
|
||||
class Constraint:
|
||||
|
@ -38,7 +38,7 @@ class Constraint:
|
|||
"""
|
||||
Returns the SQL statement for defining this constraint during table creation.
|
||||
"""
|
||||
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr))
|
||||
return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
|
||||
|
||||
|
||||
class Index:
|
||||
|
@ -66,8 +66,11 @@ class Index:
|
|||
"""
|
||||
Returns the SQL statement for defining this index during table creation.
|
||||
"""
|
||||
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (
|
||||
self.name, arg_to_sql(self.expr), self.type, self.granularity
|
||||
return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (
|
||||
self.name,
|
||||
arg_to_sql(self.expr),
|
||||
self.type,
|
||||
self.granularity,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -76,7 +79,7 @@ class Index:
|
|||
An index that stores extremes of the specified expression (if the expression is tuple, then it stores
|
||||
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
|
||||
"""
|
||||
return 'minmax'
|
||||
return "minmax"
|
||||
|
||||
@staticmethod
|
||||
def set(max_rows: int) -> str:
|
||||
|
@ -85,11 +88,12 @@ class Index:
|
|||
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
|
||||
on a block of data.
|
||||
"""
|
||||
return 'set(%d)' % max_rows
|
||||
return "set(%d)" % max_rows
|
||||
|
||||
@staticmethod
|
||||
def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int,
|
||||
number_of_hash_functions: int, random_seed: int) -> str:
|
||||
def ngrambf_v1(
|
||||
n: int, size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
|
||||
) -> str:
|
||||
"""
|
||||
An index that stores a Bloom filter containing all ngrams from a block of data.
|
||||
Works only with strings. Can be used for optimization of equals, like and in expressions.
|
||||
|
@ -100,13 +104,17 @@ class Index:
|
|||
- `number_of_hash_functions` — The number of hash functions used in the Bloom filter.
|
||||
- `random_seed` — The seed for Bloom filter hash functions.
|
||||
"""
|
||||
return 'ngrambf_v1(%d, %d, %d, %d)' % (
|
||||
n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
|
||||
return "ngrambf_v1(%d, %d, %d, %d)" % (
|
||||
n,
|
||||
size_of_bloom_filter_in_bytes,
|
||||
number_of_hash_functions,
|
||||
random_seed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int,
|
||||
random_seed: int) -> str:
|
||||
def tokenbf_v1(
|
||||
size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
|
||||
) -> str:
|
||||
"""
|
||||
An index that stores a Bloom filter containing string tokens. Tokens are sequences
|
||||
separated by non-alphanumeric characters.
|
||||
|
@ -116,8 +124,10 @@ class Index:
|
|||
- `number_of_hash_functions` — The number of hash functions used in the Bloom filter.
|
||||
- `random_seed` — The seed for Bloom filter hash functions.
|
||||
"""
|
||||
return 'tokenbf_v1(%d, %d, %d)' % (
|
||||
size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
|
||||
return "tokenbf_v1(%d, %d, %d)" % (
|
||||
size_of_bloom_filter_in_bytes,
|
||||
number_of_hash_functions,
|
||||
random_seed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -128,7 +138,7 @@ class Index:
|
|||
- `false_positive` - the probability (between 0 and 1) of receiving a false positive
|
||||
response from the filter
|
||||
"""
|
||||
return 'bloom_filter(%f)' % false_positive
|
||||
return "bloom_filter(%f)" % false_positive
|
||||
|
||||
|
||||
class ModelBase(type):
|
||||
|
@ -183,23 +193,23 @@ class ModelBase(type):
|
|||
_indexes=indexes,
|
||||
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
|
||||
_defaults=defaults,
|
||||
_has_funcs_as_defaults=has_funcs_as_defaults
|
||||
_has_funcs_as_defaults=has_funcs_as_defaults,
|
||||
)
|
||||
model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs)
|
||||
|
||||
# Let each field, constraint and index know its parent and its own name
|
||||
for n, obj in chain(fields, constraints.items(), indexes.items()):
|
||||
setattr(obj, 'parent', model)
|
||||
setattr(obj, 'name', n)
|
||||
setattr(obj, "parent", model)
|
||||
setattr(obj, "name", n)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'):
|
||||
def create_ad_hoc_model(cls, fields, model_name="AdHocModel"):
|
||||
# fields is a list of tuples (name, db_type)
|
||||
# Check if model exists in cache
|
||||
fields = list(fields)
|
||||
cache_key = model_name + ' ' + str(fields)
|
||||
cache_key = model_name + " " + str(fields)
|
||||
if cache_key in cls.ad_hoc_model_cache:
|
||||
return cls.ad_hoc_model_cache[cache_key]
|
||||
# Create an ad hoc model class
|
||||
|
@ -217,28 +227,25 @@ class ModelBase(type):
|
|||
import clickhouse_orm.contrib.geo.fields as geo_fields
|
||||
|
||||
# Enums
|
||||
if db_type.startswith('Enum'):
|
||||
if db_type.startswith("Enum"):
|
||||
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
|
||||
# DateTime with timezone
|
||||
if db_type.startswith('DateTime('):
|
||||
if db_type.startswith("DateTime("):
|
||||
timezone = db_type[9:-1]
|
||||
return orm_fields.DateTimeField(
|
||||
timezone=timezone[1:-1] if timezone else None
|
||||
)
|
||||
return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
|
||||
# DateTime64
|
||||
if db_type.startswith('DateTime64('):
|
||||
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')]
|
||||
if db_type.startswith("DateTime64("):
|
||||
precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
|
||||
return orm_fields.DateTime64Field(
|
||||
precision=int(precision),
|
||||
timezone=timezone[0][1:-1] if timezone else None
|
||||
precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
|
||||
)
|
||||
# Arrays
|
||||
if db_type.startswith('Array'):
|
||||
if db_type.startswith("Array"):
|
||||
inner_field = cls.create_ad_hoc_field(db_type[6:-1])
|
||||
return orm_fields.ArrayField(inner_field)
|
||||
# Tuples
|
||||
if db_type.startswith('Tuple'):
|
||||
types = [s.strip().split(' ') for s in db_type[6:-1].split(',')]
|
||||
if db_type.startswith("Tuple"):
|
||||
types = [s.strip().split(" ") for s in db_type[6:-1].split(",")]
|
||||
name_fields = []
|
||||
for i, tp in enumerate(types):
|
||||
if len(tp) == 2:
|
||||
|
@ -247,27 +254,27 @@ class ModelBase(type):
|
|||
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
|
||||
return orm_fields.TupleField(name_fields=name_fields)
|
||||
# FixedString
|
||||
if db_type.startswith('FixedString'):
|
||||
if db_type.startswith("FixedString"):
|
||||
length = int(db_type[12:-1])
|
||||
return orm_fields.FixedStringField(length)
|
||||
# Decimal / Decimal32 / Decimal64 / Decimal128
|
||||
if db_type.startswith('Decimal'):
|
||||
p = db_type.index('(')
|
||||
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')]
|
||||
field_class = getattr(orm_fields, db_type[:p] + 'Field')
|
||||
if db_type.startswith("Decimal"):
|
||||
p = db_type.index("(")
|
||||
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
|
||||
field_class = getattr(orm_fields, db_type[:p] + "Field")
|
||||
return field_class(*args)
|
||||
# Nullable
|
||||
if db_type.startswith('Nullable'):
|
||||
inner_field = cls.create_ad_hoc_field(db_type[9 : -1])
|
||||
if db_type.startswith("Nullable"):
|
||||
inner_field = cls.create_ad_hoc_field(db_type[9:-1])
|
||||
return orm_fields.NullableField(inner_field)
|
||||
# LowCardinality
|
||||
if db_type.startswith('LowCardinality'):
|
||||
inner_field = cls.create_ad_hoc_field(db_type[15 : -1])
|
||||
if db_type.startswith("LowCardinality"):
|
||||
inner_field = cls.create_ad_hoc_field(db_type[15:-1])
|
||||
return orm_fields.LowCardinalityField(inner_field)
|
||||
# Simple fields
|
||||
name = db_type + 'Field'
|
||||
name = db_type + "Field"
|
||||
if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)):
|
||||
raise NotImplementedError('No field class for %s' % db_type)
|
||||
raise NotImplementedError("No field class for %s" % db_type)
|
||||
field_class = getattr(orm_fields, name, None) or getattr(geo_fields, name, None)
|
||||
return field_class()
|
||||
|
||||
|
@ -282,6 +289,7 @@ class Model(metaclass=ModelBase):
|
|||
cpu_percent = Float32Field()
|
||||
engine = Memory()
|
||||
"""
|
||||
|
||||
_has_funcs_as_defaults: bool
|
||||
_constraints: dict[str, Constraint]
|
||||
_indexes: dict[str, Index]
|
||||
|
@ -318,7 +326,7 @@ class Model(metaclass=ModelBase):
|
|||
setattr(self, name, value)
|
||||
else:
|
||||
raise AttributeError(
|
||||
'%s does not have a field called %s' % (self.__class__.__name__, name)
|
||||
"%s does not have a field called %s" % (self.__class__.__name__, name)
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
|
@ -383,29 +391,29 @@ class Model(metaclass=ModelBase):
|
|||
"""
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
"""
|
||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
|
||||
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
|
||||
# Fields
|
||||
items = []
|
||||
for name, field in cls.fields().items():
|
||||
items.append(' %s %s' % (name, field.get_sql(db=db)))
|
||||
items.append(" %s %s" % (name, field.get_sql(db=db)))
|
||||
# Constraints
|
||||
for c in cls._constraints.values():
|
||||
items.append(' %s' % c.create_table_sql())
|
||||
items.append(" %s" % c.create_table_sql())
|
||||
# Indexes
|
||||
for i in cls._indexes.values():
|
||||
items.append(' %s' % i.create_table_sql())
|
||||
parts.append(',\n'.join(items))
|
||||
items.append(" %s" % i.create_table_sql())
|
||||
parts.append(",\n".join(items))
|
||||
# Engine
|
||||
parts.append(')')
|
||||
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
|
||||
return '\n'.join(parts)
|
||||
parts.append(")")
|
||||
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
|
||||
return "\n".join(parts)
|
||||
|
||||
@classmethod
|
||||
def drop_table_sql(cls, db: Database) -> str:
|
||||
"""
|
||||
Returns the SQL command for deleting this model's table.
|
||||
"""
|
||||
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name())
|
||||
return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
|
||||
|
||||
@classmethod
|
||||
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
|
||||
|
@ -422,7 +430,7 @@ class Model(metaclass=ModelBase):
|
|||
kwargs = {}
|
||||
for name in field_names:
|
||||
field = getattr(cls, name)
|
||||
field_timezone = getattr(field, 'timezone', None) or timezone_in_use
|
||||
field_timezone = getattr(field, "timezone", None) or timezone_in_use
|
||||
kwargs[name] = field.to_python(next(values), field_timezone)
|
||||
|
||||
obj = cls(**kwargs)
|
||||
|
@ -439,7 +447,9 @@ class Model(metaclass=ModelBase):
|
|||
"""
|
||||
data = self.__dict__
|
||||
fields = self.fields(writable=not include_readonly)
|
||||
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
|
||||
return "\t".join(
|
||||
field.to_db_string(data[name], quote=False) for name, field in fields.items()
|
||||
)
|
||||
|
||||
def to_tskv(self, include_readonly=True):
|
||||
"""
|
||||
|
@ -453,16 +463,16 @@ class Model(metaclass=ModelBase):
|
|||
parts = []
|
||||
for name, field in fields.items():
|
||||
if data[name] != NO_VALUE:
|
||||
parts.append(name + '=' + field.to_db_string(data[name], quote=False))
|
||||
return '\t'.join(parts)
|
||||
parts.append(name + "=" + field.to_db_string(data[name], quote=False))
|
||||
return "\t".join(parts)
|
||||
|
||||
def to_db_string(self) -> bytes:
|
||||
"""
|
||||
Returns the instance as a bytestring ready to be inserted into the database.
|
||||
"""
|
||||
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
|
||||
s += '\n'
|
||||
return s.encode('utf-8')
|
||||
s += "\n"
|
||||
return s.encode("utf-8")
|
||||
|
||||
def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]:
|
||||
"""
|
||||
|
@ -519,19 +529,18 @@ class Model(metaclass=ModelBase):
|
|||
|
||||
|
||||
class BufferModel(Model):
|
||||
|
||||
@classmethod
|
||||
def create_table_sql(cls, db: Database) -> str:
|
||||
"""
|
||||
Returns the SQL statement for creating a table for this model.
|
||||
"""
|
||||
parts = [
|
||||
'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (
|
||||
db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
|
||||
"CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
|
||||
% (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
|
||||
]
|
||||
engine_str = cls.engine.create_table_sql(db)
|
||||
parts.append(engine_str)
|
||||
return ' '.join(parts)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
class MergeModel(Model):
|
||||
|
@ -540,6 +549,7 @@ class MergeModel(Model):
|
|||
Predefines virtual _table column an controls that rows can't be inserted to this table type
|
||||
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
|
||||
"""
|
||||
|
||||
readonly = True
|
||||
|
||||
# Virtual fields can't be inserted into database
|
||||
|
@ -551,15 +561,16 @@ class MergeModel(Model):
|
|||
Returns the SQL statement for creating a table for this model.
|
||||
"""
|
||||
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
|
||||
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
|
||||
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
|
||||
cols = []
|
||||
for name, field in cls.fields().items():
|
||||
if name != '_table':
|
||||
cols.append(' %s %s' % (name, field.get_sql(db=db)))
|
||||
parts.append(',\n'.join(cols))
|
||||
parts.append(')')
|
||||
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
|
||||
return '\n'.join(parts)
|
||||
if name != "_table":
|
||||
cols.append(" %s %s" % (name, field.get_sql(db=db)))
|
||||
parts.append(",\n".join(cols))
|
||||
parts.append(")")
|
||||
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# TODO: base class for models that require specific engine
|
||||
|
||||
|
@ -574,8 +585,9 @@ class DistributedModel(Model):
|
|||
Sets the `Database` that this model instance belongs to.
|
||||
This is done automatically when the instance is read from the database or written to it.
|
||||
"""
|
||||
assert isinstance(self.engine, Distributed),\
|
||||
"engine must be an instance of engines.Distributed"
|
||||
assert isinstance(
|
||||
self.engine, Distributed
|
||||
), "engine must be an instance of engines.Distributed"
|
||||
super().set_database(db)
|
||||
|
||||
@classmethod
|
||||
|
@ -616,15 +628,20 @@ class DistributedModel(Model):
|
|||
return
|
||||
|
||||
# find out all the superclasses of the Model that store any data
|
||||
storage_models = [b for b in cls.__bases__ if issubclass(b, Model)
|
||||
and not issubclass(b, DistributedModel)]
|
||||
storage_models = [
|
||||
b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)
|
||||
]
|
||||
if not storage_models:
|
||||
raise TypeError("When defining Distributed engine without the table_name "
|
||||
"ensure that your model has a parent model")
|
||||
raise TypeError(
|
||||
"When defining Distributed engine without the table_name "
|
||||
"ensure that your model has a parent model"
|
||||
)
|
||||
|
||||
if len(storage_models) > 1:
|
||||
raise TypeError("When defining Distributed engine without the table_name "
|
||||
"ensure that your model has exactly one non-distributed superclass")
|
||||
raise TypeError(
|
||||
"When defining Distributed engine without the table_name "
|
||||
"ensure that your model has exactly one non-distributed superclass"
|
||||
)
|
||||
|
||||
# enable correct SQL for engine
|
||||
cls.engine.table = storage_models[0]
|
||||
|
@ -637,10 +654,12 @@ class DistributedModel(Model):
|
|||
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
|
||||
|
||||
parts = [
|
||||
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format(
|
||||
db.db_name, cls.table_name(), cls.engine.table_name),
|
||||
'ENGINE = ' + cls.engine.create_table_sql(db)]
|
||||
return '\n'.join(parts)
|
||||
"CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
|
||||
db.db_name, cls.table_name(), cls.engine.table_name
|
||||
),
|
||||
"ENGINE = " + cls.engine.create_table_sql(db),
|
||||
]
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class TemporaryModel(Model):
|
||||
|
@ -657,30 +676,31 @@ class TemporaryModel(Model):
|
|||
|
||||
https://clickhouse.com/docs/en/sql-reference/statements/create/table/#temporary-tables
|
||||
"""
|
||||
|
||||
_temporary = True
|
||||
|
||||
@classmethod
|
||||
def create_table_sql(cls, db: Database) -> str:
|
||||
assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance"
|
||||
|
||||
parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()]
|
||||
parts = ["CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (" % cls.table_name()]
|
||||
# Fields
|
||||
items = []
|
||||
for name, field in cls.fields().items():
|
||||
items.append(' %s %s' % (name, field.get_sql(db=db)))
|
||||
items.append(" %s %s" % (name, field.get_sql(db=db)))
|
||||
# Constraints
|
||||
for c in cls._constraints.values():
|
||||
items.append(' %s' % c.create_table_sql())
|
||||
items.append(" %s" % c.create_table_sql())
|
||||
# Indexes
|
||||
for i in cls._indexes.values():
|
||||
items.append(' %s' % i.create_table_sql())
|
||||
parts.append(',\n'.join(items))
|
||||
items.append(" %s" % i.create_table_sql())
|
||||
parts.append(",\n".join(items))
|
||||
# Engine
|
||||
parts.append(')')
|
||||
parts.append('ENGINE = Memory')
|
||||
return '\n'.join(parts)
|
||||
parts.append(")")
|
||||
parts.append("ENGINE = Memory")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
MODEL = TypeVar('MODEL', bound=Model)
|
||||
MODEL = TypeVar("MODEL", bound=Model)
|
||||
__all__ = get_subclass_names(locals(), (Model, Constraint, Index))
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import (
|
|||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Iterator
|
||||
Iterator,
|
||||
)
|
||||
|
||||
import pytz
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.database import Database, Page
|
||||
|
||||
MODEL = TypeVar('MODEL', bound='Model')
|
||||
MODEL = TypeVar("MODEL", bound="Model")
|
||||
|
||||
|
||||
class Operator:
|
||||
|
@ -59,9 +59,9 @@ class SimpleOperator(Operator):
|
|||
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
|
||||
field = getattr(model_cls, field_name)
|
||||
value = self._value_to_sql(field, value)
|
||||
if value == '\\N' and self._sql_for_null is not None:
|
||||
return ' '.join([field_name, self._sql_for_null])
|
||||
return ' '.join([field.name, self._sql_operator, value])
|
||||
if value == "\\N" and self._sql_for_null is not None:
|
||||
return " ".join([field_name, self._sql_for_null])
|
||||
return " ".join([field.name, self._sql_operator, value])
|
||||
|
||||
|
||||
class InOperator(Operator):
|
||||
|
@ -81,7 +81,7 @@ class InOperator(Operator):
|
|||
pass
|
||||
else:
|
||||
value = comma_join([self._value_to_sql(field, v) for v in value])
|
||||
return '%s IN (%s)' % (field.name, value)
|
||||
return "%s IN (%s)" % (field.name, value)
|
||||
|
||||
|
||||
class GlobalInOperator(Operator):
|
||||
|
@ -95,7 +95,7 @@ class GlobalInOperator(Operator):
|
|||
pass
|
||||
else:
|
||||
value = comma_join([self._value_to_sql(field, v) for v in value])
|
||||
return '%s GLOBAL IN (%s)' % (field.name, value)
|
||||
return "%s GLOBAL IN (%s)" % (field.name, value)
|
||||
|
||||
|
||||
class LikeOperator(Operator):
|
||||
|
@ -111,11 +111,11 @@ class LikeOperator(Operator):
|
|||
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
|
||||
field = getattr(model_cls, field_name)
|
||||
value = self._value_to_sql(field, value, quote=False)
|
||||
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
|
||||
value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
|
||||
pattern = self._pattern.format(value)
|
||||
if self._case_sensitive:
|
||||
return '%s LIKE \'%s\'' % (field.name, pattern)
|
||||
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern)
|
||||
return "%s LIKE '%s'" % (field.name, pattern)
|
||||
return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field.name, pattern)
|
||||
|
||||
|
||||
class IExactOperator(Operator):
|
||||
|
@ -126,7 +126,7 @@ class IExactOperator(Operator):
|
|||
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
|
||||
field = getattr(model_cls, field_name)
|
||||
value = self._value_to_sql(field, value)
|
||||
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value)
|
||||
return "lowerUTF8(%s) = lowerUTF8(%s)" % (field.name, value)
|
||||
|
||||
|
||||
class NotOperator(Operator):
|
||||
|
@ -139,7 +139,7 @@ class NotOperator(Operator):
|
|||
|
||||
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
|
||||
# Negate the base operator
|
||||
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value)
|
||||
return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
|
||||
|
||||
|
||||
class BetweenOperator(Operator):
|
||||
|
@ -154,16 +154,22 @@ class BetweenOperator(Operator):
|
|||
|
||||
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
|
||||
field = getattr(model_cls, field_name)
|
||||
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(
|
||||
str(value[0])) > 0 else None
|
||||
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(
|
||||
str(value[1])) > 0 else None
|
||||
value0 = (
|
||||
self._value_to_sql(field, value[0])
|
||||
if value[0] is not None or len(str(value[0])) > 0
|
||||
else None
|
||||
)
|
||||
value1 = (
|
||||
self._value_to_sql(field, value[1])
|
||||
if value[1] is not None or len(str(value[1])) > 0
|
||||
else None
|
||||
)
|
||||
if value0 and value1:
|
||||
return '%s BETWEEN %s AND %s' % (field.name, value0, value1)
|
||||
return "%s BETWEEN %s AND %s" % (field.name, value0, value1)
|
||||
if value0 and not value1:
|
||||
return ' '.join([field.name, '>=', value0])
|
||||
return " ".join([field.name, ">=", value0])
|
||||
if value1 and not value0:
|
||||
return ' '.join([field.name, '<=', value1])
|
||||
return " ".join([field.name, "<=", value1])
|
||||
|
||||
|
||||
# Define the set of builtin operators
|
||||
|
@ -175,24 +181,24 @@ def register_operator(name: str, sql: Operator):
|
|||
_operators[name] = sql
|
||||
|
||||
|
||||
register_operator('eq', SimpleOperator('=', 'IS NULL'))
|
||||
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL'))
|
||||
register_operator('gt', SimpleOperator('>'))
|
||||
register_operator('gte', SimpleOperator('>='))
|
||||
register_operator('lt', SimpleOperator('<'))
|
||||
register_operator('lte', SimpleOperator('<='))
|
||||
register_operator('between', BetweenOperator())
|
||||
register_operator('in', InOperator())
|
||||
register_operator('gin', GlobalInOperator())
|
||||
register_operator('not_in', NotOperator(InOperator()))
|
||||
register_operator('not_gin', NotOperator(GlobalInOperator()))
|
||||
register_operator('contains', LikeOperator('%{}%'))
|
||||
register_operator('startswith', LikeOperator('{}%'))
|
||||
register_operator('endswith', LikeOperator('%{}'))
|
||||
register_operator('icontains', LikeOperator('%{}%', False))
|
||||
register_operator('istartswith', LikeOperator('{}%', False))
|
||||
register_operator('iendswith', LikeOperator('%{}', False))
|
||||
register_operator('iexact', IExactOperator())
|
||||
register_operator("eq", SimpleOperator("=", "IS NULL"))
|
||||
register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
|
||||
register_operator("gt", SimpleOperator(">"))
|
||||
register_operator("gte", SimpleOperator(">="))
|
||||
register_operator("lt", SimpleOperator("<"))
|
||||
register_operator("lte", SimpleOperator("<="))
|
||||
register_operator("between", BetweenOperator())
|
||||
register_operator("in", InOperator())
|
||||
register_operator("gin", GlobalInOperator())
|
||||
register_operator("not_in", NotOperator(InOperator()))
|
||||
register_operator("not_gin", NotOperator(GlobalInOperator()))
|
||||
register_operator("contains", LikeOperator("%{}%"))
|
||||
register_operator("startswith", LikeOperator("{}%"))
|
||||
register_operator("endswith", LikeOperator("%{}"))
|
||||
register_operator("icontains", LikeOperator("%{}%", False))
|
||||
register_operator("istartswith", LikeOperator("{}%", False))
|
||||
register_operator("iendswith", LikeOperator("%{}", False))
|
||||
register_operator("iexact", IExactOperator())
|
||||
|
||||
|
||||
class Cond:
|
||||
|
@ -214,8 +220,8 @@ class FieldCond(Cond):
|
|||
self._operator = _operators.get(operator)
|
||||
if self._operator is None:
|
||||
# The field name contains __ like my__field
|
||||
self._field_name = field_name + '__' + operator
|
||||
self._operator = _operators['eq']
|
||||
self._field_name = field_name + "__" + operator
|
||||
self._operator = _operators["eq"]
|
||||
self._value = value
|
||||
|
||||
def to_sql(self, model_cls: type[Model]) -> str:
|
||||
|
@ -228,12 +234,13 @@ class FieldCond(Cond):
|
|||
|
||||
|
||||
class Q:
|
||||
AND_MODE = 'AND'
|
||||
OR_MODE = 'OR'
|
||||
AND_MODE = "AND"
|
||||
OR_MODE = "OR"
|
||||
|
||||
def __init__(self, *filter_funcs, **filter_fields):
|
||||
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in
|
||||
filter_fields.items()]
|
||||
self._conds = list(filter_funcs) + [
|
||||
self._build_cond(k, v) for k, v in filter_fields.items()
|
||||
]
|
||||
self._children = []
|
||||
self._negate = False
|
||||
self._mode = self.AND_MODE
|
||||
|
@ -263,10 +270,10 @@ class Q:
|
|||
return q
|
||||
|
||||
def _build_cond(self, key, value):
|
||||
if '__' in key:
|
||||
field_name, operator = key.rsplit('__', 1)
|
||||
if "__" in key:
|
||||
field_name, operator = key.rsplit("__", 1)
|
||||
else:
|
||||
field_name, operator = key, 'eq'
|
||||
field_name, operator = key, "eq"
|
||||
return FieldCond(field_name, operator, value)
|
||||
|
||||
def to_sql(self, model_cls: type[Model]) -> str:
|
||||
|
@ -280,16 +287,16 @@ class Q:
|
|||
|
||||
if not condition_sql:
|
||||
# Empty Q() object returns everything
|
||||
sql = '1'
|
||||
sql = "1"
|
||||
elif len(condition_sql) == 1:
|
||||
# Skip not needed brackets over single condition
|
||||
sql = condition_sql[0]
|
||||
else:
|
||||
# Each condition must be enclosed in brackets, or order of operations may be wrong
|
||||
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql)
|
||||
sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
|
||||
|
||||
if self._negate:
|
||||
sql = 'NOT (%s)' % sql
|
||||
sql = "NOT (%s)" % sql
|
||||
|
||||
return sql
|
||||
|
||||
|
@ -400,16 +407,16 @@ class QuerySet(Generic[MODEL]):
|
|||
def __getitem__(self, s):
|
||||
if isinstance(s, int):
|
||||
# Single index
|
||||
assert s >= 0, 'negative indexes are not supported'
|
||||
assert s >= 0, "negative indexes are not supported"
|
||||
queryset = self._clone()
|
||||
queryset._limits = (s, 1)
|
||||
return next(iter(queryset))
|
||||
# Slice
|
||||
assert s.step in (None, 1), 'step is not supported in slices'
|
||||
assert s.step in (None, 1), "step is not supported in slices"
|
||||
start = s.start or 0
|
||||
stop = s.stop or 2 ** 63 - 1
|
||||
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
|
||||
assert start <= stop, 'start of slice cannot be smaller than its end'
|
||||
stop = s.stop or 2**63 - 1
|
||||
assert start >= 0 and stop >= 0, "negative indexes are not supported"
|
||||
assert start <= stop, "start of slice cannot be smaller than its end"
|
||||
queryset = self._clone()
|
||||
queryset._limits = (start, stop - start)
|
||||
return queryset
|
||||
|
@ -425,7 +432,7 @@ class QuerySet(Generic[MODEL]):
|
|||
offset_limit = (0, offset_limit)
|
||||
offset = offset_limit[0]
|
||||
limit = offset_limit[1]
|
||||
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
|
||||
assert offset >= 0 and limit >= 0, "negative limits are not supported"
|
||||
queryset = self._clone()
|
||||
queryset._limit_by = (offset, limit)
|
||||
queryset._limit_by_fields = fields_or_expr
|
||||
|
@ -435,44 +442,44 @@ class QuerySet(Generic[MODEL]):
|
|||
"""
|
||||
Returns the selected fields or expressions as a SQL string.
|
||||
"""
|
||||
fields = '*'
|
||||
fields = "*"
|
||||
if self._fields:
|
||||
fields = comma_join('`%s`' % field for field in self._fields)
|
||||
fields = comma_join("`%s`" % field for field in self._fields)
|
||||
return fields
|
||||
|
||||
def as_sql(self) -> str:
|
||||
"""
|
||||
Returns the whole query as a SQL string.
|
||||
"""
|
||||
distinct = 'DISTINCT ' if self._distinct else ''
|
||||
final = ' FINAL' if self._final else ''
|
||||
table_name = '`%s`' % self._model_cls.table_name()
|
||||
distinct = "DISTINCT " if self._distinct else ""
|
||||
final = " FINAL" if self._final else ""
|
||||
table_name = "`%s`" % self._model_cls.table_name()
|
||||
if self._model_cls.is_system_model():
|
||||
table_name = '`system`.' + table_name
|
||||
table_name = "`system`." + table_name
|
||||
params = (distinct, self.select_fields_as_sql(), table_name, final)
|
||||
sql = 'SELECT %s%s\nFROM %s%s' % params
|
||||
sql = "SELECT %s%s\nFROM %s%s" % params
|
||||
|
||||
if self._prewhere_q and not self._prewhere_q.is_empty:
|
||||
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True)
|
||||
sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
|
||||
|
||||
if self._where_q and not self._where_q.is_empty:
|
||||
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
|
||||
sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
|
||||
|
||||
if self._grouping_fields:
|
||||
sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields)
|
||||
sql += "\nGROUP BY %s" % comma_join("%s" % field for field in self._grouping_fields)
|
||||
|
||||
if self._grouping_with_totals:
|
||||
sql += ' WITH TOTALS'
|
||||
sql += " WITH TOTALS"
|
||||
|
||||
if self._order_by:
|
||||
sql += '\nORDER BY ' + self.order_by_as_sql()
|
||||
sql += "\nORDER BY " + self.order_by_as_sql()
|
||||
|
||||
if self._limit_by:
|
||||
sql += '\nLIMIT %d, %d' % self._limit_by
|
||||
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
|
||||
sql += "\nLIMIT %d, %d" % self._limit_by
|
||||
sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
|
||||
|
||||
if self._limits:
|
||||
sql += '\nLIMIT %d, %d' % self._limits
|
||||
sql += "\nLIMIT %d, %d" % self._limits
|
||||
|
||||
return sql
|
||||
|
||||
|
@ -480,10 +487,12 @@ class QuerySet(Generic[MODEL]):
|
|||
"""
|
||||
Returns the contents of the query's `ORDER BY` clause as a string.
|
||||
"""
|
||||
return comma_join([
|
||||
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field)
|
||||
for field in self._order_by
|
||||
])
|
||||
return comma_join(
|
||||
[
|
||||
"%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
|
||||
for field in self._order_by
|
||||
]
|
||||
)
|
||||
|
||||
def conditions_as_sql(self, prewhere=False) -> str:
|
||||
"""
|
||||
|
@ -498,7 +507,7 @@ class QuerySet(Generic[MODEL]):
|
|||
"""
|
||||
if self._distinct or self._limits:
|
||||
# Use a subquery, since a simple count won't be accurate
|
||||
sql = 'SELECT count() FROM (%s)' % self.as_sql()
|
||||
sql = "SELECT count() FROM (%s)" % self.as_sql()
|
||||
raw = self._database.raw(sql)
|
||||
return int(raw) if raw else 0
|
||||
|
||||
|
@ -527,8 +536,8 @@ class QuerySet(Generic[MODEL]):
|
|||
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
|
||||
from clickhouse_orm.funcs import F
|
||||
|
||||
inverse = kwargs.pop('_inverse', False)
|
||||
prewhere = kwargs.pop('prewhere', False)
|
||||
inverse = kwargs.pop("_inverse", False)
|
||||
prewhere = kwargs.pop("prewhere", False)
|
||||
|
||||
queryset = self._clone()
|
||||
|
||||
|
@ -588,14 +597,14 @@ class QuerySet(Generic[MODEL]):
|
|||
if page_num == -1:
|
||||
page_num = pages_total
|
||||
elif page_num < 1:
|
||||
raise ValueError('Invalid page number: %d' % page_num)
|
||||
raise ValueError("Invalid page number: %d" % page_num)
|
||||
offset = (page_num - 1) * page_size
|
||||
return Page(
|
||||
objects=list(self[offset: offset + page_size]),
|
||||
objects=list(self[offset : offset + page_size]),
|
||||
number_of_objects=count,
|
||||
pages_total=pages_total,
|
||||
number=page_num,
|
||||
page_size=page_size
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def distinct(self) -> "QuerySet[MODEL]":
|
||||
|
@ -616,8 +625,8 @@ class QuerySet(Generic[MODEL]):
|
|||
|
||||
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
|
||||
raise TypeError(
|
||||
'final() method can be used only with the CollapsingMergeTree'
|
||||
' and ReplacingMergeTree engines'
|
||||
"final() method can be used only with the CollapsingMergeTree"
|
||||
" and ReplacingMergeTree engines"
|
||||
)
|
||||
|
||||
queryset = self._clone()
|
||||
|
@ -631,7 +640,7 @@ class QuerySet(Generic[MODEL]):
|
|||
"""
|
||||
self._verify_mutation_allowed()
|
||||
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
|
||||
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions)
|
||||
sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
|
||||
self._database.raw(sql)
|
||||
return self
|
||||
|
||||
|
@ -641,12 +650,14 @@ class QuerySet(Generic[MODEL]):
|
|||
Keyword arguments specify the field names and expressions to use for the update.
|
||||
Note that ClickHouse performs updates in the background, so they are not immediate.
|
||||
"""
|
||||
assert kwargs, 'No fields specified for update'
|
||||
assert kwargs, "No fields specified for update"
|
||||
self._verify_mutation_allowed()
|
||||
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
|
||||
fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
|
||||
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
|
||||
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (
|
||||
self._model_cls.table_name(), fields, conditions
|
||||
sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (
|
||||
self._model_cls.table_name(),
|
||||
fields,
|
||||
conditions,
|
||||
)
|
||||
self._database.raw(sql)
|
||||
return self
|
||||
|
@ -655,10 +666,10 @@ class QuerySet(Generic[MODEL]):
|
|||
"""
|
||||
Checks that the queryset's state allows mutations. Raises an AssertionError if not.
|
||||
"""
|
||||
assert not self._limits, 'Mutations are not allowed after slicing the queryset'
|
||||
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
|
||||
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
|
||||
assert not self._final, 'Mutations are not allowed after calling final()'
|
||||
assert not self._limits, "Mutations are not allowed after slicing the queryset"
|
||||
assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
|
||||
assert not self._distinct, "Mutations are not allowed after calling distinct()"
|
||||
assert not self._final, "Mutations are not allowed after calling final()"
|
||||
|
||||
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
|
||||
"""
|
||||
|
@ -687,7 +698,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
self,
|
||||
base_queryset: QuerySet,
|
||||
grouping_fields: tuple[Any],
|
||||
calculated_fields: dict[str, str]
|
||||
calculated_fields: dict[str, str],
|
||||
):
|
||||
"""
|
||||
Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`.
|
||||
|
@ -705,7 +716,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
At least one calculated field is required.
|
||||
"""
|
||||
super().__init__(base_queryset._model_cls, base_queryset._database)
|
||||
assert calculated_fields, 'No calculated fields specified for aggregation'
|
||||
assert calculated_fields, "No calculated fields specified for aggregation"
|
||||
self._fields = grouping_fields
|
||||
self._grouping_fields = grouping_fields
|
||||
self._calculated_fields = calculated_fields
|
||||
|
@ -734,8 +745,9 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
created with.
|
||||
"""
|
||||
for name in args:
|
||||
assert name in self._fields or name in self._calculated_fields, \
|
||||
'Cannot group by `%s` since it is not included in the query' % name
|
||||
assert name in self._fields or name in self._calculated_fields, (
|
||||
"Cannot group by `%s` since it is not included in the query" % name
|
||||
)
|
||||
queryset = copy(self)
|
||||
queryset._grouping_fields = args
|
||||
return queryset
|
||||
|
@ -750,14 +762,16 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
"""
|
||||
This method is not supported on `AggregateQuerySet`.
|
||||
"""
|
||||
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet')
|
||||
raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
|
||||
|
||||
def select_fields_as_sql(self) -> str:
|
||||
"""
|
||||
Returns the selected fields or expressions as a SQL string.
|
||||
"""
|
||||
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
|
||||
self._calculated_fields.items()])
|
||||
return comma_join(
|
||||
[str(f) for f in self._fields]
|
||||
+ ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[Model]:
|
||||
"""
|
||||
|
@ -778,7 +792,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
"""
|
||||
Returns the number of rows after aggregation.
|
||||
"""
|
||||
sql = 'SELECT count() FROM (%s)' % self.as_sql()
|
||||
sql = "SELECT count() FROM (%s)" % self.as_sql()
|
||||
raw = self._database.raw(sql)
|
||||
if isinstance(raw, CoroutineType):
|
||||
return raw
|
||||
|
@ -795,7 +809,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
|
|||
return queryset
|
||||
|
||||
def _verify_mutation_allowed(self):
|
||||
raise AssertionError('Cannot mutate an AggregateQuerySet')
|
||||
raise AssertionError("Cannot mutate an AggregateQuerySet")
|
||||
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
|
|
|
@ -2,8 +2,8 @@ import uuid
|
|||
from typing import Optional
|
||||
from contextvars import ContextVar, Token
|
||||
|
||||
ctx_session_id: ContextVar[str] = ContextVar('ck.session_id')
|
||||
ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout')
|
||||
ctx_session_id: ContextVar[str] = ContextVar("ck.session_id")
|
||||
ctx_session_timeout: ContextVar[float] = ContextVar("ck.session_timeout")
|
||||
|
||||
|
||||
class SessionContext:
|
||||
|
|
|
@ -16,12 +16,15 @@ class SystemPart(Model):
|
|||
This model operates only fields, described in the reference. Other fields are ignored.
|
||||
https://clickhouse.tech/docs/en/system_tables/system.parts/
|
||||
"""
|
||||
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
|
||||
|
||||
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
|
||||
|
||||
_readonly = True
|
||||
_system = True
|
||||
|
||||
database = StringField() # Name of the database where the table that this part belongs to is located.
|
||||
database = (
|
||||
StringField()
|
||||
) # Name of the database where the table that this part belongs to is located.
|
||||
table = StringField() # Name of the table that this part belongs to.
|
||||
engine = StringField() # Name of the table engine, without parameters.
|
||||
partition = StringField() # Name of the partition, in the format YYYYMM.
|
||||
|
@ -43,7 +46,9 @@ class SystemPart(Model):
|
|||
|
||||
# Time the directory with the part was modified. Usually corresponds to the part's creation time.
|
||||
modification_time = DateTimeField()
|
||||
remove_time = DateTimeField() # For inactive parts only - the time when the part became inactive.
|
||||
remove_time = (
|
||||
DateTimeField()
|
||||
) # For inactive parts only - the time when the part became inactive.
|
||||
|
||||
# The number of places where the part is used. A value greater than 2 indicates
|
||||
# that this part participates in queries or merges.
|
||||
|
@ -51,12 +56,13 @@ class SystemPart(Model):
|
|||
|
||||
@classmethod
|
||||
def table_name(cls):
|
||||
return 'parts'
|
||||
return "parts"
|
||||
|
||||
"""
|
||||
Next methods return SQL for some operations, which can be done with partitions
|
||||
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
|
||||
"""
|
||||
|
||||
def _partition_operation_sql(self, operation, settings=None, from_part=None):
|
||||
"""
|
||||
Performs some operation over partition
|
||||
|
@ -68,9 +74,16 @@ class SystemPart(Model):
|
|||
Returns: Operation execution result
|
||||
"""
|
||||
operation = operation.upper()
|
||||
assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(self.OPERATIONS)
|
||||
assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(
|
||||
self.OPERATIONS
|
||||
)
|
||||
|
||||
sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (self._database.db_name, self.table, operation, self.partition)
|
||||
sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (
|
||||
self._database.db_name,
|
||||
self.table,
|
||||
operation,
|
||||
self.partition,
|
||||
)
|
||||
if from_part is not None:
|
||||
sql += " FROM %s" % from_part
|
||||
self._database.raw(sql, settings=settings, stream=False)
|
||||
|
@ -83,7 +96,7 @@ class SystemPart(Model):
|
|||
|
||||
Returns: SQL Query
|
||||
"""
|
||||
return self._partition_operation_sql('DETACH', settings=settings)
|
||||
return self._partition_operation_sql("DETACH", settings=settings)
|
||||
|
||||
def drop(self, settings=None):
|
||||
"""
|
||||
|
@ -93,7 +106,7 @@ class SystemPart(Model):
|
|||
|
||||
Returns: SQL Query
|
||||
"""
|
||||
return self._partition_operation_sql('DROP', settings=settings)
|
||||
return self._partition_operation_sql("DROP", settings=settings)
|
||||
|
||||
def attach(self, settings=None):
|
||||
"""
|
||||
|
@ -103,7 +116,7 @@ class SystemPart(Model):
|
|||
|
||||
Returns: SQL Query
|
||||
"""
|
||||
return self._partition_operation_sql('ATTACH', settings=settings)
|
||||
return self._partition_operation_sql("ATTACH", settings=settings)
|
||||
|
||||
def freeze(self, settings=None):
|
||||
"""
|
||||
|
@ -113,7 +126,7 @@ class SystemPart(Model):
|
|||
|
||||
Returns: SQL Query
|
||||
"""
|
||||
return self._partition_operation_sql('FREEZE', settings=settings)
|
||||
return self._partition_operation_sql("FREEZE", settings=settings)
|
||||
|
||||
def fetch(self, zookeeper_path, settings=None):
|
||||
"""
|
||||
|
@ -124,7 +137,7 @@ class SystemPart(Model):
|
|||
|
||||
Returns: SQL Query
|
||||
"""
|
||||
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path)
|
||||
return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
|
||||
|
||||
@classmethod
|
||||
def get(cls, database, conditions=""):
|
||||
|
@ -140,9 +153,12 @@ class SystemPart(Model):
|
|||
assert isinstance(conditions, str), "conditions must be a string"
|
||||
if conditions:
|
||||
conditions += " AND"
|
||||
field_names = ','.join(cls.fields())
|
||||
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" %
|
||||
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls)
|
||||
field_names = ",".join(cls.fields())
|
||||
return database.select(
|
||||
"SELECT %s FROM `system`.%s WHERE %s database='%s'"
|
||||
% (field_names, cls.table_name(), conditions, database.db_name),
|
||||
model_class=cls,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_active(cls, database, conditions=""):
|
||||
|
@ -155,8 +171,8 @@ class SystemPart(Model):
|
|||
Returns: A list of SystemPart objects
|
||||
"""
|
||||
if conditions:
|
||||
conditions += ' AND '
|
||||
conditions += 'active'
|
||||
conditions += " AND "
|
||||
conditions += "active"
|
||||
return SystemPart.get(database, conditions=conditions)
|
||||
|
||||
|
||||
|
|
|
@ -10,10 +10,10 @@ SPECIAL_CHARS = {
|
|||
"\t": "\\t",
|
||||
"\0": "\\0",
|
||||
"\\": "\\\\",
|
||||
"'": "\\'"
|
||||
"'": "\\'",
|
||||
}
|
||||
|
||||
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
|
||||
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
||||
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
|
||||
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
|
||||
|
||||
|
@ -36,11 +36,11 @@ def escape(value, quote=True):
|
|||
|
||||
|
||||
def unescape(value):
|
||||
return codecs.escape_decode(value)[0].decode('utf-8')
|
||||
return codecs.escape_decode(value)[0].decode("utf-8")
|
||||
|
||||
|
||||
def string_or_func(obj):
|
||||
return obj.to_sql() if hasattr(obj, 'to_sql') else obj
|
||||
return obj.to_sql() if hasattr(obj, "to_sql") else obj
|
||||
|
||||
|
||||
def arg_to_sql(arg):
|
||||
|
@ -50,6 +50,7 @@ def arg_to_sql(arg):
|
|||
None, numbers, timezones, arrays/iterables.
|
||||
"""
|
||||
from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
|
||||
|
||||
if isinstance(arg, F):
|
||||
return arg.to_sql()
|
||||
if isinstance(arg, Field):
|
||||
|
@ -67,22 +68,22 @@ def arg_to_sql(arg):
|
|||
if isinstance(arg, tzinfo):
|
||||
return StringField().to_db_string(arg.tzname(None))
|
||||
if arg is None:
|
||||
return 'NULL'
|
||||
return "NULL"
|
||||
if isinstance(arg, QuerySet):
|
||||
return "(%s)" % arg
|
||||
if isinstance(arg, tuple):
|
||||
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')'
|
||||
return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
|
||||
if is_iterable(arg):
|
||||
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']'
|
||||
return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
|
||||
return str(arg)
|
||||
|
||||
|
||||
def parse_tsv(line):
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode()
|
||||
if line and line[-1] == '\n':
|
||||
if line and line[-1] == "\n":
|
||||
line = line[:-1]
|
||||
return [unescape(value) for value in line.split(str('\t'))]
|
||||
return [unescape(value) for value in line.split(str("\t"))]
|
||||
|
||||
|
||||
def parse_array(array_string):
|
||||
|
@ -92,17 +93,17 @@ def parse_array(array_string):
|
|||
"(1,2,3)" ==> [1, 2, 3]
|
||||
"""
|
||||
# Sanity check
|
||||
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])':
|
||||
if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
|
||||
raise ValueError('Invalid array string: "%s"' % array_string)
|
||||
# Drop opening brace
|
||||
array_string = array_string[1:]
|
||||
# Go over the string, lopping off each value at the beginning until nothing is left
|
||||
values = []
|
||||
while True:
|
||||
if array_string in '])':
|
||||
if array_string in "])":
|
||||
# End of array
|
||||
return values
|
||||
elif array_string[0] in ', ':
|
||||
elif array_string[0] in ", ":
|
||||
# In between values
|
||||
array_string = array_string[1:]
|
||||
elif array_string[0] == "'":
|
||||
|
@ -110,13 +111,13 @@ def parse_array(array_string):
|
|||
match = re.search(r"[^\\]'", array_string)
|
||||
if match is None:
|
||||
raise ValueError('Missing closing quote: "%s"' % array_string)
|
||||
values.append(array_string[1: match.start() + 1])
|
||||
array_string = array_string[match.end():]
|
||||
values.append(array_string[1 : match.start() + 1])
|
||||
array_string = array_string[match.end() :]
|
||||
else:
|
||||
# Start of non-quoted value, find its end
|
||||
match = re.search(r",|\]|\)", array_string)
|
||||
values.append(array_string[0: match.start()])
|
||||
array_string = array_string[match.end() - 1:]
|
||||
values.append(array_string[0 : match.start()])
|
||||
array_string = array_string[match.end() - 1 :]
|
||||
|
||||
|
||||
def import_submodules(package_name):
|
||||
|
@ -124,9 +125,10 @@ def import_submodules(package_name):
|
|||
Import all submodules of a module.
|
||||
"""
|
||||
import importlib, pkgutil
|
||||
|
||||
package = importlib.import_module(package_name)
|
||||
return {
|
||||
name: importlib.import_module(package_name + '.' + name)
|
||||
name: importlib.import_module(package_name + "." + name)
|
||||
for _, name, _ in pkgutil.iter_modules(package.__path__)
|
||||
}
|
||||
|
||||
|
@ -136,9 +138,9 @@ def comma_join(items, stringify=False):
|
|||
Joins an iterable of strings with commas.
|
||||
"""
|
||||
if stringify:
|
||||
return ', '.join(str(item) for item in items)
|
||||
return ", ".join(str(item) for item in items)
|
||||
else:
|
||||
return ', '.join(items)
|
||||
return ", ".join(items)
|
||||
|
||||
|
||||
def is_iterable(obj):
|
||||
|
@ -154,6 +156,7 @@ def is_iterable(obj):
|
|||
|
||||
def get_subclass_names(locals, base_class):
|
||||
from inspect import isclass
|
||||
|
||||
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
|
||||
|
||||
|
||||
|
@ -164,7 +167,7 @@ class NoValue:
|
|||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return 'NO_VALUE'
|
||||
return "NO_VALUE"
|
||||
|
||||
|
||||
NO_VALUE = NoValue()
|
||||
|
|
Loading…
Reference in New Issue
Block a user