Migrate code style to Black

This commit is contained in:
sw 2022-06-04 21:25:34 +08:00
parent df2d778919
commit d22683f28c
13 changed files with 1188 additions and 1548 deletions

View File

@ -1,526 +0,0 @@
Class Reference
===============
infi.clickhouse_orm.database
----------------------------
### Database
#### Database(db_name, db_url="http://localhost:8123/", username=None, password=None, readonly=False)
Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist.
- `db_name`: name of the database to connect to.
- `db_url`: URL of the ClickHouse server.
- `username`: optional connection credentials.
- `password`: optional connection credentials.
- `readonly`: use a read-only connection.
#### count(model_class, conditions=None)
Counts the number of records in the model's table.
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
#### create_database()
Creates the database on the ClickHouse server if it does not already exist.
#### create_table(model_class)
Creates a table for the given model class, if it does not exist already.
#### drop_database()
Deletes the database on the ClickHouse server.
#### drop_table(model_class)
Drops the database table of the given model class, if it exists.
#### insert(model_instances, batch_size=1000)
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
#### migrate(migrations_package_name, up_to=9999)
Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package
containing the migrations.
- `up_to` - number of the last migration to apply.
#### paginate(model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None)
Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause).
- `page_num`: the page number (1-based), or -1 to get the last page.
- `page_size`: number of records to return per page.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
- `settings`: query settings to send as HTTP GET parameters
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
#### raw(query, settings=None, stream=False)
Performs a query and returns its output as text.
- `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
#### select(query, model_class=None, settings=None)
Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
### DatabaseException
Extends Exception
Raised when a database operation fails.
infi.clickhouse_orm.models
--------------------------
### Model
A base class for ORM models.
#### Model(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### Model.create_table_sql(db)
Returns the SQL command for creating a table for this model.
#### Model.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### Model.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
- `database`: if given, sets the database that this instance belongs to.
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### Model.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### Model.table_name()
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
### BufferModel
Extends Model
#### BufferModel(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### BufferModel.create_table_sql(db)
Returns the SQL command for creating a table for this model.
#### BufferModel.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### BufferModel.from_tsv(line, field_names=None, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
- `database`: if given, sets the database that this instance belongs to.
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### BufferModel.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### BufferModel.table_name()
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
infi.clickhouse_orm.fields
--------------------------
### Field
Abstract base class for all field types.
#### Field(default=None, alias=None, materialized=None)
### StringField
Extends Field
#### StringField(default=None, alias=None, materialized=None)
### DateField
Extends Field
#### DateField(default=None, alias=None, materialized=None)
### DateTimeField
Extends Field
#### DateTimeField(default=None, alias=None, materialized=None)
### BaseIntField
Extends Field
Abstract base class for all integer-type fields.
#### BaseIntField(default=None, alias=None, materialized=None)
### BaseFloatField
Extends Field
Abstract base class for all float-type fields.
#### BaseFloatField(default=None, alias=None, materialized=None)
### BaseEnumField
Extends Field
Abstract base class for all enum-type fields.
#### BaseEnumField(enum_cls, default=None, alias=None, materialized=None)
### ArrayField
Extends Field
#### ArrayField(inner_field, default=None, alias=None, materialized=None)
### FixedStringField
Extends StringField
#### FixedStringField(length, default=None, alias=None, materialized=None)
### UInt8Field
Extends BaseIntField
#### UInt8Field(default=None, alias=None, materialized=None)
### UInt16Field
Extends BaseIntField
#### UInt16Field(default=None, alias=None, materialized=None)
### UInt32Field
Extends BaseIntField
#### UInt32Field(default=None, alias=None, materialized=None)
### UInt64Field
Extends BaseIntField
#### UInt64Field(default=None, alias=None, materialized=None)
### Int8Field
Extends BaseIntField
#### Int8Field(default=None, alias=None, materialized=None)
### Int16Field
Extends BaseIntField
#### Int16Field(default=None, alias=None, materialized=None)
### Int32Field
Extends BaseIntField
#### Int32Field(default=None, alias=None, materialized=None)
### Int64Field
Extends BaseIntField
#### Int64Field(default=None, alias=None, materialized=None)
### Float32Field
Extends BaseFloatField
#### Float32Field(default=None, alias=None, materialized=None)
### Float64Field
Extends BaseFloatField
#### Float64Field(default=None, alias=None, materialized=None)
### Enum8Field
Extends BaseEnumField
#### Enum8Field(enum_cls, default=None, alias=None, materialized=None)
### Enum16Field
Extends BaseEnumField
#### Enum16Field(enum_cls, default=None, alias=None, materialized=None)
infi.clickhouse_orm.engines
---------------------------
### Engine
### TinyLog
Extends Engine
### Log
Extends Engine
### Memory
Extends Engine
### MergeTree
Extends Engine
#### MergeTree(date_col, key_cols, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### Buffer
Extends Engine
Here we define Buffer engine
Read more here https://clickhouse.tech/reference_en.html#Buffer
#### Buffer(main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, min_bytes=10000000, max_bytes=100000000)
### CollapsingMergeTree
Extends MergeTree
#### CollapsingMergeTree(date_col, key_cols, sign_col, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### SummingMergeTree
Extends MergeTree
#### SummingMergeTree(date_col, key_cols, summing_cols=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
### ReplacingMergeTree
Extends MergeTree
#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
infi.clickhouse_orm.query
-------------------------
### QuerySet
#### QuerySet(model_cls, database)
#### conditions_as_sql(prewhere=True)
Return the contents of the queryset's WHERE or `PREWHERE` clause.
#### count()
Returns the number of matching model instances.
#### exclude(**kwargs)
Returns a new QuerySet instance that excludes all rows matching the conditions.
#### filter(**kwargs)
Returns a new QuerySet instance that includes only rows matching the conditions.
#### only(*field_names)
Limit the query to return only the specified field names.
Useful when there are large fields that are not needed,
or for creating a subquery to use with an IN operator.
#### order_by(*field_names)
Returns a new QuerySet instance with the ordering changed.
#### order_by_as_sql()
Return the contents of the queryset's ORDER BY clause.
#### query()
Return the the queryset as SQL.

View File

@ -15,6 +15,7 @@ from clickhouse_orm.database import Database, ServerError, DatabaseException, lo
# pylint: disable=C0116
class AioDatabase(Database):
_client_class = httpx.AsyncClient
@ -25,7 +26,7 @@ class AioDatabase(Database):
if self._readonly:
if not self.db_exists:
raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection'
"Database does not exist, and cannot be created under readonly connection"
)
self.connection_readonly = await self._is_connection_readonly()
self.readonly = True
@ -44,10 +45,7 @@ class AioDatabase(Database):
await self.request_session.aclose()
async def _send(
self,
data: str | bytes | AsyncGenerator,
settings: dict = None,
stream: bool = False
self, data: str | bytes | AsyncGenerator, settings: dict = None, stream: bool = False
):
r = await super()._send(data, settings, stream)
if r.status_code != 200:
@ -55,11 +53,7 @@ class AioDatabase(Database):
raise ServerError(r.text)
return r
async def count(
self,
model_class: type[MODEL],
conditions=None
) -> int:
async def count(self, model_class: type[MODEL], conditions=None) -> int:
"""
Counts the number of records in the model's table.
@ -70,14 +64,14 @@ class AioDatabase(Database):
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
"The AioDatabase object must execute the init method before it can be used"
)
query = 'SELECT count() FROM $table'
query = "SELECT count() FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += " WHERE " + str(conditions)
query = self._substitute(query, model_class)
r = await self._send(query)
return int(r.text) if r.text else 0
@ -86,14 +80,14 @@ class AioDatabase(Database):
"""
Creates the database on the ClickHouse server if it does not already exist.
"""
await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
await self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True
async def drop_database(self):
"""
Deletes the database on the ClickHouse server.
"""
await self._send('DROP DATABASE `%s`' % self.db_name)
await self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False
async def create_table(self, model_class: type[MODEL]) -> None:
@ -102,7 +96,7 @@ class AioDatabase(Database):
"""
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
"The AioDatabase object must execute the init method before it can be used"
)
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
@ -110,7 +104,7 @@ class AioDatabase(Database):
raise DatabaseException(
"Creating a temporary table must be within the lifetime of a session "
)
if getattr(model_class, 'engine') is None:
if getattr(model_class, "engine") is None:
raise DatabaseException(f"%s class must define an engine" % model_class.__name__)
await self._send(model_class.create_table_sql(self))
@ -121,7 +115,7 @@ class AioDatabase(Database):
"""
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
"The AioDatabase object must execute the init method before it can be used"
)
await self._send(model_class.create_temporary_table_sql(self, table_name))
@ -132,7 +126,7 @@ class AioDatabase(Database):
"""
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
"The AioDatabase object must execute the init method before it can be used"
)
if model_class.is_system_model():
@ -146,18 +140,14 @@ class AioDatabase(Database):
"""
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
"The AioDatabase object must execute the init method before it can be used"
)
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = await self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1'
return r.text.strip() == "1"
async def get_model_for_table(
self,
table_name: str,
system_table: bool = False
):
async def get_model_for_table(self, table_name: str, system_table: bool = False):
"""
Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class,
@ -166,7 +156,7 @@ class AioDatabase(Database):
- `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database
"""
db_name = 'system' if system_table else self.db_name
db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = await self._send(sql)
fields = [parse_tsv(line)[:2] async for line in lines.aiter_lines()]
@ -192,14 +182,13 @@ class AioDatabase(Database):
if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join(
['`%s`' % name for name in first_instance.fields(writable=True)])
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
async def gen():
buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8'))
buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self)
buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size
@ -217,13 +206,11 @@ class AioDatabase(Database):
# Return any remaining lines in partial batch
if lines:
yield buf.getvalue()
await self._send(gen())
async def select(
self,
query: str,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
) -> AsyncGenerator[MODEL, None]:
"""
Performs a query and returns a generator of model instances.
@ -233,7 +220,7 @@ class AioDatabase(Database):
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
"""
query += ' FORMAT TabSeparatedWithNamesAndTypes'
query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class)
r = await self._send(query, settings, True)
try:
@ -245,7 +232,8 @@ class AioDatabase(Database):
elif not field_types:
field_types = parse_tsv(line)
model_class = model_class or ModelBase.create_ad_hoc_model(
zip(field_names, field_types))
zip(field_names, field_types)
)
elif line.strip():
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
except StopIteration:
@ -271,7 +259,7 @@ class AioDatabase(Database):
page_num: int = 1,
page_size: int = 100,
conditions=None,
settings: Optional[dict] = None
settings: Optional[dict] = None,
):
"""
Selects records and returns a single page of model instances.
@ -294,22 +282,22 @@ class AioDatabase(Database):
if page_num == -1:
page_num = max(pages_total, 1)
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table'
query = "SELECT * FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size)
query += " WHERE " + str(conditions)
query += " ORDER BY %s" % order_by
query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class)
return Page(
objects=[r async for r in self.select(query, model_class, settings)] if count else [],
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
async def migrate(self, migrations_package_name, up_to=9999):
@ -322,19 +310,23 @@ class AioDatabase(Database):
"""
from ..migrations import MigrationHistory
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
applied_migrations = await self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name)
logger.info("Applying migration %s...", name)
for operation in modules[name].operations:
operation.apply(self)
await self.insert([MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today()
)])
await self.insert(
[
MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today(),
)
]
)
if int(name[:4]) >= up_to:
break
@ -342,28 +334,28 @@ class AioDatabase(Database):
r = await self._send(
"SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name
)
return r.text.strip() == '1'
return r.text.strip() == "1"
async def _is_connection_readonly(self):
r = await self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0'
return r.text.strip() != "0"
async def _get_server_timezone(self):
try:
r = await self._send('SELECT timezone()')
r = await self._send("SELECT timezone()")
return pytz.timezone(r.text.strip())
except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
return pytz.utc
async def _get_server_version(self, as_tuple=True):
try:
r = await self._send('SELECT version();')
r = await self._send("SELECT version();")
ver = r.text
except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
ver = "1.1.0"
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
async def _get_applied_migrations(self, migrations_package_name):
from ..migrations import MigrationHistory

View File

@ -11,10 +11,10 @@ class Point:
self.y = float(y)
def __repr__(self):
return f'<Point x={self.x} y={self.y}>'
return f"<Point x={self.x} y={self.y}>"
def to_db_string(self):
return f'({self.x},{self.y})'
return f"({self.x},{self.y})"
class Ring:
@ -29,16 +29,16 @@ class Ring:
return len(self.array)
def __repr__(self):
return f'<Ring {self.to_db_string()}>'
return f"<Ring {self.to_db_string()}>"
def to_db_string(self):
return f'[{",".join(pt.to_db_string() for pt in self.array)}]'
def parse_point(array_string: str) -> Point:
if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')':
if len(array_string) < 2 or array_string[0] != "(" or array_string[-1] != ")":
raise ValueError('Invalid point string: "%s"' % array_string)
x, y = array_string.strip('()').split(',')
x, y = array_string.strip("()").split(",")
return Point(x, y)
@ -47,14 +47,14 @@ def parse_ring(array_string: str) -> Ring:
raise ValueError('Invalid ring string: "%s"' % array_string)
ring = []
for point in POINT_REGEX.finditer(array_string):
x, y = point.group('x'), point.group('y')
x, y = point.group("x"), point.group("y")
ring.append(Point(x, y))
return Ring(ring)
class PointField(Field):
class_default = Point(0, 0)
db_type = 'Point'
db_type = "Point"
def __init__(
self,
@ -63,7 +63,7 @@ class PointField(Field):
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
super().__init__(default, alias, materialized, readonly, codec, db_column)
self.inner_field = Float64Field()
@ -73,10 +73,10 @@ class PointField(Field):
value = parse_point(value)
elif isinstance(value, (tuple, list)):
if len(value) != 2:
raise ValueError('PointField takes 2 value, but %s were given' % len(value))
raise ValueError("PointField takes 2 value, but %s were given" % len(value))
value = Point(value[0], value[1])
if not isinstance(value, Point):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
return value
def validate(self, value):
@ -91,7 +91,7 @@ class PointField(Field):
class RingField(Field):
class_default = [Point(0, 0)]
db_type = 'Ring'
db_type = "Ring"
def to_python(self, value, timezone_in_use):
if isinstance(value, str):
@ -100,11 +100,11 @@ class RingField(Field):
ring = []
for point in value:
if len(point) != 2:
raise ValueError('Point takes 2 value, but %s were given' % len(value))
raise ValueError("Point takes 2 value, but %s were given" % len(value))
ring.append(Point(point[0], point[1]))
value = Ring(ring)
if not isinstance(value, Ring):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
raise ValueError("PointField expects list or tuple and Point, not %s" % type(value))
return value
def to_db_string(self, value, quote=True):

View File

@ -16,8 +16,8 @@ from .utils import parse_tsv, import_submodules
from .session import ctx_session_id, ctx_session_timeout
logger = logging.getLogger('clickhouse_orm')
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
logger = logging.getLogger("clickhouse_orm")
Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
class DatabaseException(Exception):
@ -30,6 +30,7 @@ class ServerError(DatabaseException):
"""
Raised when a server returns an error.
"""
def __init__(self, message):
self.code = None
processed = self.get_error_code_msg(message)
@ -43,21 +44,30 @@ class ServerError(DatabaseException):
ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3
re.compile(r'''
re.compile(
r"""
Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
\ e.what\(\)\ =\ (?P<type2>[^ \n]+)
''', re.VERBOSE | re.DOTALL),
""",
re.VERBOSE | re.DOTALL,
),
# ClickHouse v19.3.3+
re.compile(r'''
re.compile(
r"""
Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL),
""",
re.VERBOSE | re.DOTALL,
),
# ClickHouse v21+
re.compile(r'''
re.compile(
r"""
Code:\ (?P<code>\d+).
\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL),
""",
re.VERBOSE | re.DOTALL,
),
)
@classmethod
@ -72,7 +82,7 @@ class ServerError(DatabaseException):
match = pattern.match(full_error_message)
if match:
# assert match.group('type1') == match.group('type2')
return int(match.group('code')), match.group('msg').strip()
return int(match.group("code")), match.group("msg").strip()
return 0, full_error_message
@ -86,11 +96,21 @@ class Database:
Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations.
"""
_client_class = httpx.Client
def __init__(self, db_name, db_url='http://localhost:8123/',
username=None, password=None, readonly=False, auto_create=True,
timeout=60, verify_ssl_cert=True, log_statements=False):
def __init__(
self,
db_name,
db_url="http://localhost:8123/",
username=None,
password=None,
readonly=False,
auto_create=True,
timeout=60,
verify_ssl_cert=True,
log_statements=False,
):
"""
Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist.
@ -114,7 +134,7 @@ class Database:
self.timeout = timeout
self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout)
if username:
self.request_session.auth = (username, password or '')
self.request_session.auth = (username, password or "")
self.log_statements = log_statements
self.settings = {}
self.db_exists = False # this is required before running _is_existing_database
@ -134,7 +154,7 @@ class Database:
if self._readonly:
if not self.db_exists:
raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection'
"Database does not exist, and cannot be created under readonly connection"
)
self.connection_readonly = self._is_connection_readonly()
self.readonly = True
@ -155,14 +175,14 @@ class Database:
"""
Creates the database on the ClickHouse server if it does not already exist.
"""
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True
def drop_database(self):
"""
Deletes the database on the ClickHouse server.
"""
self._send('DROP DATABASE `%s`' % self.db_name)
self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False
def create_table(self, model_class: type[MODEL]) -> None:
@ -171,7 +191,7 @@ class Database:
"""
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None:
if getattr(model_class, "engine") is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self))
@ -190,13 +210,9 @@ class Database:
"""
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1'
return r.text.strip() == "1"
def get_model_for_table(
self,
table_name: str,
system_table: bool = False
):
def get_model_for_table(self, table_name: str, system_table: bool = False):
"""
Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class,
@ -205,7 +221,7 @@ class Database:
- `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database
"""
db_name = 'system' if system_table else self.db_name
db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines()
fields = [parse_tsv(line)[:2] for line in lines]
@ -222,7 +238,7 @@ class Database:
The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value.
"""
assert isinstance(name, str), 'Setting name must be a string'
assert isinstance(name, str), "Setting name must be a string"
if value is None:
self.settings.pop(name, None)
else:
@ -246,14 +262,13 @@ class Database:
if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join(
['`%s`' % name for name in first_instance.fields(writable=True)])
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
def gen():
buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8'))
buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self)
buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size
@ -271,12 +286,11 @@ class Database:
# Return any remaining lines in partial batch
if lines:
yield buf.getvalue()
self._send(gen())
def count(
self,
model_class: Optional[type[MODEL]],
conditions: Optional[Union[str, 'Q']] = None
self, model_class: Optional[type[MODEL]], conditions: Optional[Union[str, "Q"]] = None
) -> int:
"""
Counts the number of records in the model's table.
@ -286,20 +300,17 @@ class Database:
"""
from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
query = "SELECT count() FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += " WHERE " + str(conditions)
query = self._substitute(query, model_class)
r = self._send(query)
return int(r.text) if r.text else 0
def select(
self,
query: str,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
self, query: str, model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
"""
Performs a query and returns a generator of model instances.
@ -309,7 +320,7 @@ class Database:
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
"""
query += ' FORMAT TabSeparatedWithNamesAndTypes'
query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class)
r = self._send(query, settings, True)
try:
@ -345,7 +356,7 @@ class Database:
page_num: int = 1,
page_size: int = 100,
conditions=None,
settings: Optional[dict] = None
settings: Optional[dict] = None,
):
"""
Selects records and returns a single page of model instances.
@ -362,27 +373,28 @@ class Database:
`pages_total`, `number` (of the current page), and `page_size`.
"""
from clickhouse_orm.query import Q
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = max(pages_total, 1)
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table'
query = "SELECT * FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size)
query += " WHERE " + str(conditions)
query += " ORDER BY %s" % order_by
query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class)
return Page(
objects=list(self.select(query, model_class, settings)) if count else [],
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
def migrate(self, migrations_package_name, up_to=9999):
@ -395,19 +407,23 @@ class Database:
"""
from .migrations import MigrationHistory # pylint: disable=C0415
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name)
logger.info("Applying migration %s...", name)
for operation in modules[name].operations:
operation.apply(self)
self.insert([MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today())
])
self.insert(
[
MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today(),
)
]
)
if int(name[:4]) >= up_to:
break
@ -432,19 +448,14 @@ class Database:
query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query))
def _send(
self,
data: str | bytes | Generator,
settings: dict = None,
stream: bool = False
):
def _send(self, data: str | bytes | Generator, settings: dict = None, stream: bool = False):
if isinstance(data, str):
data = data.encode('utf-8')
data = data.encode("utf-8")
if self.log_statements:
logger.info(data)
params = self._build_params(settings)
request = self.request_session.build_request(
method='POST', url=self.db_url, content=data, params=params
method="POST", url=self.db_url, content=data, params=params
)
r = self.request_session.send(request, stream=stream)
if isinstance(r, httpx.Response) and r.status_code != 200:
@ -457,52 +468,52 @@ class Database:
params.update(self.settings)
params.update(self._context_params)
if self.db_exists:
params['database'] = self.db_name
params["database"] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error)
if self.readonly and not self.connection_readonly:
params['readonly'] = '1'
params["readonly"] = "1"
return params
def _substitute(self, query, model_class=None):
"""
Replaces $db and $table placeholders in the query.
"""
if '$' in query:
if "$" in query:
mapping = dict(db="`%s`" % self.db_name)
if model_class:
if model_class.is_system_model():
mapping['table'] = "`system`.`%s`" % model_class.table_name()
mapping["table"] = "`system`.`%s`" % model_class.table_name()
elif model_class.is_temporary_model():
mapping['table'] = "`%s`" % model_class.table_name()
mapping["table"] = "`%s`" % model_class.table_name()
else:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).safe_substitute(mapping)
return query
def _get_server_timezone(self):
try:
r = self._send('SELECT timezone()')
r = self._send("SELECT timezone()")
return pytz.timezone(r.text.strip())
except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
logger.exception("Cannot determine server timezone (%s), assuming UTC", err)
return pytz.utc
def _get_server_version(self, as_tuple=True):
try:
r = self._send('SELECT version();')
r = self._send("SELECT version();")
ver = r.text
except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
logger.exception("Cannot determine server version (%s), assuming 1.1.0", err)
ver = "1.1.0"
return tuple(int(n) for n in ver.split(".") if n.isdigit()) if as_tuple else ver
def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
return r.text.strip() == '1'
return r.text.strip() == "1"
def _is_connection_readonly(self):
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0'
return r.text.strip() != "0"
# Expose only relevant classes in import *

View File

@ -11,35 +11,30 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.funcs import F
logger = logging.getLogger('clickhouse_orm')
logger = logging.getLogger("clickhouse_orm")
class Engine:
def create_table_sql(self, db: Database) -> str:
raise NotImplementedError() # pragma: no cover
raise NotImplementedError() # pragma: no cover
class TinyLog(Engine):
def create_table_sql(self, db):
return 'TinyLog'
return "TinyLog"
class Log(Engine):
def create_table_sql(self, db):
return 'Log'
return "Log"
class Memory(Engine):
def create_table_sql(self, db):
return 'Memory'
return "Memory"
class MergeTree(Engine):
def __init__(
self,
date_col: Optional[str] = None,
@ -49,22 +44,27 @@ class MergeTree(Engine):
replica_table_path: Optional[str] = None,
replica_name: Optional[str] = None,
partition_key: Optional[Union[list, tuple]] = None,
primary_key: Optional[Union[list, tuple]] = None
primary_key: Optional[Union[list, tuple]] = None,
):
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple'
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present'
assert primary_key is None or type(primary_key) in (list, tuple), \
'primary_key must be a list or tuple'
assert partition_key is None or type(partition_key) in (list, tuple),\
'partition_key must be tuple or list if present'
assert (replica_table_path is None) == (replica_name is None), \
'both replica_table_path and replica_name must be specified'
assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
assert primary_key is None or type(primary_key) in (
list,
tuple,
), "primary_key must be a list or tuple"
assert partition_key is None or type(partition_key) in (
list,
tuple,
), "partition_key must be tuple or list if present"
assert (replica_table_path is None) == (
replica_name is None
), "both replica_table_path and replica_name must be specified"
# These values conflict with each other (old and new syntax of table engines.
# So let's control only one of them is given.
assert date_col or partition_key, "You must set either date_col or partition_key"
self.date_col = date_col
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,)
self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
self.primary_key = primary_key
self.order_by = order_by
@ -76,28 +76,33 @@ class MergeTree(Engine):
# I changed field name for new reality and syntax
@property
def key_cols(self):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. '
'Use `order_by` attribute instead')
logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. "
"Use `order_by` attribute instead"
)
return self.order_by
@key_cols.setter
def key_cols(self, value):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. '
'Use `order_by` attribute instead')
logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. "
"Use `order_by` attribute instead"
)
self.order_by = value
def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__
if self.replica_name:
name = 'Replicated' + name
name = "Replicated" + name
# In ClickHouse 1.1.54310 custom partitioning key was introduced
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
# Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \
% (comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True))
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True),
)
if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
@ -110,16 +115,17 @@ class MergeTree(Engine):
elif not self.date_col:
# Can't import it globally due to circular import
from clickhouse_orm.database import DatabaseException
raise DatabaseException(
"Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
)
else:
partition_sql = ''
partition_sql = ""
params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql)
return "%s(%s) %s" % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db: Database) -> list[str]:
params = []
@ -134,22 +140,34 @@ class MergeTree(Engine):
params.append(self.date_col)
if self.sampling_expr:
params.append(self.sampling_expr)
params.append('(%s)' % comma_join(self.order_by, stringify=True))
params.append("(%s)" % comma_join(self.order_by, stringify=True))
params.append(str(self.index_granularity))
return params
class CollapsingMergeTree(MergeTree):
def __init__(
self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None,
partition_key=None, primary_key=None
self,
date_col=None,
order_by=(),
sign_col="sign",
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(CollapsingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity,
replica_table_path, replica_name, partition_key, primary_key
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.sign_col = sign_col
@ -160,37 +178,63 @@ class CollapsingMergeTree(MergeTree):
class SummingMergeTree(MergeTree):
def __init__(
self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None,
partition_key=None, primary_key=None
self,
date_col=None,
order_by=(),
summing_cols=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(SummingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity,
replica_table_path, replica_name, partition_key, primary_key
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
assert type is None or type(summing_cols) in (list, tuple), \
'summing_cols must be a list or tuple'
assert type is None or type(summing_cols) in (
list,
tuple,
), "summing_cols must be a list or tuple"
self.summing_cols = summing_cols
def _build_sql_params(self, db: Database) -> list[str]:
params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols))
params.append("(%s)" % comma_join(self.summing_cols))
return params
class ReplacingMergeTree(MergeTree):
def __init__(
self, date_col=None, order_by=(), ver_col=None, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None,
partition_key=None, primary_key=None
self,
date_col=None,
order_by=(),
ver_col=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(ReplacingMergeTree, self).__init__(
date_col, order_by, sampling_expr, index_granularity,
replica_table_path, replica_name, partition_key, primary_key
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.ver_col = ver_col
@ -217,7 +261,7 @@ class Buffer(Engine):
min_rows: int = 10000,
max_rows: int = 1000000,
min_bytes: int = 10000000,
max_bytes: int = 100000000
max_bytes: int = 100000000,
):
self.main_model = main_model
self.num_layers = num_layers
@ -231,11 +275,17 @@ class Buffer(Engine):
def create_table_sql(self, db: Database) -> str:
# Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % (
db.db_name, self.main_model.table_name(), self.num_layers,
self.min_time, self.max_time, self.min_rows,
self.max_rows, self.min_bytes, self.max_bytes
)
sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
db.db_name,
self.main_model.table_name(),
self.num_layers,
self.min_time,
self.max_time,
self.min_rows,
self.max_rows,
self.min_bytes,
self.max_bytes,
)
return sql
@ -265,6 +315,7 @@ class Distributed(Engine):
See full documentation here
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
"""
def __init__(self, cluster, table=None, sharding_key=None):
"""
- `cluster`: what cluster to access data from
@ -292,12 +343,15 @@ class Distributed(Engine):
def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__
params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params))
return "%s(%s)" % (name, ", ".join(params))
def _build_sql_params(self, db: Database) -> list[str]:
if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format(
self.__class__.__name__))
raise ValueError(
"Cannot create {} engine: specify an underlying table".format(
self.__class__.__name__
)
)
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
if self.sharding_key:

View File

@ -21,13 +21,14 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm')
logger = getLogger("clickhouse_orm")
class Field(FunctionOperatorsMixin):
"""
Abstract base class for all field types.
"""
name: str = None # this is set by the parent model
parent: type["Model"] = None # this is set by the parent model
creation_counter: int = 0 # used for keeping the model fields ordered
@ -41,21 +42,29 @@ class Field(FunctionOperatorsMixin):
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
"Alias parameter must be a string or function object, if given"
assert (materialized is None or isinstance(materialized, F) or
isinstance(materialized, str) and materialized != ""), \
"Materialized parameter must be a string or function object, if given"
assert readonly is None or type(
readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given"
assert db_column is None or isinstance(db_column, str) and db_column != "", \
"db_column field must be string, if given"
assert [default, alias, materialized].count(
None
) >= 2, "Only one of default, alias and materialized parameters can be given"
assert (
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
), "Alias parameter must be a string or function object, if given"
assert (
materialized is None
or isinstance(materialized, F)
or isinstance(materialized, str)
and materialized != ""
), "Materialized parameter must be a string or function object, if given"
assert (
readonly is None or type(readonly) is bool
), "readonly parameter must be bool if given"
assert (
codec is None or isinstance(codec, str) and codec != ""
), "Codec field must be string, if given"
assert (
db_column is None or isinstance(db_column, str) and db_column != ""
), "db_column field must be string, if given"
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
@ -70,7 +79,7 @@ class Field(FunctionOperatorsMixin):
return self.name
def __repr__(self):
return '<%s>' % self.__class__.__name__
return "<%s>" % self.__class__.__name__
def to_python(self, value, timezone_in_use):
"""
@ -92,9 +101,10 @@ class Field(FunctionOperatorsMixin):
Utility method to check that the given value is between min_value and max_value.
"""
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (
self.__class__.__name__, value, min_value, max_value
))
raise ValueError(
"%s out of range - %s is not between %s and %s"
% (self.__class__.__name__, value, min_value, max_value)
)
def to_db_string(self, value, quote=True):
"""
@ -114,7 +124,7 @@ class Field(FunctionOperatorsMixin):
sql = self.db_type
args = self.get_db_type_args()
if args:
sql += '(%s)' % comma_join(args)
sql += "(%s)" % comma_join(args)
if with_default_expression:
sql += self._extra_params(db)
return sql
@ -124,18 +134,18 @@ class Field(FunctionOperatorsMixin):
return []
def _extra_params(self, db: Database) -> str:
sql = ''
sql = ""
if self.alias:
sql += ' ALIAS %s' % string_or_func(self.alias)
sql += " ALIAS %s" % string_or_func(self.alias)
elif self.materialized:
sql += ' MATERIALIZED %s' % string_or_func(self.materialized)
sql += " MATERIALIZED %s" % string_or_func(self.materialized)
elif isinstance(self.default, F):
sql += ' DEFAULT %s' % self.default.to_sql()
sql += " DEFAULT %s" % self.default.to_sql()
elif self.default:
default = self.to_db_string(self.default)
sql += ' DEFAULT %s' % default
sql += " DEFAULT %s" % default
if self.codec and db and db.has_codec_support and not self.alias:
sql += ' CODEC(%s)' % self.codec
sql += " CODEC(%s)" % self.codec
return sql
def isinstance(self, types) -> bool:
@ -149,28 +159,27 @@ class Field(FunctionOperatorsMixin):
"""
if isinstance(self, types):
return True
inner_field = getattr(self, 'inner_field', None)
inner_field = getattr(self, "inner_field", None)
while inner_field:
if isinstance(inner_field, types):
return True
inner_field = getattr(inner_field, 'inner_field', None)
inner_field = getattr(inner_field, "inner_field", None)
return False
class StringField(Field):
class_default = ''
db_type = 'String'
class_default = ""
db_type = "String"
def to_python(self, value, timezone_in_use) -> str:
if isinstance(value, str):
return value
if isinstance(value, bytes):
return value.decode('UTF-8')
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
return value.decode("UTF-8")
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
class FixedStringField(StringField):
def __init__(
self,
length: int,
@ -178,22 +187,22 @@ class FixedStringField(StringField):
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
self._length = length
self.db_type = 'FixedString(%d)' % length
self.db_type = "FixedString(%d)" % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use) -> str:
value = super(FixedStringField, self).to_python(value, timezone_in_use)
return value.rstrip('\0')
return value.rstrip("\0")
def validate(self, value):
if isinstance(value, str):
value = value.encode('UTF-8')
value = value.encode("UTF-8")
if len(value) > self._length:
raise ValueError(
f'Value of {len(value)} bytes is too long for FixedStringField({self._length})'
f"Value of {len(value)} bytes is too long for FixedStringField({self._length})"
)
@ -201,7 +210,7 @@ class DateField(Field):
min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31)
class_default = min_value
db_type = 'Date'
db_type = "Date"
def to_python(self, value, timezone_in_use) -> datetime.date:
if isinstance(value, datetime.datetime):
@ -211,10 +220,10 @@ class DateField(Field):
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, str):
if value == '0000-00-00':
if value == "0000-00-00":
return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value)
@ -225,7 +234,7 @@ class DateField(Field):
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
db_type = "DateTime"
def __init__(
self,
@ -235,7 +244,7 @@ class DateTimeField(Field):
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None
timezone: Optional[Union[BaseTzInfo, str]] = None,
):
super().__init__(default, alias, materialized, readonly, codec, db_column)
# assert not timezone, 'Temporarily field timezone is not supported'
@ -257,7 +266,7 @@ class DateTimeField(Field):
if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str):
if value == '0000-00-00 00:00:00':
if value == "0000-00-00 00:00:00":
return self.class_default
if len(value) == 10:
try:
@ -275,14 +284,14 @@ class DateTimeField(Field):
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = timezone_in_use.localize(dt)
return dt
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str:
return escape('%010d' % timegm(value.utctimetuple()), quote)
return escape("%010d" % timegm(value.utctimetuple()), quote)
class DateTime64Field(DateTimeField):
db_type = 'DateTime64'
db_type = "DateTime64"
"""
@ -303,10 +312,10 @@ class DateTime64Field(DateTimeField):
codec: Optional[str] = None,
db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None,
precision: int = 6
precision: int = 6,
):
super().__init__(default, alias, materialized, readonly, codec, db_column, timezone)
assert precision is None or isinstance(precision, int), 'Precision must be int type'
assert precision is None or isinstance(precision, int), "Precision must be int type"
self.precision = precision
def get_db_type_args(self):
@ -322,11 +331,10 @@ class DateTime64Field(DateTimeField):
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
"""
return escape(
'{timestamp:0{width}.{precision}f}'.format(
timestamp=value.timestamp(),
width=11 + self.precision,
precision=self.precision),
quote
"{timestamp:0{width}.{precision}f}".format(
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
),
quote,
)
def to_python(self, value, timezone_in_use) -> datetime.datetime:
@ -336,8 +344,8 @@ class DateTime64Field(DateTimeField):
if isinstance(value, (int, float)):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str):
left_part = value.split('.')[0]
if left_part == '0000-00-00 00:00:00':
left_part = value.split(".")[0]
if left_part == "0000-00-00 00:00:00":
return self.class_default
if len(left_part) == 10:
try:
@ -357,7 +365,7 @@ class BaseIntField(Field):
try:
return int(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain
@ -370,50 +378,50 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField):
min_value = 0
max_value = 2 ** 8 - 1
db_type = 'UInt8'
max_value = 2**8 - 1
db_type = "UInt8"
class UInt16Field(BaseIntField):
min_value = 0
max_value = 2 ** 16 - 1
db_type = 'UInt16'
max_value = 2**16 - 1
db_type = "UInt16"
class UInt32Field(BaseIntField):
min_value = 0
max_value = 2 ** 32 - 1
db_type = 'UInt32'
max_value = 2**32 - 1
db_type = "UInt32"
class UInt64Field(BaseIntField):
min_value = 0
max_value = 2 ** 64 - 1
db_type = 'UInt64'
max_value = 2**64 - 1
db_type = "UInt64"
class Int8Field(BaseIntField):
min_value = -2 ** 7
max_value = 2 ** 7 - 1
db_type = 'Int8'
min_value = -(2**7)
max_value = 2**7 - 1
db_type = "Int8"
class Int16Field(BaseIntField):
min_value = -2 ** 15
max_value = 2 ** 15 - 1
db_type = 'Int16'
min_value = -(2**15)
max_value = 2**15 - 1
db_type = "Int16"
class Int32Field(BaseIntField):
min_value = -2 ** 31
max_value = 2 ** 31 - 1
db_type = 'Int32'
min_value = -(2**31)
max_value = 2**31 - 1
db_type = "Int32"
class Int64Field(BaseIntField):
min_value = -2 ** 63
max_value = 2 ** 63 - 1
db_type = 'Int64'
min_value = -(2**63)
max_value = 2**63 - 1
db_type = "Int64"
class BaseFloatField(Field):
@ -425,7 +433,7 @@ class BaseFloatField(Field):
try:
return float(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain
@ -434,11 +442,11 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField):
db_type = 'Float32'
db_type = "Float32"
class Float64Field(BaseFloatField):
db_type = 'Float64'
db_type = "Float64"
class DecimalField(Field):
@ -454,13 +462,13 @@ class DecimalField(Field):
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
self.precision = precision
self.scale = scale
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale)
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
with localcontext() as ctx:
ctx.prec = 38
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
@ -473,9 +481,9 @@ class DecimalField(Field):
try:
value = Decimal(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
if not value.is_finite():
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
return self._round(value)
def to_db_string(self, value, quote=True) -> str:
@ -498,14 +506,13 @@ class Decimal32Field(DecimalField):
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
super().__init__(9, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal32(%d)' % scale
self.db_type = "Decimal32(%d)" % scale
class Decimal64Field(DecimalField):
def __init__(
self,
scale: int,
@ -513,14 +520,13 @@ class Decimal64Field(DecimalField):
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
super().__init__(18, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal64(%d)' % scale
self.db_type = "Decimal64(%d)" % scale
class Decimal128Field(DecimalField):
def __init__(
self,
scale: int,
@ -528,10 +534,10 @@ class Decimal128Field(DecimalField):
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
super().__init__(38, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal128(%d)' % scale
self.db_type = "Decimal128(%d)" % scale
class BaseEnumField(Field):
@ -547,7 +553,7 @@ class BaseEnumField(Field):
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
self.enum_cls = enum_cls
if default is None:
@ -564,7 +570,7 @@ class BaseEnumField(Field):
except Exception:
return self.enum_cls(value)
if isinstance(value, bytes):
decoded = value.decode('UTF-8')
decoded = value.decode("UTF-8")
try:
return self.enum_cls[decoded]
except Exception:
@ -573,13 +579,13 @@ class BaseEnumField(Field):
return self.enum_cls(value)
except (KeyError, ValueError):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True) -> str:
return escape(value.name, quote)
def get_db_type_args(self):
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
@classmethod
def create_ad_hoc_field(cls, db_type) -> BaseEnumField:
@ -590,17 +596,17 @@ class BaseEnumField(Field):
members = {}
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
enum_cls = Enum("AdHocEnum", members)
field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
return field_class(enum_cls)
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
db_type = "Enum8"
class Enum16Field(BaseEnumField):
db_type = 'Enum16'
db_type = "Enum16"
class ArrayField(Field):
@ -614,12 +620,14 @@ class ArrayField(Field):
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \
"Multidimensional array fields are not supported by the ORM"
assert isinstance(
inner_field, Field
), "The first argument of ArrayField must be a Field instance"
assert not isinstance(
inner_field, ArrayField
), "Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec, db_column)
@ -627,9 +635,9 @@ class ArrayField(Field):
if isinstance(value, str):
value = parse_array(value)
elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8'))
value = parse_array(value.decode("UTF-8"))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
def validate(self, value):
@ -638,12 +646,12 @@ class ArrayField(Field):
def to_db_string(self, value, quote=True) -> str:
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + comma_join(array) + ']'
return "[" + comma_join(array) + "]"
def get_sql(self, with_default_expression=True, db=None) -> str:
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec
sql += " CODEC(%s)" % self.codec
return sql
@ -658,17 +666,19 @@ class TupleField(Field):
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
self.names = {}
self.inner_fields = []
for (name, field) in name_fields:
if name in self.names:
raise ValueError('The Field name conflict')
assert isinstance(field, Field), \
"The first argument of TupleField must be a Field instance"
assert not isinstance(field, (ArrayField, TupleField)), \
"Multidimensional array fields are not supported by the ORM"
raise ValueError("The Field name conflict")
assert isinstance(
field, Field
), "The first argument of TupleField must be a Field instance"
assert not isinstance(
field, (ArrayField, TupleField)
), "Multidimensional array fields are not supported by the ORM"
self.names[name] = field
self.inner_fields.append(field)
self.class_default = tuple(field.class_default for field in self.inner_fields)
@ -677,16 +687,19 @@ class TupleField(Field):
def to_python(self, value, timezone_in_use) -> tuple:
if isinstance(value, str):
value = parse_array(value)
value = (self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
value = (
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8'))
value = (self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
value = parse_array(value.decode("UTF-8"))
value = (
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
elif not isinstance(value, (list, tuple)):
raise ValueError('TupleField expects list or tuple, not %s' % type(value))
return tuple(self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
raise ValueError("TupleField expects list or tuple, not %s" % type(value))
return tuple(
self.inner_fields[i].to_python(v, timezone_in_use) for i, v in enumerate(value)
)
def validate(self, value):
for i, v in enumerate(value):
@ -694,21 +707,22 @@ class TupleField(Field):
def to_db_string(self, value, quote=True) -> str:
array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)]
return '(' + comma_join(array) + ')'
return "(" + comma_join(array) + ")"
def get_sql(self, with_default_expression=True, db=None) -> str:
inner_sql = ', '.join('%s %s' % (name, field.get_sql(False))
for name, field in self.names.items())
inner_sql = ", ".join(
"%s %s" % (name, field.get_sql(False)) for name, field in self.names.items()
)
sql = 'Tuple(%s)' % inner_sql
sql = "Tuple(%s)" % inner_sql
if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec
sql += " CODEC(%s)" % self.codec
return sql
class UUIDField(Field):
class_default = UUID(int=0)
db_type = 'UUID'
db_type = "UUID"
def to_python(self, value, timezone_in_use) -> UUID:
if isinstance(value, UUID):
@ -722,7 +736,7 @@ class UUIDField(Field):
elif isinstance(value, tuple):
return UUID(fields=value)
else:
raise ValueError('Invalid value for UUIDField: %r' % value)
raise ValueError("Invalid value for UUIDField: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -730,7 +744,7 @@ class UUIDField(Field):
class IPv4Field(Field):
class_default = 0
db_type = 'IPv4'
db_type = "IPv4"
def to_python(self, value, timezone_in_use) -> IPv4Address:
if isinstance(value, IPv4Address):
@ -738,7 +752,7 @@ class IPv4Field(Field):
elif isinstance(value, (bytes, str, int)):
return IPv4Address(value)
else:
raise ValueError('Invalid value for IPv4Address: %r' % value)
raise ValueError("Invalid value for IPv4Address: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -746,7 +760,7 @@ class IPv4Field(Field):
class IPv6Field(Field):
class_default = 0
db_type = 'IPv6'
db_type = "IPv6"
def to_python(self, value, timezone_in_use) -> IPv6Address:
if isinstance(value, IPv6Address):
@ -754,7 +768,7 @@ class IPv6Field(Field):
elif isinstance(value, (bytes, str, int)):
return IPv6Address(value)
else:
raise ValueError('Invalid value for IPv6Address: %r' % value)
raise ValueError("Invalid value for IPv6Address: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -771,11 +785,13 @@ class NullableField(Field):
materialized: Optional[Union[F, str]] = None,
extra_null_values: Optional[Iterable] = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
assert isinstance(inner_field, Field), \
"The first argument of NullableField must be a Field instance." \
" Not: {}".format(inner_field)
assert isinstance(
inner_field, Field
), "The first argument of NullableField must be a Field instance." " Not: {}".format(
inner_field
)
self.inner_field = inner_field
self._null_values = [None]
if extra_null_values:
@ -785,7 +801,7 @@ class NullableField(Field):
)
def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values:
if value == "\\N" or value in self._null_values:
return None
return self.inner_field.to_python(value, timezone_in_use)
@ -794,18 +810,17 @@ class NullableField(Field):
def to_db_string(self, value, quote=True):
if value in self._null_values:
return '\\N'
return "\\N"
return self.inner_field.to_db_string(value, quote=quote)
def get_sql(self, with_default_expression=True, db=None):
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression:
sql += self._extra_params(db)
return sql
class LowCardinalityField(Field):
def __init__(
self,
inner_field: Field,
@ -814,16 +829,20 @@ class LowCardinalityField(Field):
materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
db_column: Optional[str] = None,
):
assert isinstance(inner_field, Field), \
"The first argument of LowCardinalityField must be a Field instance." \
" Not: {}".format(inner_field)
assert not isinstance(inner_field, LowCardinalityField), \
"LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), \
"Array field inside LowCardinality are not supported by the ORM." \
assert isinstance(
inner_field, Field
), "The first argument of LowCardinalityField must be a Field instance." " Not: {}".format(
inner_field
)
assert not isinstance(
inner_field, LowCardinalityField
), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), (
"Array field inside LowCardinality are not supported by the ORM."
" Use Array(LowCardinality) instead"
)
self.inner_field = inner_field
self.class_default = self.inner_field.class_default
super().__init__(default, alias, materialized, readonly, codec, db_column)
@ -839,12 +858,12 @@ class LowCardinalityField(Field):
def get_sql(self, with_default_expression=True, db=None):
if db and db.has_low_cardinality_support:
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
else:
sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning(
'LowCardinalityField not supported on clickhouse-server version < 19.0'
' using {} as fallback'.format(self.inner_field.__class__.__name__)
"LowCardinalityField not supported on clickhouse-server version < 19.0"
" using {} as fallback".format(self.inner_field.__class__.__name__)
)
if with_default_expression:
sql += self._extra_params(db)

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ from .utils import get_subclass_names
import logging
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
class Operation:
@ -14,7 +14,7 @@ class Operation:
"""
def apply(self, database):
raise NotImplementedError() # pragma: no cover
raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation):
@ -30,9 +30,9 @@ class ModelOperation(Operation):
self.table_name = model_class.table_name()
def _alter_table(self, database, cmd):
'''
"""
Utility for running ALTER TABLE commands.
'''
"""
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
logger.debug(cmd)
database.raw(cmd)
@ -44,7 +44,7 @@ class CreateTable(ModelOperation):
"""
def apply(self, database):
logger.info(' Create table %s', self.table_name)
logger.info(" Create table %s", self.table_name)
if issubclass(self.model_class, BufferModel):
database.create_table(self.model_class.engine.main_model)
database.create_table(self.model_class)
@ -65,7 +65,7 @@ class AlterTable(ModelOperation):
return [(row.name, row.type) for row in database.select(query)]
def apply(self, database):
logger.info(' Alter table %s', self.table_name)
logger.info(" Alter table %s", self.table_name)
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
# ADD COLUMN ... AFTER doesn't affect it
@ -74,8 +74,8 @@ class AlterTable(ModelOperation):
# Identify fields that were deleted from the model
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
for name in deleted_fields:
logger.info(' Drop column %s', name)
self._alter_table(database, 'DROP COLUMN %s' % name)
logger.info(" Drop column %s", name)
self._alter_table(database, "DROP COLUMN %s" % name)
del table_fields[name]
# Identify fields that were added to the model
@ -83,13 +83,13 @@ class AlterTable(ModelOperation):
for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias)
if name not in table_fields:
logger.info(' Add column %s', name)
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database))
logger.info(" Add column %s", name)
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field:
if prev_name:
cmd += ' AFTER %s' % prev_name
cmd += " AFTER %s" % prev_name
else:
cmd += ' FIRST'
cmd += " FIRST"
self._alter_table(database, cmd)
if is_regular_field:
@ -101,16 +101,24 @@ class AlterTable(ModelOperation):
# The order of class attributes can be changed any time, so we can't count on it
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()}
model_fields = {
name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()
}
for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement'
assert field_name in model_fields, "Model fields and table columns in disagreement"
if field_sql != model_fields[field_name]:
logger.info(' Change type of column %s from %s to %s', field_name, field_sql,
model_fields[field_name])
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name]))
logger.info(
" Change type of column %s from %s to %s",
field_name,
field_sql,
model_fields[field_name],
)
self._alter_table(
database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name])
)
class AlterTableWithBuffer(ModelOperation):
@ -135,7 +143,7 @@ class DropTable(ModelOperation):
"""
def apply(self, database):
logger.info(' Drop table %s', self.table_name)
logger.info(" Drop table %s", self.table_name)
database.drop_table(self.model_class)
@ -148,28 +156,29 @@ class AlterConstraints(ModelOperation):
"""
def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name)
logger.info(" Alter constraints for %s", self.table_name)
existing = self._get_constraint_names(database)
# Go over constraints in the model
for constraint in self.model_class._constraints.values():
# Check if it's a new constraint
if constraint.name not in existing:
logger.info(' Add constraint %s', constraint.name)
self._alter_table(database, 'ADD %s' % constraint.create_table_sql())
logger.info(" Add constraint %s", constraint.name)
self._alter_table(database, "ADD %s" % constraint.create_table_sql())
else:
existing.remove(constraint.name)
# Remaining constraints in `existing` are obsolete
for name in existing:
logger.info(' Drop constraint %s', name)
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name)
logger.info(" Drop constraint %s", name)
self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
def _get_constraint_names(self, database):
"""
Returns a set containing the names of existing constraints in the table.
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def)
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
return set(matches)
@ -191,33 +200,34 @@ class AlterIndexes(ModelOperation):
self.reindex = reindex
def apply(self, database):
logger.info(' Alter indexes for %s', self.table_name)
logger.info(" Alter indexes for %s", self.table_name)
existing = self._get_index_names(database)
logger.info(existing)
# Go over indexes in the model
for index in self.model_class._indexes.values():
# Check if it's a new index
if index.name not in existing:
logger.info(' Add index %s', index.name)
self._alter_table(database, 'ADD %s' % index.create_table_sql())
logger.info(" Add index %s", index.name)
self._alter_table(database, "ADD %s" % index.create_table_sql())
else:
existing.remove(index.name)
# Remaining indexes in `existing` are obsolete
for name in existing:
logger.info(' Drop index %s', name)
self._alter_table(database, 'DROP INDEX `%s`' % name)
logger.info(" Drop index %s", name)
self._alter_table(database, "DROP INDEX `%s`" % name)
# Reindex
if self.reindex:
logger.info(' Build indexes on table')
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name)
logger.info(" Build indexes on table")
database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
def _get_index_names(self, database):
"""
Returns a set containing the names of existing indexes in the table.
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def)
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
return set(matches)
@ -225,16 +235,17 @@ class RunPython(Operation):
"""
A migration operation that executes a Python function.
"""
def __init__(self, func):
'''
"""
Initializer. The given Python function will be called with a single
argument - the Database instance to apply the migration to.
'''
"""
assert callable(func), "'func' argument must be function"
self._func = func
def apply(self, database):
logger.info(' Executing python operation %s', self._func.__name__)
logger.info(" Executing python operation %s", self._func.__name__)
self._func(database)
@ -244,17 +255,17 @@ class RunSQL(Operation):
"""
def __init__(self, sql):
'''
"""
Initializer. The given sql argument must be a valid SQL statement or
list of statements.
'''
"""
if isinstance(sql, str):
sql = [sql]
assert isinstance(sql, list), "'sql' argument must be string or list of strings"
self._sql = sql
def apply(self, database):
logger.info(' Executing raw SQL operations')
logger.info(" Executing raw SQL operations")
for item in self._sql:
database.raw(item)
@ -268,11 +279,11 @@ class MigrationHistory(Model):
module_name = StringField()
applied = DateField()
engine = MergeTree('applied', ('package_name', 'module_name'))
engine = MergeTree("applied", ("package_name", "module_name"))
@classmethod
def table_name(cls):
return 'infi_clickhouse_orm_migrations'
return "infi_clickhouse_orm_migrations"
# Expose only relevant classes in import *

