From 7138dfe8c3eca8294568f9ba5a503b9d847872d3 Mon Sep 17 00:00:00 2001 From: sswest Date: Thu, 26 May 2022 17:02:32 +0800 Subject: [PATCH] add typing hint --- src/clickhouse_orm/contrib/geo/fields.py | 3 + src/clickhouse_orm/database.py | 129 ++++++++++-------- src/clickhouse_orm/fields.py | 158 +++++++++++----------- src/clickhouse_orm/funcs.py | 8 ++ src/clickhouse_orm/models.py | 164 ++++++++++++----------- src/clickhouse_orm/query.py | 19 +-- 6 files changed, 262 insertions(+), 219 deletions(-) diff --git a/src/clickhouse_orm/contrib/geo/fields.py b/src/clickhouse_orm/contrib/geo/fields.py index 6c1bb87..1c6a3d1 100644 --- a/src/clickhouse_orm/contrib/geo/fields.py +++ b/src/clickhouse_orm/contrib/geo/fields.py @@ -74,6 +74,9 @@ class PointField(Field): def to_db_string(self, value, quote=True): return value.to_db_string() + def __getitem__(self, item): + return + class RingField(Field): class_default = [Point(0, 0)] diff --git a/src/clickhouse_orm/database.py b/src/clickhouse_orm/database.py index e4396cc..783a7e6 100644 --- a/src/clickhouse_orm/database.py +++ b/src/clickhouse_orm/database.py @@ -1,26 +1,28 @@ from __future__ import unicode_literals - import re -import requests -from collections import namedtuple -from .models import ModelBase -from .utils import escape, parse_tsv, import_submodules -from math import ceil -import datetime -from string import Template -import pytz - import logging +import datetime +from math import ceil +from string import Template +from collections import namedtuple +from typing import Type, Optional, Generator, Union, Any + +import pytz +import requests + +from .models import ModelBase, MODEL +from .utils import parse_tsv, import_submodules +from .query import Q + + logger = logging.getLogger('clickhouse_orm') - - Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') class DatabaseException(Exception): - ''' + """ Raised when a database operation fails. - ''' + """ pass @@ -80,15 +82,15 @@ class ServerError(DatabaseException): class Database(object): - ''' + """ Database instances connect to a specific ClickHouse database for running queries, inserting data and other operations. - ''' + """ def __init__(self, db_name, db_url='http://localhost:8123/', username=None, password=None, readonly=False, autocreate=True, timeout=60, verify_ssl_cert=True, log_statements=False): - ''' + """ Initializes a database instance. Unless it's readonly, the database will be created on the ClickHouse server if it does not already exist. @@ -101,7 +103,7 @@ class Database(object): - `timeout`: the connection timeout in seconds. - `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS. - `log_statements`: when True, all database statements are logged. - ''' + """ self.db_name = db_name self.db_url = db_url self.readonly = False @@ -130,55 +132,59 @@ class Database(object): self.has_low_cardinality_support = self.server_version >= (19, 0) def create_database(self): - ''' + """ Creates the database on the ClickHouse server if it does not already exist. - ''' + """ self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self.db_exists = True def drop_database(self): - ''' + """ Deletes the database on the ClickHouse server. - ''' + """ self._send('DROP DATABASE `%s`' % self.db_name) self.db_exists = False - def create_table(self, model_class): - ''' + def create_table(self, model_class: Type[MODEL]) -> None: + """ Creates a table for the given model class, if it does not exist already. - ''' + """ if model_class.is_system_model(): raise DatabaseException("You can't create system table") if getattr(model_class, 'engine') is None: raise DatabaseException("%s class must define an engine" % model_class.__name__) self._send(model_class.create_table_sql(self)) - def drop_table(self, model_class): - ''' + def drop_table(self, model_class: Type[MODEL]) -> None: + """ Drops the database table of the given model class, if it exists. - ''' + """ if model_class.is_system_model(): raise DatabaseException("You can't drop system table") self._send(model_class.drop_table_sql(self)) - def does_table_exist(self, model_class): - ''' + def does_table_exist(self, model_class: Type[MODEL]) -> bool: + """ Checks whether a table for the given model class already exists. Note that this only checks for existence of a table with the expected name. - ''' + """ sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" r = self._send(sql % (self.db_name, model_class.table_name())) return r.text.strip() == '1' - def get_model_for_table(self, table_name, system_table=False): - ''' + def get_model_for_table( + self, + table_name: str, + system_table: bool = False + ): + """ Generates a model class from an existing table in the database. This can be used for querying tables which don't have a corresponding model class, for example system tables. - `table_name`: the table to create a model for - `system_table`: whether the table is a system table, or belongs to the current database - ''' + """ db_name = 'system' if system_table else self.db_name sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) lines = self._send(sql).iter_lines() @@ -188,14 +194,14 @@ class Database(object): model._system = model._readonly = True return model - def add_setting(self, name, value): - ''' + def add_setting(self, name: str, value: Any): + """ Adds a database setting that will be sent with every request. For example, `db.add_setting("max_execution_time", 10)` will limit query execution time to 10 seconds. The name must be string, and the value is converted to string in case it isn't. To remove a setting, pass `None` as the value. - ''' + """ assert isinstance(name, str), 'Setting name must be a string' if value is None: self.settings.pop(name, None) @@ -203,12 +209,12 @@ class Database(object): self.settings[name] = str(value) def insert(self, model_instances, batch_size=1000): - ''' + """ Insert records into the database. - `model_instances`: any iterable containing instances of a single model class. - `batch_size`: number of records to send per chunk (use a lower number if your records are very large). - ''' + """ from io import BytesIO i = iter(model_instances) try: @@ -247,13 +253,17 @@ class Database(object): yield buf.getvalue() self._send(gen()) - def count(self, model_class, conditions=None): - ''' + def count( + self, + model_class: Optional[Type[MODEL]], + conditions: Optional[Union[str, Q]] = None + ) -> int: + """ Counts the number of records in the model's table. - `model_class`: the model to count. - `conditions`: optional SQL conditions (contents of the WHERE clause). - ''' + """ from clickhouse_orm.query import Q query = 'SELECT count() FROM $table' if conditions: @@ -264,15 +274,20 @@ class Database(object): r = self._send(query) return int(r.text) if r.text else 0 - def select(self, query, model_class=None, settings=None): - ''' + def select( + self, + query: str, + model_class: Optional[Type[MODEL]] = None, + settings: Optional[dict] = None + ) -> Generator[MODEL, None, None]: + """ Performs a query and returns a generator of model instances. - `query`: the SQL query to execute. - `model_class`: the model class matching the query's table, or `None` for getting back instances of an ad-hoc model. - `settings`: query settings to send as HTTP GET parameters - ''' + """ query += ' FORMAT TabSeparatedWithNamesAndTypes' query = self._substitute(query, model_class) r = self._send(query, settings, True) @@ -285,19 +300,27 @@ class Database(object): if line: yield model_class.from_tsv(line, field_names, self.server_timezone, self) - def raw(self, query, settings=None, stream=False): - ''' + def raw(self, query: str, settings: Optional[dict] = None, stream: bool = False) -> str: + """ Performs a query and returns its output as text. - `query`: the SQL query to execute. - `settings`: query settings to send as HTTP GET parameters - `stream`: if true, the HTTP response from ClickHouse will be streamed. - ''' + """ query = self._substitute(query, None) return self._send(query, settings=settings, stream=stream).text - def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None): - ''' + def paginate( + self, + model_class: Type[MODEL], + order_by: str, + page_num: int = 1, + page_size: int = 100, + conditions=None, + settings: Optional[dict] = None + ): + """ Selects records and returns a single page of model instances. - `model_class`: the model class matching the query's table, @@ -310,7 +333,7 @@ class Database(object): The result is a namedtuple containing `objects` (list), `number_of_objects`, `pages_total`, `number` (of the current page), and `page_size`. - ''' + """ from clickhouse_orm.query import Q count = self.count(model_class, conditions) pages_total = int(ceil(count / float(page_size))) @@ -336,13 +359,13 @@ class Database(object): ) def migrate(self, migrations_package_name, up_to=9999): - ''' + """ Executes schema migrations. - `migrations_package_name` - fully qualified name of the Python package containing the migrations. - `up_to` - number of the last migration to apply. - ''' + """ from .migrations import MigrationHistory logger = logging.getLogger('migrations') applied_migrations = self._get_applied_migrations(migrations_package_name) diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index 2807e2c..c2c9194 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -17,23 +17,25 @@ logger = getLogger('clickhouse_orm') class Field(FunctionOperatorsMixin): - ''' + """ Abstract base class for all field types. - ''' - name = None # this is set by the parent model - parent = None # this is set by the parent model - creation_counter = 0 # used for keeping the model fields ordered - class_default = 0 # should be overridden by concrete subclasses - db_type = None # should be overridden by concrete subclasses + """ + name = None # this is set by the parent model + parent = None # this is set by the parent model + creation_counter = 0 # used for keeping the model fields ordered + class_default = 0 # should be overridden by concrete subclasses + db_type = None # should be overridden by concrete subclasses def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): assert [default, alias, materialized].count(None) >= 2, \ "Only one of default, alias and materialized parameters can be given" - assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\ + assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ "Alias parameter must be a string or function object, if given" - assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\ + assert materialized is None or isinstance(materialized, F) or isinstance(materialized, + str) and materialized != "", \ "Materialized parameter must be a string or function object, if given" - assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given" + assert readonly is None or type( + readonly) is bool, "readonly parameter must be bool if given" assert codec is None or isinstance(codec, str) and codec != "", \ "Codec field must be string, if given" @@ -52,42 +54,43 @@ class Field(FunctionOperatorsMixin): return '<%s>' % self.__class__.__name__ def to_python(self, value, timezone_in_use): - ''' + """ Converts the input value into the expected Python data type, raising ValueError if the data can't be converted. Returns the converted value. Subclasses should override this. The timezone_in_use parameter should be consulted when parsing datetime fields. - ''' - return value # pragma: no cover + """ + return value # pragma: no cover def validate(self, value): - ''' + """ Called after to_python to validate that the value is suitable for the field's database type. Subclasses should override this. - ''' + """ pass def _range_check(self, value, min_value, max_value): - ''' + """ Utility method to check that the given value is between min_value and max_value. - ''' + """ if value < min_value or value > max_value: - raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) + raise ValueError('%s out of range - %s is not between %s and %s' % ( + self.__class__.__name__, value, min_value, max_value)) def to_db_string(self, value, quote=True): - ''' + """ Returns the field's value prepared for writing to the database. When quote is true, strings are surrounded by single quotes. - ''' + """ return escape(value, quote) def get_sql(self, with_default_expression=True, db=None): - ''' + """ Returns an SQL expression describing the field (e.g. for CREATE TABLE). - `with_default_expression`: If True, adds default value to sql. It doesn't affect fields with alias and materialized values. - `db`: Database, used for checking supported features. - ''' + """ sql = self.db_type args = self.get_db_type_args() if args: @@ -135,7 +138,6 @@ class Field(FunctionOperatorsMixin): class StringField(Field): - class_default = '' db_type = 'String' @@ -162,11 +164,11 @@ class FixedStringField(StringField): if isinstance(value, str): value = value.encode('UTF-8') if len(value) > self._length: - raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length)) + raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % ( + len(value), self._length)) class DateField(Field): - min_value = datetime.date(1970, 1, 1) max_value = datetime.date(2105, 12, 31) class_default = min_value @@ -193,7 +195,6 @@ class DateField(Field): class DateTimeField(Field): - class_default = datetime.datetime.fromtimestamp(0, pytz.utc) db_type = 'DateTime' @@ -292,9 +293,10 @@ class DateTime64Field(DateTimeField): class BaseIntField(Field): - ''' + """ Abstract base class for all integer-type fields. - ''' + """ + def to_python(self, value, timezone_in_use): try: return int(value) @@ -311,58 +313,50 @@ class BaseIntField(Field): class UInt8Field(BaseIntField): - min_value = 0 - max_value = 2**8 - 1 + max_value = 2 ** 8 - 1 db_type = 'UInt8' class UInt16Field(BaseIntField): - min_value = 0 - max_value = 2**16 - 1 + max_value = 2 ** 16 - 1 db_type = 'UInt16' class UInt32Field(BaseIntField): - min_value = 0 - max_value = 2**32 - 1 + max_value = 2 ** 32 - 1 db_type = 'UInt32' class UInt64Field(BaseIntField): - min_value = 0 - max_value = 2**64 - 1 + max_value = 2 ** 64 - 1 db_type = 'UInt64' class Int8Field(BaseIntField): - - min_value = -2**7 - max_value = 2**7 - 1 + min_value = -2 ** 7 + max_value = 2 ** 7 - 1 db_type = 'Int8' class Int16Field(BaseIntField): - - min_value = -2**15 - max_value = 2**15 - 1 + min_value = -2 ** 15 + max_value = 2 ** 15 - 1 db_type = 'Int16' class Int32Field(BaseIntField): - - min_value = -2**31 - max_value = 2**31 - 1 + min_value = -2 ** 31 + max_value = 2 ** 31 - 1 db_type = 'Int32' class Int64Field(BaseIntField): - - min_value = -2**63 - max_value = 2**63 - 1 + min_value = -2 ** 63 + max_value = 2 ** 63 - 1 db_type = 'Int64' @@ -384,21 +378,20 @@ class BaseFloatField(Field): class Float32Field(BaseFloatField): - db_type = 'Float32' class Float64Field(BaseFloatField): - db_type = 'Float64' class DecimalField(Field): - ''' + """ Base class for all decimal fields. Can also be used directly. - ''' + """ - def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None): + def __init__(self, precision, scale, default=None, alias=None, materialized=None, + readonly=None): assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' self.precision = precision @@ -406,7 +399,7 @@ class DecimalField(Field): self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) with localcontext() as ctx: ctx.prec = 38 - self.exp = Decimal(10) ** -self.scale # for rounding to the required scale + self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.min_value = -self.max_value super(DecimalField, self).__init__(default, alias, materialized, readonly) @@ -418,7 +411,7 @@ class DecimalField(Field): except: raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) if not value.is_finite(): - raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) + raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) return self._round(value) def to_db_string(self, value, quote=True): @@ -455,11 +448,12 @@ class Decimal128Field(DecimalField): class BaseEnumField(Field): - ''' + """ Abstract base class for all enum-type fields. - ''' + """ - def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None): + def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, + codec=None): self.enum_cls = enum_cls if default is None: default = list(enum_cls)[0] @@ -494,10 +488,10 @@ class BaseEnumField(Field): @classmethod def create_ad_hoc_field(cls, db_type): - ''' + """ Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" this method returns a matching enum field. - ''' + """ import re from enum import Enum members = {} @@ -509,22 +503,22 @@ class BaseEnumField(Field): class Enum8Field(BaseEnumField): - db_type = 'Enum8' class Enum16Field(BaseEnumField): - db_type = 'Enum16' class ArrayField(Field): - class_default = [] - def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): - assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance" - assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM" + def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, + codec=None): + assert isinstance(inner_field, Field), \ + "The first argument of ArrayField must be a Field instance" + assert not isinstance(inner_field, ArrayField), \ + "Multidimensional array fields are not supported by the ORM" self.inner_field = inner_field super(ArrayField, self).__init__(default, alias, materialized, readonly, codec) @@ -548,12 +542,11 @@ class ArrayField(Field): def get_sql(self, with_default_expression=True, db=None): sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression and self.codec and db and db.has_codec_support: - sql+= ' CODEC(%s)' % self.codec + sql += ' CODEC(%s)' % self.codec return sql class UUIDField(Field): - class_default = UUID(int=0) db_type = 'UUID' @@ -576,7 +569,6 @@ class UUIDField(Field): class IPv4Field(Field): - class_default = 0 db_type = 'IPv4' @@ -593,7 +585,6 @@ class IPv4Field(Field): class IPv6Field(Field): - class_default = 0 db_type = 'IPv6' @@ -610,17 +601,19 @@ class IPv6Field(Field): class NullableField(Field): - class_default = None def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None): - assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field) + assert isinstance(inner_field, Field), \ + "The first argument of NullableField must be a Field instance." \ + " Not: {}".format(inner_field) self.inner_field = inner_field self._null_values = [None] if extra_null_values: self._null_values.extend(extra_null_values) - super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec) + super(NullableField, self).__init__(default, alias, materialized, readonly=None, + codec=codec) def to_python(self, value, timezone_in_use): if value == '\\N' or value in self._null_values: @@ -644,10 +637,16 @@ class NullableField(Field): class LowCardinalityField(Field): - def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): - assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field) - assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM" - assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead" + def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, + codec=None): + assert isinstance(inner_field, Field), \ + "The first argument of LowCardinalityField must be a Field instance." \ + " Not: {}".format(inner_field) + assert not isinstance(inner_field, LowCardinalityField), \ + "LowCardinality inner fields are not supported by the ORM" + assert not isinstance(inner_field, ArrayField), \ + "Array field inside LowCardinality are not supported by the ORM." \ + " Use Array(LowCardinality) instead" self.inner_field = inner_field self.class_default = self.inner_field.class_default super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec) @@ -666,7 +665,10 @@ class LowCardinalityField(Field): sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) else: sql = self.inner_field.get_sql(with_default_expression=False) - logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__)) + logger.warning( + 'LowCardinalityField not supported on clickhouse-server version < 19.0' + ' using {} as fallback'.format(self.inner_field.__class__.__name__) + ) if with_default_expression: sql += self._extra_params(db) return sql diff --git a/src/clickhouse_orm/funcs.py b/src/clickhouse_orm/funcs.py index 4f6cd4a..2d0d1d6 100644 --- a/src/clickhouse_orm/funcs.py +++ b/src/clickhouse_orm/funcs.py @@ -1121,6 +1121,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): def arrayElement(arr, n): return F('arrayElement', arr, n) + @staticmethod + def tupleElement(arr, n): + return F('tupleElement', arr, n) + @staticmethod def has(arr, x): return F('has', arr, x) @@ -1133,6 +1137,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): def hasAny(arr, x): return F('hasAny', arr, x) + @staticmethod + def geohashEncode(x, y, precision=12): + return F('geohashEncode', x, y, precision) + @staticmethod def indexOf(arr, x): return F('indexOf', arr, x) diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index 0fc5e7f..a9f7eb8 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -3,11 +3,12 @@ import sys from collections import OrderedDict from itertools import chain from logging import getLogger +from typing import TypeVar import pytz from .fields import Field, StringField -from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql, unescape +from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql from .query import QuerySet from .funcs import F from .engines import Merge, Distributed @@ -15,75 +16,74 @@ from .engines import Merge, Distributed logger = getLogger('clickhouse_orm') - class Constraint: - ''' + """ Defines a model constraint. - ''' + """ - name = None # this is set by the parent model - parent = None # this is set by the parent model + name = None # this is set by the parent model + parent = None # this is set by the parent model def __init__(self, expr): - ''' + """ Initializer. Expects an expression that ClickHouse will verify when inserting data. - ''' + """ self.expr = expr def create_table_sql(self): - ''' + """ Returns the SQL statement for defining this constraint during table creation. - ''' + """ return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) class Index: - ''' + """ Defines a data-skipping index. - ''' + """ - name = None # this is set by the parent model - parent = None # this is set by the parent model + name = None # this is set by the parent model + parent = None # this is set by the parent model def __init__(self, expr, type, granularity): - ''' + """ Initializer. - `expr` - a column, expression, or tuple of columns and expressions to index. - `type` - the index type. Use one of the following methods to specify the type: `Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`. - `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine). - ''' + """ self.expr = expr self.type = type self.granularity = granularity def create_table_sql(self): - ''' + """ Returns the SQL statement for defining this index during table creation. - ''' + """ return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity) @staticmethod def minmax(): - ''' + """ An index that stores extremes of the specified expression (if the expression is tuple, then it stores extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. - ''' + """ return 'minmax' @staticmethod def set(max_rows): - ''' + """ An index that stores unique values of the specified expression (no more than max_rows rows, or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable on a block of data. - ''' + """ return 'set(%d)' % max_rows @staticmethod def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): - ''' + """ An index that stores a Bloom filter containing all ngrams from a block of data. Works only with strings. Can be used for optimization of equals, like and in expressions. @@ -92,12 +92,12 @@ class Index: for example 256 or 512, because it can be compressed well). - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. - ''' + """ return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) @staticmethod def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): - ''' + """ An index that stores a Bloom filter containing string tokens. Tokens are sequences separated by non-alphanumeric characters. @@ -105,7 +105,7 @@ class Index: for example 256 or 512, because it can be compressed well). - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. - ''' + """ return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) @staticmethod @@ -253,7 +253,7 @@ class ModelBase(type): class Model(metaclass=ModelBase): - ''' + """ A base class for ORM models. Each model class represent a ClickHouse table. For example: class CPUStats(Model): @@ -261,7 +261,7 @@ class Model(metaclass=ModelBase): cpu_id = UInt16Field() cpu_percent = Float32Field() engine = Memory() - ''' + """ engine = None @@ -274,12 +274,12 @@ class Model(metaclass=ModelBase): _database = None def __init__(self, **kwargs): - ''' + """ Creates a model instance, using keyword arguments as field values. Since values are immediately converted to their Pythonic type, invalid values will cause a `ValueError` to be raised. Unrecognized field names will cause an `AttributeError`. - ''' + """ super(Model, self).__init__() # Assign default values self.__dict__.update(self._defaults) @@ -292,10 +292,10 @@ class Model(metaclass=ModelBase): raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) def __setattr__(self, name, value): - ''' + """ When setting a field value, converts the value to its Pythonic type and validates it. This may raise a `ValueError`. - ''' + """ field = self.get_field(name) if field and (value != NO_VALUE): try: @@ -308,50 +308,50 @@ class Model(metaclass=ModelBase): super(Model, self).__setattr__(name, value) def set_database(self, db): - ''' + """ Sets the `Database` that this model instance belongs to. This is done automatically when the instance is read from the database or written to it. - ''' + """ # This can not be imported globally due to circular import from .database import Database assert isinstance(db, Database), "database must be database.Database instance" self._database = db def get_database(self): - ''' + """ Gets the `Database` that this model instance belongs to. Returns `None` unless the instance was read from the database or written to it. - ''' + """ return self._database def get_field(self, name): - ''' + """ Gets a `Field` instance given its name, or `None` if not found. - ''' + """ return self._fields.get(name) @classmethod def table_name(cls): - ''' + """ Returns the model's database table name. By default this is the class name converted to lowercase. Override this if you want to use a different table name. - ''' + """ return cls.__name__.lower() @classmethod def has_funcs_as_defaults(cls): - ''' + """ Return True if some of the model's fields use a function expression as a default value. This requires special handling when inserting instances. - ''' + """ return cls._has_funcs_as_defaults @classmethod def create_table_sql(cls, db): - ''' + """ Returns the SQL statement for creating a table for this model. - ''' + """ parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] # Fields items = [] @@ -371,14 +371,14 @@ class Model(metaclass=ModelBase): @classmethod def drop_table_sql(cls, db): - ''' + """ Returns the SQL command for deleting this model's table. - ''' + """ return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) @classmethod def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): - ''' + """ Create a model instance from a tab-separated line. The line may or may not include a newline. The `field_names` list must match the fields defined in the model, but does not have to include all of them. @@ -386,7 +386,7 @@ class Model(metaclass=ModelBase): - `field_names`: names of the model fields in the data. - `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones. - `database`: if given, sets the database that this instance belongs to. - ''' + """ values = iter(parse_tsv(line)) kwargs = {} for name in field_names: @@ -401,22 +401,22 @@ class Model(metaclass=ModelBase): return obj def to_tsv(self, include_readonly=True): - ''' + """ Returns the instance's column values as a tab-separated line. A newline is not included. - `include_readonly`: if false, returns only fields that can be inserted into database. - ''' + """ data = self.__dict__ fields = self.fields(writable=not include_readonly) return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) def to_tskv(self, include_readonly=True): - ''' + """ Returns the instance's column keys and values as a tab-separated line. A newline is not included. Fields that were not assigned a value are omitted. - `include_readonly`: if false, returns only fields that can be inserted into database. - ''' + """ data = self.__dict__ fields = self.fields(writable=not include_readonly) parts = [] @@ -426,20 +426,20 @@ class Model(metaclass=ModelBase): return '\t'.join(parts) def to_db_string(self): - ''' + """ Returns the instance as a bytestring ready to be inserted into the database. - ''' + """ s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) s += '\n' return s.encode('utf-8') def to_dict(self, include_readonly=True, field_names=None): - ''' + """ Returns the instance's column values as a dict. - `include_readonly`: if false, returns only fields that can be inserted into database. - `field_names`: an iterable of field names to return (optional) - ''' + """ fields = self.fields(writable=not include_readonly) if field_names is not None: @@ -450,66 +450,68 @@ class Model(metaclass=ModelBase): @classmethod def objects_in(cls, database): - ''' + """ Returns a `QuerySet` for selecting instances of this model class. - ''' + """ return QuerySet(cls, database) @classmethod def fields(cls, writable=False): - ''' + """ Returns an `OrderedDict` of the model's fields (from name to `Field` instance). If `writable` is true, only writable fields are included. Callers should not modify the dictionary. - ''' + """ # noinspection PyProtectedMember,PyUnresolvedReferences return cls._writable_fields if writable else cls._fields @classmethod def is_read_only(cls): - ''' + """ Returns true if the model is marked as read only. - ''' + """ return cls._readonly @classmethod def is_system_model(cls): - ''' + """ Returns true if the model represents a system table. - ''' + """ return cls._system class BufferModel(Model): @classmethod - def create_table_sql(cls, db): - ''' + def create_table_sql(cls, db) -> str: + """ Returns the SQL statement for creating a table for this model. - ''' - parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, - cls.engine.main_model.table_name())] + """ + parts = [ + 'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % ( + db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name()) + ] engine_str = cls.engine.create_table_sql(db) parts.append(engine_str) return ' '.join(parts) class MergeModel(Model): - ''' + """ Model for Merge engine Predefines virtual _table column an controls that rows can't be inserted to this table type https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge - ''' + """ readonly = True # Virtual fields can't be inserted into database _table = StringField(readonly=True) @classmethod - def create_table_sql(cls, db): - ''' + def create_table_sql(cls, db) -> str: + """ Returns the SQL statement for creating a table for this model. - ''' + """ assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] cols = [] @@ -530,11 +532,12 @@ class DistributedModel(Model): """ def set_database(self, db): - ''' + """ Sets the `Database` that this model instance belongs to. This is done automatically when the instance is read from the database or written to it. - ''' - assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" + """ + assert isinstance(self.engine, Distributed),\ + "engine must be an instance of engines.Distributed" res = super(DistributedModel, self).set_database(db) return res @@ -590,10 +593,10 @@ class DistributedModel(Model): cls.engine.table = storage_models[0] @classmethod - def create_table_sql(cls, db): - ''' + def create_table_sql(cls, db) -> str: + """ Returns the SQL statement for creating a table for this model. - ''' + """ assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" cls.fix_engine_table() @@ -606,4 +609,5 @@ class DistributedModel(Model): # Expose only relevant classes in import * +MODEL = TypeVar('MODEL', bound=Model) __all__ = get_subclass_names(locals(), (Model, Constraint, Index)) diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index 3074cb2..1fd5248 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals +from math import ceil +from copy import copy, deepcopy import pytz -from copy import copy, deepcopy -from math import ceil -from datetime import date, datetime + from .utils import comma_join, string_or_func, arg_to_sql @@ -393,7 +393,7 @@ class QuerySet(object): sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) if self._grouping_fields: - sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) + sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields) if self._grouping_with_totals: sql += ' WITH TOTALS' @@ -548,7 +548,9 @@ class QuerySet(object): from .engines import CollapsingMergeTree, ReplacingMergeTree if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): raise TypeError( - 'final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') + 'final() method can be used only with the CollapsingMergeTree' + ' and ReplacingMergeTree engines' + ) qs = copy(self) qs._final = True @@ -576,14 +578,15 @@ class QuerySet(object): fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % ( - self._model_cls.table_name(), fields, conditions) + self._model_cls.table_name(), fields, conditions + ) self._database.raw(sql) return self def _verify_mutation_allowed(self): - ''' + """ Checks that the queryset's state allows mutations. Raises an AssertionError if not. - ''' + """ assert not self._limits, 'Mutations are not allowed after slicing the queryset' assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' assert not self._distinct, 'Mutations are not allowed after calling distinct()'