add typing hint

This commit is contained in:
sswest 2022-05-26 17:02:32 +08:00
parent 17a5c30bfd
commit 7138dfe8c3
6 changed files with 262 additions and 219 deletions

View File

@ -74,6 +74,9 @@ class PointField(Field):
def to_db_string(self, value, quote=True):
return value.to_db_string()
def __getitem__(self, item):
return
class RingField(Field):
class_default = [Point(0, 0)]

View File

@ -1,26 +1,28 @@
from __future__ import unicode_literals
import re
import requests
from collections import namedtuple
from .models import ModelBase
from .utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime
from string import Template
import pytz
import logging
import datetime
from math import ceil
from string import Template
from collections import namedtuple
from typing import Type, Optional, Generator, Union, Any
import pytz
import requests
from .models import ModelBase, MODEL
from .utils import parse_tsv, import_submodules
from .query import Q
logger = logging.getLogger('clickhouse_orm')
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
class DatabaseException(Exception):
'''
"""
Raised when a database operation fails.
'''
"""
pass
@ -80,15 +82,15 @@ class ServerError(DatabaseException):
class Database(object):
'''
"""
Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations.
'''
"""
def __init__(self, db_name, db_url='http://localhost:8123/',
username=None, password=None, readonly=False, autocreate=True,
timeout=60, verify_ssl_cert=True, log_statements=False):
'''
"""
Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist.
@ -101,7 +103,7 @@ class Database(object):
- `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged.
'''
"""
self.db_name = db_name
self.db_url = db_url
self.readonly = False
@ -130,55 +132,59 @@ class Database(object):
self.has_low_cardinality_support = self.server_version >= (19, 0)
def create_database(self):
'''
"""
Creates the database on the ClickHouse server if it does not already exist.
'''
"""
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
self.db_exists = True
def drop_database(self):
'''
"""
Deletes the database on the ClickHouse server.
'''
"""
self._send('DROP DATABASE `%s`' % self.db_name)
self.db_exists = False
def create_table(self, model_class):
'''
def create_table(self, model_class: Type[MODEL]) -> None:
"""
Creates a table for the given model class, if it does not exist already.
'''
"""
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self))
def drop_table(self, model_class):
'''
def drop_table(self, model_class: Type[MODEL]) -> None:
"""
Drops the database table of the given model class, if it exists.
'''
"""
if model_class.is_system_model():
raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class):
'''
def does_table_exist(self, model_class: Type[MODEL]) -> bool:
"""
Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name.
'''
"""
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1'
def get_model_for_table(self, table_name, system_table=False):
'''
def get_model_for_table(
self,
table_name: str,
system_table: bool = False
):
"""
Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class,
for example system tables.
- `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database
'''
"""
db_name = 'system' if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines()
@ -188,14 +194,14 @@ class Database(object):
model._system = model._readonly = True
return model
def add_setting(self, name, value):
'''
def add_setting(self, name: str, value: Any):
"""
Adds a database setting that will be sent with every request.
For example, `db.add_setting("max_execution_time", 10)` will
limit query execution time to 10 seconds.
The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value.
'''
"""
assert isinstance(name, str), 'Setting name must be a string'
if value is None:
self.settings.pop(name, None)
@ -203,12 +209,12 @@ class Database(object):
self.settings[name] = str(value)
def insert(self, model_instances, batch_size=1000):
'''
"""
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
'''
"""
from io import BytesIO
i = iter(model_instances)
try:
@ -247,13 +253,17 @@ class Database(object):
yield buf.getvalue()
self._send(gen())
def count(self, model_class, conditions=None):
'''
def count(
self,
model_class: Optional[Type[MODEL]],
conditions: Optional[Union[str, Q]] = None
) -> int:
"""
Counts the number of records in the model's table.
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
'''
"""
from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
if conditions:
@ -264,15 +274,20 @@ class Database(object):
r = self._send(query)
return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None):
'''
def select(
self,
query: str,
model_class: Optional[Type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
"""
Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
'''
"""
query += ' FORMAT TabSeparatedWithNamesAndTypes'
query = self._substitute(query, model_class)
r = self._send(query, settings, True)
@ -285,19 +300,27 @@ class Database(object):
if line:
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
def raw(self, query, settings=None, stream=False):
'''
def raw(self, query: str, settings: Optional[dict] = None, stream: bool = False) -> str:
"""
Performs a query and returns its output as text.
- `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
'''
"""
query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
'''
def paginate(
self,
model_class: Type[MODEL],
order_by: str,
page_num: int = 1,
page_size: int = 100,
conditions=None,
settings: Optional[dict] = None
):
"""
Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table,
@ -310,7 +333,7 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
'''
"""
from clickhouse_orm.query import Q
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
@ -336,13 +359,13 @@ class Database(object):
)
def migrate(self, migrations_package_name, up_to=9999):
'''
"""
Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package
containing the migrations.
- `up_to` - number of the last migration to apply.
'''
"""
from .migrations import MigrationHistory
logger = logging.getLogger('migrations')
applied_migrations = self._get_applied_migrations(migrations_package_name)

View File

@ -17,9 +17,9 @@ logger = getLogger('clickhouse_orm')
class Field(FunctionOperatorsMixin):
'''
"""
Abstract base class for all field types.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered
@ -31,9 +31,11 @@ class Field(FunctionOperatorsMixin):
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
"Alias parameter must be a string or function object, if given"
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\
assert materialized is None or isinstance(materialized, F) or isinstance(materialized,
str) and materialized != "", \
"Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
assert readonly is None or type(
readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given"
@ -52,42 +54,43 @@ class Field(FunctionOperatorsMixin):
return '<%s>' % self.__class__.__name__
def to_python(self, value, timezone_in_use):
'''
"""
Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this.
The timezone_in_use parameter should be consulted when parsing datetime fields.
'''
"""
return value # pragma: no cover
def validate(self, value):
'''
"""
Called after to_python to validate that the value is suitable for the field's database type.
Subclasses should override this.
'''
"""
pass
def _range_check(self, value, min_value, max_value):
'''
"""
Utility method to check that the given value is between min_value and max_value.
'''
"""
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
raise ValueError('%s out of range - %s is not between %s and %s' % (
self.__class__.__name__, value, min_value, max_value))
def to_db_string(self, value, quote=True):
'''
"""
Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes.
'''
"""
return escape(value, quote)
def get_sql(self, with_default_expression=True, db=None):
'''
"""
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
- `with_default_expression`: If True, adds default value to sql.
It doesn't affect fields with alias and materialized values.
- `db`: Database, used for checking supported features.
'''
"""
sql = self.db_type
args = self.get_db_type_args()
if args:
@ -135,7 +138,6 @@ class Field(FunctionOperatorsMixin):
class StringField(Field):
class_default = ''
db_type = 'String'
@ -162,11 +164,11 @@ class FixedStringField(StringField):
if isinstance(value, str):
value = value.encode('UTF-8')
if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length))
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (
len(value), self._length))
class DateField(Field):
min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31)
class_default = min_value
@ -193,7 +195,6 @@ class DateField(Field):
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
@ -292,9 +293,10 @@ class DateTime64Field(DateTimeField):
class BaseIntField(Field):
'''
"""
Abstract base class for all integer-type fields.
'''
"""
def to_python(self, value, timezone_in_use):
try:
return int(value)
@ -311,56 +313,48 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField):
min_value = 0
max_value = 2 ** 8 - 1
db_type = 'UInt8'
class UInt16Field(BaseIntField):
min_value = 0
max_value = 2 ** 16 - 1
db_type = 'UInt16'
class UInt32Field(BaseIntField):
min_value = 0
max_value = 2 ** 32 - 1
db_type = 'UInt32'
class UInt64Field(BaseIntField):
min_value = 0
max_value = 2 ** 64 - 1
db_type = 'UInt64'
class Int8Field(BaseIntField):
min_value = -2 ** 7
max_value = 2 ** 7 - 1
db_type = 'Int8'
class Int16Field(BaseIntField):
min_value = -2 ** 15
max_value = 2 ** 15 - 1
db_type = 'Int16'
class Int32Field(BaseIntField):
min_value = -2 ** 31
max_value = 2 ** 31 - 1
db_type = 'Int32'
class Int64Field(BaseIntField):
min_value = -2 ** 63
max_value = 2 ** 63 - 1
db_type = 'Int64'
@ -384,21 +378,20 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField):
db_type = 'Float32'
class Float64Field(BaseFloatField):
db_type = 'Float64'
class DecimalField(Field):
'''
"""
Base class for all decimal fields. Can also be used directly.
'''
"""
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
def __init__(self, precision, scale, default=None, alias=None, materialized=None,
readonly=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
self.precision = precision
@ -455,11 +448,12 @@ class Decimal128Field(DecimalField):
class BaseEnumField(Field):
'''
"""
Abstract base class for all enum-type fields.
'''
"""
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None,
codec=None):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
@ -494,10 +488,10 @@ class BaseEnumField(Field):
@classmethod
def create_ad_hoc_field(cls, db_type):
'''
"""
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
'''
"""
import re
from enum import Enum
members = {}
@ -509,22 +503,22 @@ class BaseEnumField(Field):
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
class Enum16Field(BaseEnumField):
db_type = 'Enum16'
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), "Multidimensional array fields are not supported by the ORM"
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None):
assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \
"Multidimensional array fields are not supported by the ORM"
self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
@ -553,7 +547,6 @@ class ArrayField(Field):
class UUIDField(Field):
class_default = UUID(int=0)
db_type = 'UUID'
@ -576,7 +569,6 @@ class UUIDField(Field):
class IPv4Field(Field):
class_default = 0
db_type = 'IPv4'
@ -593,7 +585,6 @@ class IPv4Field(Field):
class IPv6Field(Field):
class_default = 0
db_type = 'IPv6'
@ -610,17 +601,19 @@ class IPv6Field(Field):
class NullableField(Field):
class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None,
extra_null_values=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
assert isinstance(inner_field, Field), \
"The first argument of NullableField must be a Field instance." \
" Not: {}".format(inner_field)
self.inner_field = inner_field
self._null_values = [None]
if extra_null_values:
self._null_values.extend(extra_null_values)
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
super(NullableField, self).__init__(default, alias, materialized, readonly=None,
codec=codec)
def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values:
@ -644,10 +637,16 @@ class NullableField(Field):
class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None):
assert isinstance(inner_field, Field), \
"The first argument of LowCardinalityField must be a Field instance." \
" Not: {}".format(inner_field)
assert not isinstance(inner_field, LowCardinalityField), \
"LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), \
"Array field inside LowCardinality are not supported by the ORM." \
" Use Array(LowCardinality) instead"
self.inner_field = inner_field
self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
@ -666,7 +665,10 @@ class LowCardinalityField(Field):
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
else:
sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__))
logger.warning(
'LowCardinalityField not supported on clickhouse-server version < 19.0'
' using {} as fallback'.format(self.inner_field.__class__.__name__)
)
if with_default_expression:
sql += self._extra_params(db)
return sql

View File

@ -1121,6 +1121,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def arrayElement(arr, n):
return F('arrayElement', arr, n)
@staticmethod
def tupleElement(arr, n):
return F('tupleElement', arr, n)
@staticmethod
def has(arr, x):
return F('has', arr, x)
@ -1133,6 +1137,10 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def hasAny(arr, x):
return F('hasAny', arr, x)
@staticmethod
def geohashEncode(x, y, precision=12):
return F('geohashEncode', x, y, precision)
@staticmethod
def indexOf(arr, x):
return F('indexOf', arr, x)

View File

@ -3,11 +3,12 @@ import sys
from collections import OrderedDict
from itertools import chain
from logging import getLogger
from typing import TypeVar
import pytz
from .fields import Field, StringField
from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql, unescape
from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql
from .query import QuerySet
from .funcs import F
from .engines import Merge, Distributed
@ -15,75 +16,74 @@ from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm')
class Constraint:
'''
"""
Defines a model constraint.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
def __init__(self, expr):
'''
"""
Initializer. Expects an expression that ClickHouse will verify when inserting data.
'''
"""
self.expr = expr
def create_table_sql(self):
'''
"""
Returns the SQL statement for defining this constraint during table creation.
'''
"""
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr))
class Index:
'''
"""
Defines a data-skipping index.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
def __init__(self, expr, type, granularity):
'''
"""
Initializer.
- `expr` - a column, expression, or tuple of columns and expressions to index.
- `type` - the index type. Use one of the following methods to specify the type:
`Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`.
- `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine).
'''
"""
self.expr = expr
self.type = type
self.granularity = granularity
def create_table_sql(self):
'''
"""
Returns the SQL statement for defining this index during table creation.
'''
"""
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
@staticmethod
def minmax():
'''
"""
An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
'''
"""
return 'minmax'
@staticmethod
def set(max_rows):
'''
"""
An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data.
'''
"""
return 'set(%d)' % max_rows
@staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
'''
"""
An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -92,12 +92,12 @@ class Index:
for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
'''
"""
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
'''
"""
An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters.
@ -105,7 +105,7 @@ class Index:
for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
'''
"""
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod
@ -253,7 +253,7 @@ class ModelBase(type):
class Model(metaclass=ModelBase):
'''
"""
A base class for ORM models. Each model class represent a ClickHouse table. For example:
class CPUStats(Model):
@ -261,7 +261,7 @@ class Model(metaclass=ModelBase):
cpu_id = UInt16Field()
cpu_percent = Float32Field()
engine = Memory()
'''
"""
engine = None
@ -274,12 +274,12 @@ class Model(metaclass=ModelBase):
_database = None
def __init__(self, **kwargs):
'''
"""
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
'''
"""
super(Model, self).__init__()
# Assign default values
self.__dict__.update(self._defaults)
@ -292,10 +292,10 @@ class Model(metaclass=ModelBase):
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name))
def __setattr__(self, name, value):
'''
"""
When setting a field value, converts the value to its Pythonic type and validates it.
This may raise a `ValueError`.
'''
"""
field = self.get_field(name)
if field and (value != NO_VALUE):
try:
@ -308,50 +308,50 @@ class Model(metaclass=ModelBase):
super(Model, self).__setattr__(name, value)
def set_database(self, db):
'''
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
'''
"""
# This can not be imported globally due to circular import
from .database import Database
assert isinstance(db, Database), "database must be database.Database instance"
self._database = db
def get_database(self):
'''
"""
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
'''
"""
return self._database
def get_field(self, name):
'''
"""
Gets a `Field` instance given its name, or `None` if not found.
'''
"""
return self._fields.get(name)
@classmethod
def table_name(cls):
'''
"""
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
'''
"""
return cls.__name__.lower()
@classmethod
def has_funcs_as_defaults(cls):
'''
"""
Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances.
'''
"""
return cls._has_funcs_as_defaults
@classmethod
def create_table_sql(cls, db):
'''
"""
Returns the SQL statement for creating a table for this model.
'''
"""
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
# Fields
items = []
@ -371,14 +371,14 @@ class Model(metaclass=ModelBase):
@classmethod
def drop_table_sql(cls, db):
'''
"""
Returns the SQL command for deleting this model's table.
'''
"""
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name())
@classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
'''
"""
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
@ -386,7 +386,7 @@ class Model(metaclass=ModelBase):
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to.
'''
"""
values = iter(parse_tsv(line))
kwargs = {}
for name in field_names:
@ -401,22 +401,22 @@ class Model(metaclass=ModelBase):
return obj
def to_tsv(self, include_readonly=True):
'''
"""
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
'''
"""
data = self.__dict__
fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_tskv(self, include_readonly=True):
'''
"""
Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database.
'''
"""
data = self.__dict__
fields = self.fields(writable=not include_readonly)
parts = []
@ -426,20 +426,20 @@ class Model(metaclass=ModelBase):
return '\t'.join(parts)
def to_db_string(self):
'''
"""
Returns the instance as a bytestring ready to be inserted into the database.
'''
"""
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n'
return s.encode('utf-8')
def to_dict(self, include_readonly=True, field_names=None):
'''
"""
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
'''
"""
fields = self.fields(writable=not include_readonly)
if field_names is not None:
@ -450,66 +450,68 @@ class Model(metaclass=ModelBase):
@classmethod
def objects_in(cls, database):
'''
"""
Returns a `QuerySet` for selecting instances of this model class.
'''
"""
return QuerySet(cls, database)
@classmethod
def fields(cls, writable=False):
'''
"""
Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included.
Callers should not modify the dictionary.
'''
"""
# noinspection PyProtectedMember,PyUnresolvedReferences
return cls._writable_fields if writable else cls._fields
@classmethod
def is_read_only(cls):
'''
"""
Returns true if the model is marked as read only.
'''
"""
return cls._readonly
@classmethod
def is_system_model(cls):
'''
"""
Returns true if the model represents a system table.
'''
"""
return cls._system
class BufferModel(Model):
@classmethod
def create_table_sql(cls, db):
'''
def create_table_sql(cls, db) -> str:
"""
Returns the SQL statement for creating a table for this model.
'''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name,
cls.engine.main_model.table_name())]
"""
parts = [
'CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (
db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str)
return ' '.join(parts)
class MergeModel(Model):
'''
"""
Model for Merge engine
Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
'''
"""
readonly = True
# Virtual fields can't be inserted into database
_table = StringField(readonly=True)
@classmethod
def create_table_sql(cls, db):
'''
def create_table_sql(cls, db) -> str:
"""
Returns the SQL statement for creating a table for this model.
'''
"""
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = []
@ -530,11 +532,12 @@ class DistributedModel(Model):
"""
def set_database(self, db):
'''
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
'''
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
"""
assert isinstance(self.engine, Distributed),\
"engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db)
return res
@ -590,10 +593,10 @@ class DistributedModel(Model):
cls.engine.table = storage_models[0]
@classmethod
def create_table_sql(cls, db):
'''
def create_table_sql(cls, db) -> str:
"""
Returns the SQL statement for creating a table for this model.
'''
"""
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table()
@ -606,4 +609,5 @@ class DistributedModel(Model):
# Expose only relevant classes in import *
MODEL = TypeVar('MODEL', bound=Model)
__all__ = get_subclass_names(locals(), (Model, Constraint, Index))

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals
from math import ceil
from copy import copy, deepcopy
import pytz
from copy import copy, deepcopy
from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql
@ -393,7 +393,7 @@ class QuerySet(object):
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields)
sql += '\nGROUP BY %s' % comma_join('%s' % field for field in self._grouping_fields)
if self._grouping_with_totals:
sql += ' WITH TOTALS'
@ -548,7 +548,9 @@ class QuerySet(object):
from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError(
'final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines')
'final() method can be used only with the CollapsingMergeTree'
' and ReplacingMergeTree engines'
)
qs = copy(self)
qs._final = True
@ -576,14 +578,15 @@ class QuerySet(object):
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (
self._model_cls.table_name(), fields, conditions)
self._model_cls.table_name(), fields, conditions
)
self._database.raw(sql)
return self
def _verify_mutation_allowed(self):
'''
"""
Checks that the queryset's state allows mutations. Raises an AssertionError if not.
'''
"""
assert not self._limits, 'Mutations are not allowed after slicing the queryset'
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
assert not self._distinct, 'Mutations are not allowed after calling distinct()'