View File

@ -17,7 +17,7 @@ from .engines import Merge, Distributed, Memory
if TYPE_CHECKING:
from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm')
logger = getLogger("clickhouse_orm")
class Constraint:
@ -38,7 +38,7 @@ class Constraint:
"""
Returns the SQL statement for defining this constraint during table creation.
"""
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr))
return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
class Index:
@ -66,8 +66,11 @@ class Index:
"""
Returns the SQL statement for defining this index during table creation.
"""
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (
self.name, arg_to_sql(self.expr), self.type, self.granularity
return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (
self.name,
arg_to_sql(self.expr),
self.type,
self.granularity,
)
@staticmethod
@ -76,7 +79,7 @@ class Index:
An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
"""
return 'minmax'
return "minmax"
@staticmethod
def set(max_rows: int) -> str:
@ -85,11 +88,12 @@ class Index:
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data.
"""
return 'set(%d)' % max_rows
return "set(%d)" % max_rows
@staticmethod
def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int,
number_of_hash_functions: int, random_seed: int) -> str:
def ngrambf_v1(
n: int, size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
) -> str:
"""
An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -100,13 +104,17 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
"""
return 'ngrambf_v1(%d, %d, %d, %d)' % (
n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
return "ngrambf_v1(%d, %d, %d, %d)" % (
n,
size_of_bloom_filter_in_bytes,
number_of_hash_functions,
random_seed,
)
@staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int,
random_seed: int) -> str:
def tokenbf_v1(
size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, random_seed: int
) -> str:
"""
An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters.
@ -116,8 +124,10 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
"""
return 'tokenbf_v1(%d, %d, %d)' % (
size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
return "tokenbf_v1(%d, %d, %d)" % (
size_of_bloom_filter_in_bytes,
number_of_hash_functions,
random_seed,
)
@staticmethod
@ -128,7 +138,7 @@ class Index:
- `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter
"""
return 'bloom_filter(%f)' % false_positive
return "bloom_filter(%f)" % false_positive
class ModelBase(type):
@ -183,23 +193,23 @@ class ModelBase(type):
_indexes=indexes,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults
_has_funcs_as_defaults=has_funcs_as_defaults,
)
model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs)
# Let each field, constraint and index know its parent and its own name
for n, obj in chain(fields, constraints.items(), indexes.items()):
setattr(obj, 'parent', model)
setattr(obj, 'name', n)
setattr(obj, "parent", model)
setattr(obj, "name", n)
return model
@classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'):
def create_ad_hoc_model(cls, fields, model_name="AdHocModel"):
# fields is a list of tuples (name, db_type)
# Check if model exists in cache
fields = list(fields)
cache_key = model_name + ' ' + str(fields)
cache_key = model_name + " " + str(fields)
if cache_key in cls.ad_hoc_model_cache:
return cls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class
@ -217,28 +227,25 @@ class ModelBase(type):
import clickhouse_orm.contrib.geo.fields as geo_fields
# Enums
if db_type.startswith('Enum'):
if db_type.startswith("Enum"):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# DateTime with timezone
if db_type.startswith('DateTime('):
if db_type.startswith("DateTime("):
timezone = db_type[9:-1]
return orm_fields.DateTimeField(
timezone=timezone[1:-1] if timezone else None
)
return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
# DateTime64
if db_type.startswith('DateTime64('):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')]
if db_type.startswith("DateTime64("):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
return orm_fields.DateTime64Field(
precision=int(precision),
timezone=timezone[0][1:-1] if timezone else None
precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
)
# Arrays
if db_type.startswith('Array'):
if db_type.startswith("Array"):
inner_field = cls.create_ad_hoc_field(db_type[6:-1])
return orm_fields.ArrayField(inner_field)
# Tuples
if db_type.startswith('Tuple'):
types = [s.strip().split(' ') for s in db_type[6:-1].split(',')]
if db_type.startswith("Tuple"):
types = [s.strip().split(" ") for s in db_type[6:-1].split(",")]
name_fields = []
for i, tp in enumerate(types):
if len(tp) == 2:
@ -247,27 +254,27 @@ class ModelBase(type):
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
return orm_fields.TupleField(name_fields=name_fields)
# FixedString
if db_type.startswith('FixedString'):
if db_type.startswith("FixedString"):
length = int(db_type[12:-1])
return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'):
p = db_type.index('(')
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')]
field_class = getattr(orm_fields, db_type[:p] + 'Field')
if db_type.startswith("Decimal"):
p = db_type.index("(")
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
field_class = getattr(orm_fields, db_type[:p] + "Field")
return field_class(*args)
# Nullable
if db_type.startswith('Nullable'):
inner_field = cls.create_ad_hoc_field(db_type[9 : -1])
if db_type.startswith("Nullable"):
inner_field = cls.create_ad_hoc_field(db_type[9:-1])
return orm_fields.NullableField(inner_field)
# LowCardinality
if db_type.startswith('LowCardinality'):
inner_field = cls.create_ad_hoc_field(db_type[15 : -1])
if db_type.startswith("LowCardinality"):
inner_field = cls.create_ad_hoc_field(db_type[15:-1])
return orm_fields.LowCardinalityField(inner_field)
# Simple fields
name = db_type + 'Field'
name = db_type + "Field"
if not (hasattr(orm_fields, name) or hasattr(geo_fields, name)):
raise NotImplementedError('No field class for %s' % db_type)
raise NotImplementedError("No field class for %s" % db_type)
field_class = getattr(orm_fields, name, None) or getattr(geo_fields, name, None)
return field_class()
@ -282,6 +289,7 @@ class Model(metaclass=ModelBase):
cpu_percent = Float32Field()
engine = Memory()
"""
_has_funcs_as_defaults: bool
_constraints: dict[str, Constraint]
_indexes: dict[str, Index]
@ -318,7 +326,7 @@ class Model(metaclass=ModelBase):
setattr(self, name, value)
else:
raise AttributeError(
'%s does not have a field called %s' % (self.__class__.__name__, name)
"%s does not have a field called %s" % (self.__class__.__name__, name)
)
def __setattr__(self, name, value):
@ -383,29 +391,29 @@ class Model(metaclass=ModelBase):
"""
Returns the SQL statement for creating a table for this model.
"""
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
# Fields
items = []
for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db)))
items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints
for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql())
items.append(" %s" % c.create_table_sql())
# Indexes
for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql())
parts.append(',\n'.join(items))
items.append(" %s" % i.create_table_sql())
parts.append(",\n".join(items))
# Engine
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts)
parts.append(")")
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return "\n".join(parts)
@classmethod
def drop_table_sql(cls, db: Database) -> str:
"""
Returns the SQL command for deleting this model's table.
"""
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name())
return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
@classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
@ -422,7 +430,7 @@ class Model(metaclass=ModelBase):
kwargs = {}
for name in field_names:
field = getattr(cls, name)
field_timezone = getattr(field, 'timezone', None) or timezone_in_use
field_timezone = getattr(field, "timezone", None) or timezone_in_use
kwargs[name] = field.to_python(next(values), field_timezone)
obj = cls(**kwargs)
@ -439,7 +447,9 @@ class Model(metaclass=ModelBase):
"""
data = self.__dict__
fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
return "\t".join(
field.to_db_string(data[name], quote=False) for name, field in fields.items()
)
def to_tskv(self, include_readonly=True):
"""
@ -453,16 +463,16 @@ class Model(metaclass=ModelBase):
parts = []
for name, field in fields.items():
if data[name] != NO_VALUE:
parts.append(name + '=' + field.to_db_string(data[name], quote=False))
return '\t'.join(parts)
parts.append(name + "=" + field.to_db_string(data[name], quote=False))
return "\t".join(parts)
def to_db_string(self) -> bytes:
"""
Returns the instance as a bytestring ready to be inserted into the database.
"""
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n'
return s.encode('utf-8')
s += "\n"
return s.encode("utf-8")
def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]:
"""
@ -519,19 +529,18 @@ class Model(metaclass=ModelBase):
class BufferModel(Model):
@classmethod
def create_table_sql(cls, db: Database) -> str:
"""
Returns the SQL statement for creating a table for this model.
"""
parts = [
'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (
db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
"CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
% (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str)
return ' '.join(parts)
return " ".join(parts)
class MergeModel(Model):
@ -540,6 +549,7 @@ class MergeModel(Model):
Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
"""
readonly = True
# Virtual fields can't be inserted into database
@ -551,15 +561,16 @@ class MergeModel(Model):
Returns the SQL statement for creating a table for this model.
"""
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
cols = []
for name, field in cls.fields().items():
if name != '_table':
cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols))
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts)
if name != "_table":
cols.append(" %s %s" % (name, field.get_sql(db=db)))
parts.append(",\n".join(cols))
parts.append(")")
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return "\n".join(parts)
# TODO: base class for models that require specific engine
@ -574,8 +585,9 @@ class DistributedModel(Model):
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
"""
assert isinstance(self.engine, Distributed),\
"engine must be an instance of engines.Distributed"
assert isinstance(
self.engine, Distributed
), "engine must be an instance of engines.Distributed"
super().set_database(db)
@classmethod
@ -616,15 +628,20 @@ class DistributedModel(Model):
return
# find out all the superclasses of the Model that store any data
storage_models = [b for b in cls.__bases__ if issubclass(b, Model)
and not issubclass(b, DistributedModel)]
storage_models = [
b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)
]
if not storage_models:
raise TypeError("When defining Distributed engine without the table_name "
"ensure that your model has a parent model")
raise TypeError(
"When defining Distributed engine without the table_name "
"ensure that your model has a parent model"
)
if len(storage_models) > 1:
raise TypeError("When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass")
raise TypeError(
"When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass"
)
# enable correct SQL for engine
cls.engine.table = storage_models[0]
@ -637,10 +654,12 @@ class DistributedModel(Model):
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
parts = [
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format(
db.db_name, cls.table_name(), cls.engine.table_name),
'ENGINE = ' + cls.engine.create_table_sql(db)]
return '\n'.join(parts)
"CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
db.db_name, cls.table_name(), cls.engine.table_name
),
"ENGINE = " + cls.engine.create_table_sql(db),
]
return "\n".join(parts)
class TemporaryModel(Model):
@ -657,30 +676,31 @@ class TemporaryModel(Model):
https://clickhouse.com/docs/en/sql-reference/statements/create/table/#temporary-tables
"""
_temporary = True
@classmethod
def create_table_sql(cls, db: Database) -> str:
assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance"
parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()]
parts = ["CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (" % cls.table_name()]
# Fields
items = []
for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db)))
items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints
for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql())
items.append(" %s" % c.create_table_sql())
# Indexes
for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql())
parts.append(',\n'.join(items))
items.append(" %s" % i.create_table_sql())
parts.append(",\n".join(items))
# Engine
parts.append(')')
parts.append('ENGINE = Memory')
return '\n'.join(parts)
parts.append(")")
parts.append("ENGINE = Memory")
return "\n".join(parts)
# Expose only relevant classes in import *
MODEL = TypeVar('MODEL', bound=Model)
MODEL = TypeVar("MODEL", bound=Model)
__all__ = get_subclass_names(locals(), (Model, Constraint, Index))

View File

@ -11,7 +11,7 @@ from typing import (
Generic,
TypeVar,
AsyncIterator,
Iterator
Iterator,
)
import pytz
@ -24,7 +24,7 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.database import Database, Page
MODEL = TypeVar('MODEL', bound='Model')
MODEL = TypeVar("MODEL", bound="Model")
class Operator:
@ -59,9 +59,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null])
return ' '.join([field.name, self._sql_operator, value])
if value == "\\N" and self._sql_for_null is not None:
return " ".join([field_name, self._sql_for_null])
return " ".join([field.name, self._sql_operator, value])
class InOperator(Operator):
@ -81,7 +81,7 @@ class InOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field.name, value)
return "%s IN (%s)" % (field.name, value)
class GlobalInOperator(Operator):
@ -95,7 +95,7 @@ class GlobalInOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s GLOBAL IN (%s)' % (field.name, value)
return "%s GLOBAL IN (%s)" % (field.name, value)
class LikeOperator(Operator):
@ -111,11 +111,11 @@ class LikeOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value)
if self._case_sensitive:
return '%s LIKE \'%s\'' % (field.name, pattern)
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern)
return "%s LIKE '%s'" % (field.name, pattern)
return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field.name, pattern)
class IExactOperator(Operator):
@ -126,7 +126,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value)
return "lowerUTF8(%s) = lowerUTF8(%s)" % (field.name, value)
class NotOperator(Operator):
@ -139,7 +139,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
# Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value)
return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator):
@ -154,16 +154,22 @@ class BetweenOperator(Operator):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(
str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(
str(value[1])) > 0 else None
value0 = (
self._value_to_sql(field, value[0])
if value[0] is not None or len(str(value[0])) > 0
else None
)
value1 = (
self._value_to_sql(field, value[1])
if value[1] is not None or len(str(value[1])) > 0
else None
)
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field.name, value0, value1)
return "%s BETWEEN %s AND %s" % (field.name, value0, value1)
if value0 and not value1:
return ' '.join([field.name, '>=', value0])
return " ".join([field.name, ">=", value0])
if value1 and not value0:
return ' '.join([field.name, '<=', value1])
return " ".join([field.name, "<=", value1])
# Define the set of builtin operators
@ -175,24 +181,24 @@ def register_operator(name: str, sql: Operator):
_operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL'))
register_operator('gt', SimpleOperator('>'))
register_operator('gte', SimpleOperator('>='))
register_operator('lt', SimpleOperator('<'))
register_operator('lte', SimpleOperator('<='))
register_operator('between', BetweenOperator())
register_operator('in', InOperator())
register_operator('gin', GlobalInOperator())
register_operator('not_in', NotOperator(InOperator()))
register_operator('not_gin', NotOperator(GlobalInOperator()))
register_operator('contains', LikeOperator('%{}%'))
register_operator('startswith', LikeOperator('{}%'))
register_operator('endswith', LikeOperator('%{}'))
register_operator('icontains', LikeOperator('%{}%', False))
register_operator('istartswith', LikeOperator('{}%', False))
register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator())
register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator("gt", SimpleOperator(">"))
register_operator("gte", SimpleOperator(">="))
register_operator("lt", SimpleOperator("<"))
register_operator("lte", SimpleOperator("<="))
register_operator("between", BetweenOperator())
register_operator("in", InOperator())
register_operator("gin", GlobalInOperator())
register_operator("not_in", NotOperator(InOperator()))
register_operator("not_gin", NotOperator(GlobalInOperator()))
register_operator("contains", LikeOperator("%{}%"))
register_operator("startswith", LikeOperator("{}%"))
register_operator("endswith", LikeOperator("%{}"))
register_operator("icontains", LikeOperator("%{}%", False))
register_operator("istartswith", LikeOperator("{}%", False))
register_operator("iendswith", LikeOperator("%{}", False))
register_operator("iexact", IExactOperator())
class Cond:
@ -214,8 +220,8 @@ class FieldCond(Cond):
self._operator = _operators.get(operator)
if self._operator is None:
# The field name contains __ like my__field
self._field_name = field_name + '__' + operator
self._operator = _operators['eq']
self._field_name = field_name + "__" + operator
self._operator = _operators["eq"]
self._value = value
def to_sql(self, model_cls: type[Model]) -> str:
@ -228,12 +234,13 @@ class FieldCond(Cond):
class Q:
AND_MODE = 'AND'
OR_MODE = 'OR'
AND_MODE = "AND"
OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in
filter_fields.items()]
self._conds = list(filter_funcs) + [
self._build_cond(k, v) for k, v in filter_fields.items()
]
self._children = []
self._negate = False
self._mode = self.AND_MODE
@ -263,10 +270,10 @@ class Q:
return q
def _build_cond(self, key, value):
if '__' in key:
field_name, operator = key.rsplit('__', 1)
if "__" in key:
field_name, operator = key.rsplit("__", 1)
else:
field_name, operator = key, 'eq'
field_name, operator = key, "eq"
return FieldCond(field_name, operator, value)
def to_sql(self, model_cls: type[Model]) -> str:
@ -280,16 +287,16 @@ class Q:
if not condition_sql:
# Empty Q() object returns everything
sql = '1'
sql = "1"
elif len(condition_sql) == 1:
# Skip not needed brackets over single condition
sql = condition_sql[0]
else:
# Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql)
sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate:
sql = 'NOT (%s)' % sql
sql = "NOT (%s)" % sql
return sql
@ -400,16 +407,16 @@ class QuerySet(Generic[MODEL]):
def __getitem__(self, s):
if isinstance(s, int):
# Single index
assert s >= 0, 'negative indexes are not supported'
assert s >= 0, "negative indexes are not supported"
queryset = self._clone()
queryset._limits = (s, 1)
return next(iter(queryset))
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0
stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end'
stop = s.stop or 2**63 - 1
assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, "start of slice cannot be smaller than its end"
queryset = self._clone()
queryset._limits = (start, stop - start)
return queryset
@ -425,7 +432,7 @@ class QuerySet(Generic[MODEL]):
offset_limit = (0, offset_limit)
offset = offset_limit[0]
limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
assert offset >= 0 and limit >= 0, "negative limits are not supported"
queryset = self._clone()
queryset._limit_by = (offset, limit)
queryset._limit_by_fields = fields_or_expr
@ -435,44 +442,44 @@ class QuerySet(Generic[MODEL]):
"""
Returns the selected fields or expressions as a SQL string.
"""
fields = '*'
fields = "*"
if self._fields:
fields = comma_join('`%s`' % field for field in self._fields)
fields = comma_join("`%s`" % field for field in self._fields)
return fields
def as_sql(self) -> str:
"""
Returns the whole query as a SQL string.
"""
distinct = 'DISTINCT ' if self._distinct else ''
final = ' FINAL' if self._final else ''
table_name = '`%s`' % self._model_cls.table_name()
distinct = "DISTINCT " if self._distinct else ""
final = " FINAL" if self._final else ""
table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model():
table_name = '`system`.' + table_name
table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = 'SELECT %s%s\nFROM %s%s' % params
sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True)
sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields)
sql += "\nGROUP BY %s" % comma_join("%s" % field for field in self._grouping_fields)
if self._grouping_with_totals:
sql += ' WITH TOTALS'
sql += " WITH TOTALS"
if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql()
sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
sql += "\nLIMIT %d, %d" % self._limit_by
sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits:
sql += '\nLIMIT %d, %d' % self._limits
sql += "\nLIMIT %d, %d" % self._limits
return sql
@ -480,10 +487,12 @@ class QuerySet(Generic[MODEL]):
"""
Returns the contents of the query's `ORDER BY` clause as a string.
"""
return comma_join([
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field)
for field in self._order_by
])
return comma_join(
[
"%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
for field in self._order_by
]
)
def conditions_as_sql(self, prewhere=False) -> str:
"""
@ -498,7 +507,7 @@ class QuerySet(Generic[MODEL]):
"""
if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate
sql = 'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -527,8 +536,8 @@ class QuerySet(Generic[MODEL]):
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
from clickhouse_orm.funcs import F
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop("prewhere", False)
queryset = self._clone()
@ -588,14 +597,14 @@ class QuerySet(Generic[MODEL]):
if page_num == -1:
page_num = pages_total
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
return Page(
objects=list(self[offset: offset + page_size]),
objects=list(self[offset : offset + page_size]),
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
def distinct(self) -> "QuerySet[MODEL]":
@ -616,8 +625,8 @@ class QuerySet(Generic[MODEL]):
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError(
'final() method can be used only with the CollapsingMergeTree'
' and ReplacingMergeTree engines'
"final() method can be used only with the CollapsingMergeTree"
" and ReplacingMergeTree engines"
)
queryset = self._clone()
@ -631,7 +640,7 @@ class QuerySet(Generic[MODEL]):
"""
self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions)
sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql)
return self
@ -641,12 +650,14 @@ class QuerySet(Generic[MODEL]):
Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate.
"""
assert kwargs, 'No fields specified for update'
assert kwargs, "No fields specified for update"
self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (
self._model_cls.table_name(), fields, conditions
sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (
self._model_cls.table_name(),
fields,
conditions,
)
self._database.raw(sql)
return self
@ -655,10 +666,10 @@ class QuerySet(Generic[MODEL]):
"""
Checks that the queryset's state allows mutations. Raises an AssertionError if not.
"""
assert not self._limits, 'Mutations are not allowed after slicing the queryset'
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()'
assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
"""
@ -687,7 +698,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
self,
base_queryset: QuerySet,
grouping_fields: tuple[Any],
calculated_fields: dict[str, str]
calculated_fields: dict[str, str],
):
"""
Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`.
@ -705,7 +716,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
At least one calculated field is required.
"""
super().__init__(base_queryset._model_cls, base_queryset._database)
assert calculated_fields, 'No calculated fields specified for aggregation'
assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields
self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields
@ -734,8 +745,9 @@ class AggregateQuerySet(QuerySet[MODEL]):
created with.
"""
for name in args:
assert name in self._fields or name in self._calculated_fields, \
'Cannot group by `%s` since it is not included in the query' % name
assert name in self._fields or name in self._calculated_fields, (
"Cannot group by `%s` since it is not included in the query" % name
)
queryset = copy(self)
queryset._grouping_fields = args
return queryset
@ -750,14 +762,16 @@ class AggregateQuerySet(QuerySet[MODEL]):
"""
This method is not supported on `AggregateQuerySet`.
"""
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet')
raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self) -> str:
"""
Returns the selected fields or expressions as a SQL string.
"""
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
self._calculated_fields.items()])
return comma_join(
[str(f) for f in self._fields]
+ ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self) -> Iterator[Model]:
"""
@ -778,7 +792,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
"""
Returns the number of rows after aggregation.
"""
sql = 'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
if isinstance(raw, CoroutineType):
return raw
@ -795,7 +809,7 @@ class AggregateQuerySet(QuerySet[MODEL]):
return queryset
def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet')
raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import *

View File

@ -2,8 +2,8 @@ import uuid
from typing import Optional
from contextvars import ContextVar, Token
ctx_session_id: ContextVar[str] = ContextVar('ck.session_id')
ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout')
ctx_session_id: ContextVar[str] = ContextVar("ck.session_id")
ctx_session_timeout: ContextVar[float] = ContextVar("ck.session_timeout")
class SessionContext:

View File

@ -16,12 +16,15 @@ class SystemPart(Model):
This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/
"""
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
_readonly = True
_system = True
database = StringField() # Name of the database where the table that this part belongs to is located.
database = (
StringField()
) # Name of the database where the table that this part belongs to is located.
table = StringField() # Name of the table that this part belongs to.
engine = StringField() # Name of the table engine, without parameters.
partition = StringField() # Name of the partition, in the format YYYYMM.
@ -43,7 +46,9 @@ class SystemPart(Model):
# Time the directory with the part was modified. Usually corresponds to the part's creation time.
modification_time = DateTimeField()
remove_time = DateTimeField() # For inactive parts only - the time when the part became inactive.
remove_time = (
DateTimeField()
) # For inactive parts only - the time when the part became inactive.
# The number of places where the part is used. A value greater than 2 indicates
# that this part participates in queries or merges.
@ -51,12 +56,13 @@ class SystemPart(Model):
@classmethod
def table_name(cls):
return 'parts'
return "parts"
"""
Next methods return SQL for some operations, which can be done with partitions
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
"""
def _partition_operation_sql(self, operation, settings=None, from_part=None):
"""
Performs some operation over partition
@ -68,9 +74,16 @@ class SystemPart(Model):
Returns: Operation execution result
"""
operation = operation.upper()
assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(self.OPERATIONS)
assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(
self.OPERATIONS
)
sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (self._database.db_name, self.table, operation, self.partition)
sql = "ALTER TABLE `%s`.`%s` %s PARTITION %s" % (
self._database.db_name,
self.table,
operation,
self.partition,
)
if from_part is not None:
sql += " FROM %s" % from_part
self._database.raw(sql, settings=settings, stream=False)
@ -83,7 +96,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('DETACH', settings=settings)
return self._partition_operation_sql("DETACH", settings=settings)
def drop(self, settings=None):
"""
@ -93,7 +106,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('DROP', settings=settings)
return self._partition_operation_sql("DROP", settings=settings)
def attach(self, settings=None):
"""
@ -103,7 +116,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('ATTACH', settings=settings)
return self._partition_operation_sql("ATTACH", settings=settings)
def freeze(self, settings=None):
"""
@ -113,7 +126,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('FREEZE', settings=settings)
return self._partition_operation_sql("FREEZE", settings=settings)
def fetch(self, zookeeper_path, settings=None):
"""
@ -124,7 +137,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path)
return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
@classmethod
def get(cls, database, conditions=""):
@ -140,9 +153,12 @@ class SystemPart(Model):
assert isinstance(conditions, str), "conditions must be a string"
if conditions:
conditions += " AND"
field_names = ','.join(cls.fields())
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" %
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls)
field_names = ",".join(cls.fields())
return database.select(
"SELECT %s FROM `system`.%s WHERE %s database='%s'"
% (field_names, cls.table_name(), conditions, database.db_name),
model_class=cls,
)
@classmethod
def get_active(cls, database, conditions=""):
@ -155,8 +171,8 @@ class SystemPart(Model):
Returns: A list of SystemPart objects
"""
if conditions:
conditions += ' AND '
conditions += 'active'
conditions += " AND "
conditions += "active"
return SystemPart.get(database, conditions=conditions)

View File

@ -10,10 +10,10 @@ SPECIAL_CHARS = {
"\t": "\\t",
"\0": "\\0",
"\\": "\\\\",
"'": "\\'"
"'": "\\'",
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
@ -36,11 +36,11 @@ def escape(value, quote=True):
def unescape(value):
return codecs.escape_decode(value)[0].decode('utf-8')
return codecs.escape_decode(value)[0].decode("utf-8")
def string_or_func(obj):
return obj.to_sql() if hasattr(obj, 'to_sql') else obj
return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg):
@ -50,6 +50,7 @@ def arg_to_sql(arg):
None, numbers, timezones, arrays/iterables.
"""
from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
if isinstance(arg, F):
return arg.to_sql()
if isinstance(arg, Field):
@ -67,22 +68,22 @@ def arg_to_sql(arg):
if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None))
if arg is None:
return 'NULL'
return "NULL"
if isinstance(arg, QuerySet):
return "(%s)" % arg
if isinstance(arg, tuple):
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')'
return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
if is_iterable(arg):
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']'
return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
return str(arg)
def parse_tsv(line):
if isinstance(line, bytes):
line = line.decode()
if line and line[-1] == '\n':
if line and line[-1] == "\n":
line = line[:-1]
return [unescape(value) for value in line.split(str('\t'))]
return [unescape(value) for value in line.split(str("\t"))]
def parse_array(array_string):
@ -92,17 +93,17 @@ def parse_array(array_string):
"(1,2,3)" ==> [1, 2, 3]
"""
# Sanity check
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])':
if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace
array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left
values = []
while True:
if array_string in '])':
if array_string in "])":
# End of array
return values
elif array_string[0] in ', ':
elif array_string[0] in ", ":
# In between values
array_string = array_string[1:]
elif array_string[0] == "'":
@ -110,13 +111,13 @@ def parse_array(array_string):
match = re.search(r"[^\\]'", array_string)
if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1: match.start() + 1])
array_string = array_string[match.end():]
values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end() :]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]|\)", array_string)
values.append(array_string[0: match.start()])
array_string = array_string[match.end() - 1:]
values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1 :]
def import_submodules(package_name):
@ -124,9 +125,10 @@ def import_submodules(package_name):
Import all submodules of a module.
"""
import importlib, pkgutil
package = importlib.import_module(package_name)
return {
name: importlib.import_module(package_name + '.' + name)
name: importlib.import_module(package_name + "." + name)
for _, name, _ in pkgutil.iter_modules(package.__path__)
}
@ -136,9 +138,9 @@ def comma_join(items, stringify=False):
Joins an iterable of strings with commas.
"""
if stringify:
return ', '.join(str(item) for item in items)
return ", ".join(str(item) for item in items)
else:
return ', '.join(items)
return ", ".join(items)
def is_iterable(obj):
@ -154,6 +156,7 @@ def is_iterable(obj):
def get_subclass_names(locals, base_class):
from inspect import isclass
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
@ -164,7 +167,7 @@ class NoValue:
"""
def __repr__(self):
return 'NO_VALUE'
return "NO_VALUE"
NO_VALUE = NoValue()