add typing hint

This commit is contained in:
sswest 2022-05-26 17:02:32 +08:00
parent 17a5c30bfd
commit 7138dfe8c3
6 changed files with 262 additions and 219 deletions

View File

@ -74,6 +74,9 @@ class PointField(Field):
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return value.to_db_string() return value.to_db_string()
def __getitem__(self, item):
return
class RingField(Field): class RingField(Field):
class_default = [Point(0, 0)] class_default = [Point(0, 0)]

View File

@ -1,26 +1,28 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
import requests
from collections import namedtuple
from .models import ModelBase
from .utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime
from string import Template
import pytz
import logging import logging
import datetime
from math import ceil
from string import Template
from collections import namedtuple
from typing import Type, Optional, Generator, Union, Any
import pytz
import requests
from .models import ModelBase, MODEL
from .utils import parse_tsv, import_submodules
from .query import Q
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger('clickhouse_orm')
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
class DatabaseException(Exception): class DatabaseException(Exception):
''' """
Raised when a database operation fails. Raised when a database operation fails.
''' """
pass pass
@ -80,15 +82,15 @@ class ServerError(DatabaseException):
class Database(object): class Database(object):
''' """
Database instances connect to a specific ClickHouse database for running queries, Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations. inserting data and other operations.
''' """
def __init__(self, db_name, db_url='http://localhost:8123/', def __init__(self, db_name, db_url='http://localhost:8123/',
username=None, password=None, readonly=False, autocreate=True, username=None, password=None, readonly=False, autocreate=True,
timeout=60, verify_ssl_cert=True, log_statements=False): timeout=60, verify_ssl_cert=True, log_statements=False):
''' """
Initializes a database instance. Unless it's readonly, the database will be Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist. created on the ClickHouse server if it does not already exist.
@ -101,7 +103,7 @@ class Database(object):
- `timeout`: the connection timeout in seconds. - `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS. - `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged. - `log_statements`: when True, all database statements are logged.
''' """
self.db_name = db_name self.db_name = db_name
self.db_url = db_url self.db_url = db_url
self.readonly = False self.readonly = False
@ -130,55 +132,59 @@ class Database(object):
self.has_low_cardinality_support = self.server_version >= (19, 0) self.has_low_cardinality_support = self.server_version >= (19, 0)
def create_database(self): def create_database(self):
''' """
Creates the database on the ClickHouse server if it does not already exist. Creates the database on the ClickHouse server if it does not already exist.
''' """
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
self.db_exists = True self.db_exists = True
def drop_database(self): def drop_database(self):
''' """
Deletes the database on the ClickHouse server. Deletes the database on the ClickHouse server.
''' """
self._send('DROP DATABASE `%s`' % self.db_name) self._send('DROP DATABASE `%s`' % self.db_name)
self.db_exists = False self.db_exists = False
def create_table(self, model_class): def create_table(self, model_class: Type[MODEL]) -> None:
''' """
Creates a table for the given model class, if it does not exist already. Creates a table for the given model class, if it does not exist already.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't create system table") raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None: if getattr(model_class, 'engine') is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__) raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self)) self._send(model_class.create_table_sql(self))
def drop_table(self, model_class): def drop_table(self, model_class: Type[MODEL]) -> None:
''' """
Drops the database table of the given model class, if it exists. Drops the database table of the given model class, if it exists.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't drop system table") raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self)) self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class): def does_table_exist(self, model_class: Type[MODEL]) -> bool:
''' """
Checks whether a table for the given model class already exists. Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name. Note that this only checks for existence of a table with the expected name.
''' """
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name())) r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1' return r.text.strip() == '1'
def get_model_for_table(self, table_name, system_table=False): def get_model_for_table(
''' self,
table_name: str,
system_table: bool = False
):
"""
Generates a model class from an existing table in the database. Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class, This can be used for querying tables which don't have a corresponding model class,
for example system tables. for example system tables.
- `table_name`: the table to create a model for - `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database - `system_table`: whether the table is a system table, or belongs to the current database
''' """
db_name = 'system' if system_table else self.db_name db_name = 'system' if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines() lines = self._send(sql).iter_lines()
@ -188,14 +194,14 @@ class Database(object):
model._system = model._readonly = True model._system = model._readonly = True
return model return model
def add_setting(self, name, value): def add_setting(self, name: str, value: Any):
''' """
Adds a database setting that will be sent with every request. Adds a database setting that will be sent with every request.
For example, `db.add_setting("max_execution_time", 10)` will For example, `db.add_setting("max_execution_time", 10)` will
limit query execution time to 10 seconds. limit query execution time to 10 seconds.
The name must be string, and the value is converted to string in case The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value. it isn't. To remove a setting, pass `None` as the value.
''' """
assert isinstance(name, str), 'Setting name must be a string' assert isinstance(name, str), 'Setting name must be a string'
if value is None: if value is None:
self.settings.pop(name, None) self.settings.pop(name, None)
@ -203,12 +209,12 @@ class Database(object):
self.settings[name] = str(value) self.settings[name] = str(value)
def insert(self, model_instances, batch_size=1000): def insert(self, model_instances, batch_size=1000):
''' """
Insert records into the database. Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class. - `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large). - `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
''' """
from io import BytesIO from io import BytesIO
i = iter(model_instances) i = iter(model_instances)
try: try:
@ -247,13 +253,17 @@ class Database(object):
yield buf.getvalue() yield buf.getvalue()
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(
''' self,
model_class: Optional[Type[MODEL]],
conditions: Optional[Union[str, Q]] = None
) -> int:
"""
Counts the number of records in the model's table. Counts the number of records in the model's table.
- `model_class`: the model to count. - `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause). - `conditions`: optional SQL conditions (contents of the WHERE clause).
''' """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table' query = 'SELECT count() FROM $table'
if conditions: if conditions:
@ -264,15 +274,20 @@ class Database(object):
r = self._send(query) r = self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None): def select(
''' self,
query: str,
model_class: Optional[Type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
"""
Performs a query and returns a generator of model instances. Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model. or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
''' """
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += ' FORMAT TabSeparatedWithNamesAndTypes'
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query, settings, True) r = self._send(query, settings, True)
@ -285,19 +300,27 @@ class Database(object):
if line: if line:
yield model_class.from_tsv(line, field_names, self.server_timezone, self) yield model_class.from_tsv(line, field_names, self.server_timezone, self)
def raw(self, query, settings=None, stream=False): def raw(self, query: str, settings: Optional[dict] = None, stream: bool = False) -> str:
''' """
Performs a query and returns its output as text. Performs a query and returns its output as text.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed. - `stream`: if true, the HTTP response from ClickHouse will be streamed.
''' """
query = self._substitute(query, None) query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None): def paginate(
''' self,
model_class: Type[MODEL],
order_by: str,
page_num: int = 1,
page_size: int = 100,
conditions=None,
settings: Optional[dict] = None
):
"""
Selects records and returns a single page of model instances. Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
@ -310,7 +333,7 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`, The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
''' """
from clickhouse_orm.query import Q from clickhouse_orm.query import Q
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
@ -336,13 +359,13 @@ class Database(object):
) )
def migrate(self, migrations_package_name, up_to=9999): def migrate(self, migrations_package_name, up_to=9999):
''' """
Executes schema migrations. Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package - `migrations_package_name` - fully qualified name of the Python package
containing the migrations. containing the migrations.
- `up_to` - number of the last migration to apply. - `up_to` - number of the last migration to apply.
''' """
from .migrations import MigrationHistory from .migrations import MigrationHistory
logger = logging.getLogger('migrations') logger = logging.getLogger('migrations')
applied_migrations = self._get_applied_migrations(migrations_package_name) applied_migrations = self._get_applied_migrations(migrations_package_name)

View File

@ -17,23 +17,25 @@ logger = getLogger('clickhouse_orm')
class Field(FunctionOperatorsMixin): class Field(FunctionOperatorsMixin):
''' """
Abstract base class for all field types. Abstract base class for all field types.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered creation_counter = 0 # used for keeping the model fields ordered
class_default = 0 # should be overridden by concrete subclasses class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert [default, alias, materialized].count(None) >= 2, \ assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given" "Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\ assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
"Alias parameter must be a string or function object, if given" "Alias parameter must be a string or function object, if given"
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\ assert materialized is None or isinstance(materialized, F) or isinstance(materialized,
str) and materialized != "", \
"Materialized parameter must be a string or function object, if given" "Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given" assert readonly is None or type(
readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \ assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given" "Codec field must be string, if given"
@ -52,42 +54,43 @@ class Field(FunctionOperatorsMixin):
return '<%s>' % self.__class__.__name__ return '<%s>' % self.__class__.__name__
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
''' """
Converts the input value into the expected Python data type, raising ValueError if the Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this. data can't be converted. Returns the converted value. Subclasses should override this.
The timezone_in_use parameter should be consulted when parsing datetime fields. The timezone_in_use parameter should be consulted when parsing datetime fields.
''' """
return value # pragma: no cover return value # pragma: no cover
def validate(self, value): def validate(self, value):
''' """
Called after to_python to validate that the value is suitable for the field's database type. Called after to_python to validate that the value is suitable for the field's database type.
Subclasses should override this. Subclasses should override this.
''' """
pass pass
def _range_check(self, value, min_value, max_value): def _range_check(self, value, min_value, max_value):
''' """
Utility method to check that the given value is between min_value and max_value. Utility method to check that the given value is between min_value and max_value.
''' """
if value < min_value or value > max_value: if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) raise ValueError('%s out of range - %s is not between %s and %s' % (
self.__class__.__name__, value, min_value, max_value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
''' """
Returns the field's value prepared for writing to the database. Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes. When quote is true, strings are surrounded by single quotes.
''' """
return escape(value, quote) return escape(value, quote)
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
''' """
Returns an SQL expression describing the field (e.g. for CREATE TABLE). Returns an SQL expression describing the field (e.g. for CREATE TABLE).
- `with_default_expression`: If True, adds default value to sql. - `with_default_expression`: If True, adds default value to sql.
It doesn't affect fields with alias and materialized values. It doesn't affect fields with alias and materialized values.
- `db`: Database, used for checking supported features. - `db`: Database, used for checking supported features.
''' """
sql = self.db_type sql = self.db_type
args = self.get_db_type_args() args = self.get_db_type_args()
if args: if args:
@ -135,7 +138,6 @@ class Field(FunctionOperatorsMixin):
class StringField(Field): class StringField(Field):
class_default = '' class_default = ''
db_type = 'String' db_type = 'String'
@ -162,11 +164,11 @@ class FixedStringField(StringField):
if isinstance(value, str): if isinstance(value, str):
value = value.encode('UTF-8') value = value.encode('UTF-8')
if len(value) > self._length: if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length)) raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (
len(value), self._length))
class DateField(Field): class DateField(Field):
min_value = datetime.date(1970, 1, 1) min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31) max_value = datetime.date(2105, 12, 31)
class_default = min_value class_default = min_value
@ -193,7 +195,6 @@ class DateField(Field):
class DateTimeField(Field): class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc) class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime' db_type = 'DateTime'
@ -292,9 +293,10 @@ class DateTime64Field(DateTimeField):
class BaseIntField(Field): class BaseIntField(Field):
''' """
Abstract base class for all integer-type fields. Abstract base class for all integer-type fields.
''' """
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
try: try:
return int(value) return int(value)
@ -311,58 +313,50 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField): class UInt8Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**8 - 1 max_value = 2 ** 8 - 1
db_type = 'UInt8' db_type = 'UInt8'
class UInt16Field(BaseIntField): class UInt16Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**16 - 1 max_value = 2 ** 16 - 1
db_type = 'UInt16' db_type = 'UInt16'
class UInt32Field(BaseIntField): class UInt32Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**32 - 1 max_value = 2 ** 32 - 1
db_type = 'UInt32' db_type = 'UInt32'
class UInt64Field(BaseIntField): class UInt64Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**64 - 1 max_value = 2 ** 64 - 1
db_type = 'UInt64' db_type = 'UInt64'
class Int8Field(BaseIntField): class Int8Field(BaseIntField):
min_value = -2 ** 7
min_value = -2**7 max_value = 2 ** 7 - 1
max_value = 2**7 - 1
db_type = 'Int8' db_type = 'Int8'
class Int16Field(BaseIntField): class Int16Field(BaseIntField):
min_value = -2 ** 15
min_value = -2**15 max_value = 2 ** 15 - 1
max_value = 2**15 - 1
db_type = 'Int16' db_type = 'Int16'
class Int32Field(BaseIntField): class Int32Field(BaseIntField):
min_value = -2 ** 31
min_value = -2**31 max_value = 2 ** 31 - 1
max_value = 2**31 - 1
db_type = 'Int32' db_type = 'Int32'
class Int64Field(BaseIntField): class Int64Field(BaseIntField):
min_value = -2 ** 63
min_value = -2**63 max_value = 2 ** 63 - 1
max_value = 2**63 - 1
db_type = 'Int64' db_type = 'Int64'
@ -384,21 +378,20 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField): class Float32Field(BaseFloatField):
db_type = 'Float32' db_type = 'Float32'
class Float64Field(BaseFloatField): class Float64Field(BaseFloatField):
db_type = 'Float64' db_type = 'Float64'
class DecimalField(Field): class DecimalField(Field):
''' """
Base class for all decimal fields. Can also be used directly. Base class for all decimal fields. Can also be used directly.
''' """
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, precision, scale, default=None, alias=None, materialized=None,
readonly=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
self.precision = precision self.precision = precision
@ -406,7 +399,7 @@ class DecimalField(Field):
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale)
with localcontext() as ctx: with localcontext() as ctx:
ctx.prec = 38 ctx.prec = 38
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
self.min_value = -self.max_value self.min_value = -self.max_value
super(DecimalField, self).__init__(default, alias, materialized, readonly) super(DecimalField, self).__init__(default, alias, materialized, readonly)
@ -418,7 +411,7 @@ class DecimalField(Field):
except: except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
if not value.is_finite(): if not value.is_finite():
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
return self._round(value) return self._round(value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
@ -455,11 +448,12 @@ class Decimal128Field(DecimalField):
class BaseEnumField(Field): class BaseEnumField(Field):
''' """
Abstract base class for all enum-type fields. Abstract base class for all enum-type fields.
''' """
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None,
codec=None):
self.enum_cls = enum_cls self.enum_cls = enum_cls
if default is None: if default is None:
default = list(enum_cls)[0] default = list(enum_cls)[0]
@ -494,10 +488,10 @@ class BaseEnumField(Field):
@classmethod @classmethod
def create_ad_hoc_field(cls, db_type): def create_ad_hoc_field(cls, db_type):
''' """
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field. this method returns a matching enum field.
''' """
import re import re
from enum import Enum from enum import Enum
members = {} members = {}
@ -509,22 +503,22 @@ class BaseEnumField(Field):
class Enum8Field(BaseEnumField): class Enum8Field(BaseEnumField):
db_type = 'Enum8' db_type = 'Enum8'
class Enum16Field(BaseEnumField): class Enum16Field(BaseEnumField):
db_type = 'Enum16' db_type = 'Enum16'
class ArrayField(Field): class ArrayField(Field):
class_default = [] class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance" codec=None):
assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM" assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \
"Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec) super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
@ -548,12 +542,11 @@ class ArrayField(Field):
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support: if with_default_expression and self.codec and db and db.has_codec_support:
sql+= ' CODEC(%s)' % self.codec sql += ' CODEC(%s)' % self.codec
return sql return sql
class UUIDField(Field): class UUIDField(Field):
class_default = UUID(int=0) class_default = UUID(int=0)
db_type = 'UUID' db_type = 'UUID'
@ -576,7 +569,6 @@ class UUIDField(Field):
class IPv4Field(Field): class IPv4Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv4' db_type = 'IPv4'
@ -593,7 +585,6 @@ class IPv4Field(Field):
class IPv6Field(Field): class IPv6Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv6' db_type = 'IPv6'
@ -610,17 +601,19 @@ class IPv6Field(Field):
class NullableField(Field): class NullableField(Field):
class_default = None class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None, def __init__(self, inner_field, default=None, alias=None, materialized=None,
extra_null_values=None, codec=None): extra_null_values=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field) assert isinstance(inner_field, Field), \
"The first argument of NullableField must be a Field instance." \
" Not: {}".format(inner_field)
self.inner_field = inner_field self.inner_field = inner_field
self._null_values = [None] self._null_values = [None]
if extra_null_values: if extra_null_values:
self._null_values.extend(extra_null_values) self._null_values.extend(extra_null_values)
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec) super(NullableField, self).__init__(default, alias, materialized, readonly=None,
codec=codec)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values: if value == '\\N' or value in self._null_values:
@ -644,10 +637,16 @@ class NullableField(Field):
class LowCardinalityField(Field): class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field) codec=None):
assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM" assert isinstance(inner_field, Field), \
assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead" "The first argument of LowCardinalityField must be a Field instance." \
" Not: {}".format(inner_field)
assert not isinstance(inner_field, LowCardinalityField), \
"LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), \
"Array field inside LowCardinality are not supported by the ORM." \
" Use Array(LowCardinality) instead"
self.inner_field = inner_field self.inner_field = inner_field
self.class_default = self.inner_field.class_default self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec) super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
@ -666,7 +665,10 @@ class LowCardinalityField(Field):
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
else: else:
sql = self.inner_field.get_sql(with_default_expression=False) sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__)) logger.warning(
'LowCardinalityField not supported on clickhouse-server version < 19.0'
' using {} as fallback'.format(self.inner_field.__class__.__name__)
)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql

View File

@ -1121,6 +1121,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def arrayElement(arr, n): def arrayElement(arr, n):
return F('arrayElement', arr, n) return F('arrayElement', arr, n)
@staticmethod
def tupleElement(arr, n):
return F('tupleElement', arr, n)
@staticmethod @staticmethod
def has(arr, x): def has(arr, x):
return F('has', arr, x) return F('has', arr, x)
@ -1133,6 +1137,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def hasAny(arr, x): def hasAny(arr, x):
return F('hasAny', arr, x) return F('hasAny', arr, x)
@staticmethod
def geohashEncode(x, y, precision=12):
return F('geohashEncode', x, y, precision)
@staticmethod @staticmethod
def indexOf(arr, x): def indexOf(arr, x):
return F('indexOf', arr, x) return F('indexOf', arr, x)

View File

@ -3,11 +3,12 @@ import sys
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
from logging import getLogger from logging import getLogger
from typing import TypeVar
import pytz import pytz
from .fields import Field, StringField from .fields import Field, StringField
from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql, unescape from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql
from .query import QuerySet from .query import QuerySet
from .funcs import F from .funcs import F
from .engines import Merge, Distributed from .engines import Merge, Distributed
@ -15,75 +16,74 @@ from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm') logger = getLogger('clickhouse_orm')
class Constraint: class Constraint:
''' """
Defines a model constraint. Defines a model constraint.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr): def __init__(self, expr):
''' """
Initializer. Expects an expression that ClickHouse will verify when inserting data. Initializer. Expects an expression that ClickHouse will verify when inserting data.
''' """
self.expr = expr self.expr = expr
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this constraint during table creation. Returns the SQL statement for defining this constraint during table creation.
''' """
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr))
class Index: class Index:
''' """
Defines a data-skipping index. Defines a data-skipping index.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr, type, granularity): def __init__(self, expr, type, granularity):
''' """
Initializer. Initializer.
- `expr` - a column, expression, or tuple of columns and expressions to index. - `expr` - a column, expression, or tuple of columns and expressions to index.
- `type` - the index type. Use one of the following methods to specify the type: - `type` - the index type. Use one of the following methods to specify the type:
`Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`. `Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`.
- `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine). - `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine).
''' """
self.expr = expr self.expr = expr
self.type = type self.type = type
self.granularity = granularity self.granularity = granularity
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this index during table creation. Returns the SQL statement for defining this index during table creation.
''' """
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity) return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
@staticmethod @staticmethod
def minmax(): def minmax():
''' """
An index that stores extremes of the specified expression (if the expression is tuple, then it stores An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
''' """
return 'minmax' return 'minmax'
@staticmethod @staticmethod
def set(max_rows): def set(max_rows):
''' """
An index that stores unique values of the specified expression (no more than max_rows rows, An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data. on a block of data.
''' """
return 'set(%d)' % max_rows return 'set(%d)' % max_rows
@staticmethod @staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing all ngrams from a block of data. An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions. Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -92,12 +92,12 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing string tokens. Tokens are sequences An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters. separated by non-alphanumeric characters.
@ -105,7 +105,7 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
@ -253,7 +253,7 @@ class ModelBase(type):
class Model(metaclass=ModelBase): class Model(metaclass=ModelBase):
''' """
A base class for ORM models. Each model class represent a ClickHouse table. For example: A base class for ORM models. Each model class represent a ClickHouse table. For example:
class CPUStats(Model): class CPUStats(Model):
@ -261,7 +261,7 @@ class Model(metaclass=ModelBase):
cpu_id = UInt16Field() cpu_id = UInt16Field()
cpu_percent = Float32Field() cpu_percent = Float32Field()
engine = Memory() engine = Memory()
''' """
engine = None engine = None
@ -274,12 +274,12 @@ class Model(metaclass=ModelBase):
_database = None _database = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
''' """
Creates a model instance, using keyword arguments as field values. Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type, Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised. invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`. Unrecognized field names will cause an `AttributeError`.
''' """
super(Model, self).__init__() super(Model, self).__init__()
# Assign default values # Assign default values
self.__dict__.update(self._defaults) self.__dict__.update(self._defaults)
@ -292,10 +292,10 @@ class Model(metaclass=ModelBase):
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name))
def __setattr__(self, name, value): def __setattr__(self, name, value):
''' """
When setting a field value, converts the value to its Pythonic type and validates it. When setting a field value, converts the value to its Pythonic type and validates it.
This may raise a `ValueError`. This may raise a `ValueError`.
''' """
field = self.get_field(name) field = self.get_field(name)
if field and (value != NO_VALUE): if field and (value != NO_VALUE):
try: try:
@ -308,50 +308,50 @@ class Model(metaclass=ModelBase):
super(Model, self).__setattr__(name, value) super(Model, self).__setattr__(name, value)
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
# This can not be imported globally due to circular import # This can not be imported globally due to circular import
from .database import Database from .database import Database
assert isinstance(db, Database), "database must be database.Database instance" assert isinstance(db, Database), "database must be database.Database instance"
self._database = db self._database = db
def get_database(self): def get_database(self):
''' """
Gets the `Database` that this model instance belongs to. Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it. Returns `None` unless the instance was read from the database or written to it.
''' """
return self._database return self._database
def get_field(self, name): def get_field(self, name):
''' """
Gets a `Field` instance given its name, or `None` if not found. Gets a `Field` instance given its name, or `None` if not found.
''' """
return self._fields.get(name) return self._fields.get(name)
@classmethod @classmethod
def table_name(cls): def table_name(cls):
''' """
Returns the model's database table name. By default this is the Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use class name converted to lowercase. Override this if you want to use
a different table name. a different table name.
''' """
return cls.__name__.lower() return cls.__name__.lower()
@classmethod @classmethod
def has_funcs_as_defaults(cls): def has_funcs_as_defaults(cls):
''' """
Return True if some of the model's fields use a function expression Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances. as a default value. This requires special handling when inserting instances.
''' """
return cls._has_funcs_as_defaults return cls._has_funcs_as_defaults
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
# Fields # Fields
items = [] items = []
@ -371,14 +371,14 @@ class Model(metaclass=ModelBase):
@classmethod @classmethod
def drop_table_sql(cls, db): def drop_table_sql(cls, db):
''' """
Returns the SQL command for deleting this model's table. Returns the SQL command for deleting this model's table.
''' """
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
''' """
Create a model instance from a tab-separated line. The line may or may not include a newline. Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them. The `field_names` list must match the fields defined in the model, but does not have to include all of them.
@ -386,7 +386,7 @@ class Model(metaclass=ModelBase):
- `field_names`: names of the model fields in the data. - `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones. - `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to. - `database`: if given, sets the database that this instance belongs to.
''' """
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
@ -401,22 +401,22 @@ class Model(metaclass=ModelBase):
return obj return obj
def to_tsv(self, include_readonly=True): def to_tsv(self, include_readonly=True):
''' """
Returns the instance's column values as a tab-separated line. A newline is not included. Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_tskv(self, include_readonly=True): def to_tskv(self, include_readonly=True):
''' """
Returns the instance's column keys and values as a tab-separated line. A newline is not included. Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted. Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
parts = [] parts = []
@ -426,20 +426,20 @@ class Model(metaclass=ModelBase):
return '\t'.join(parts) return '\t'.join(parts)
def to_db_string(self): def to_db_string(self):
''' """
Returns the instance as a bytestring ready to be inserted into the database. Returns the instance as a bytestring ready to be inserted into the database.
''' """
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n' s += '\n'
return s.encode('utf-8') return s.encode('utf-8')
def to_dict(self, include_readonly=True, field_names=None): def to_dict(self, include_readonly=True, field_names=None):
''' """
Returns the instance's column values as a dict. Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional) - `field_names`: an iterable of field names to return (optional)
''' """
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
if field_names is not None: if field_names is not None:
@ -450,66 +450,68 @@ class Model(metaclass=ModelBase):
@classmethod @classmethod
def objects_in(cls, database): def objects_in(cls, database):
''' """
Returns a `QuerySet` for selecting instances of this model class. Returns a `QuerySet` for selecting instances of this model class.
''' """
return QuerySet(cls, database) return QuerySet(cls, database)
@classmethod @classmethod
def fields(cls, writable=False): def fields(cls, writable=False):
''' """
Returns an `OrderedDict` of the model's fields (from name to `Field` instance). Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included. If `writable` is true, only writable fields are included.
Callers should not modify the dictionary. Callers should not modify the dictionary.
''' """
# noinspection PyProtectedMember,PyUnresolvedReferences # noinspection PyProtectedMember,PyUnresolvedReferences
return cls._writable_fields if writable else cls._fields return cls._writable_fields if writable else cls._fields
@classmethod @classmethod
def is_read_only(cls): def is_read_only(cls):
''' """
Returns true if the model is marked as read only. Returns true if the model is marked as read only.
''' """
return cls._readonly return cls._readonly
@classmethod @classmethod
def is_system_model(cls): def is_system_model(cls):
''' """
Returns true if the model represents a system table. Returns true if the model represents a system table.
''' """
return cls._system return cls._system
class BufferModel(Model): class BufferModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db) -> str:
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, parts = [
cls.engine.main_model.table_name())] 'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (
db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db) engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str) parts.append(engine_str)
return ' '.join(parts) return ' '.join(parts)
class MergeModel(Model): class MergeModel(Model):
''' """
Model for Merge engine Model for Merge engine
Predefines virtual _table column an controls that rows can't be inserted to this table type Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
''' """
readonly = True readonly = True
# Virtual fields can't be inserted into database # Virtual fields can't be inserted into database
_table = StringField(readonly=True) _table = StringField(readonly=True)
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db) -> str:
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = [] cols = []
@ -530,11 +532,12 @@ class DistributedModel(Model):
""" """
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" assert isinstance(self.engine, Distributed),\
"engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db) res = super(DistributedModel, self).set_database(db)
return res return res
@ -590,10 +593,10 @@ class DistributedModel(Model):
cls.engine.table = storage_models[0] cls.engine.table = storage_models[0]
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db) -> str:
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table() cls.fix_engine_table()
@ -606,4 +609,5 @@ class DistributedModel(Model):
# Expose only relevant classes in import * # Expose only relevant classes in import *
MODEL = TypeVar('MODEL', bound=Model)
__all__ = get_subclass_names(locals(), (Model, Constraint, Index)) __all__ = get_subclass_names(locals(), (Model, Constraint, Index))

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from math import ceil
from copy import copy, deepcopy
import pytz import pytz
from copy import copy, deepcopy
from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql from .utils import comma_join, string_or_func, arg_to_sql
@ -393,7 +393,7 @@ class QuerySet(object):
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
if self._grouping_fields: if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields)
if self._grouping_with_totals: if self._grouping_with_totals:
sql += ' WITH TOTALS' sql += ' WITH TOTALS'
@ -548,7 +548,9 @@ class QuerySet(object):
from .engines import CollapsingMergeTree, ReplacingMergeTree from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError( raise TypeError(
'final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') 'final() method can be used only with the CollapsingMergeTree'
' and ReplacingMergeTree engines'
)
qs = copy(self) qs = copy(self)
qs._final = True qs._final = True
@ -576,14 +578,15 @@ class QuerySet(object):
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % ( sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (
self._model_cls.table_name(), fields, conditions) self._model_cls.table_name(), fields, conditions
)
self._database.raw(sql) self._database.raw(sql)
return self return self
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
''' """
Checks that the queryset's state allows mutations. Raises an AssertionError if not. Checks that the queryset's state allows mutations. Raises an AssertionError if not.
''' """
assert not self._limits, 'Mutations are not allowed after slicing the queryset' assert not self._limits, 'Mutations are not allowed after slicing the queryset'
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._distinct, 'Mutations are not allowed after calling distinct()'