feat: support TupleField

This commit is contained in:
sswest 2022-06-02 19:15:48 +08:00
parent 6596517b25
commit 7002912300
14 changed files with 1159 additions and 333 deletions

558
.pylintrc Normal file
View File

@ -0,0 +1,558 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold to be exceeded before program exits with error.
fail-under=10.0
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regex patterns to the ignore-list. The
# regex matches against paths.
ignore-paths=protobuf
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
init-hook='from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=no
# Min Python version to use for version dependend checks. Will default to the
# version used to run pylint.
py-version=3.9
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
consider-using-f-string
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'error', 'warning', 'refactor', and 'convention'
# which contain the number of messages in each category, as well as 'statement'
# which is the total number of statements analyzed. This score is used by the
# global evaluation report (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=colorized
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it work,
# install the 'python-enchant' package.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear and the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
#notes-rgx=
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=no
# Signatures are removed from the similarity computation
ignore-signatures=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,ex,Run,_,r,s,n,x,y,z,f,v,db,c,id,t,ap,pk,qs,q,Q
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=mcs
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
# 函数或方法的最大参数个数
max-args=10
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=18
# Maximum number of locals for function / method body.
max-locals=25
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
# 函数或方法内出现的最大return/yield语句的数量
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception

View File

@ -20,7 +20,7 @@ dependencies = [
"iso8601 >= 0.1.12",
"setuptools"
]
version = "0.0.6"
version = "0.1.0"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,4 +1,4 @@
__import__("pkg_resources").declare_namespace(__name__)
from inspect import isclass
from clickhouse_orm.database import *
from clickhouse_orm.engines import *
@ -9,5 +9,4 @@ from clickhouse_orm.models import *
from clickhouse_orm.query import *
from clickhouse_orm.system_models import *
from inspect import isclass
__all__ = [c.__name__ for c in locals().values() if isclass(c)]

View File

@ -1,7 +1,8 @@
import datetime
import logging
from io import BytesIO
from math import ceil
from typing import Type, Optional, Generator
from typing import Optional, AsyncGenerator
import httpx
import pytz
@ -11,29 +12,13 @@ from clickhouse_orm.utils import parse_tsv, import_submodules
from clickhouse_orm.database import Database, ServerError, DatabaseException, logger, Page
# pylint: disable=C0116
class AioDatabase(Database):
_client_class = httpx.AsyncClient
def __init__(
self, db_name, db_url='http://localhost:18123/', username=None,
password=None, readonly=False, auto_create=True, timeout=60,
verify_ssl_cert=True, log_statements=False
):
self.db_name = db_name
self.db_url = db_url
self.readonly = False
self._readonly = readonly
self.auto_create = auto_create
self.timeout = timeout
self.request_session = httpx.AsyncClient(verify=verify_ssl_cert, timeout=timeout)
if username:
self.request_session.auth = (username, password or '')
self.log_statements = log_statements
self.settings = {}
self._db_check = False
self.db_exists = False
async def db_check(self):
if self._db_check:
async def init(self):
if self._init:
return
self.db_exists = await self._is_existing_database()
if self._readonly:
@ -52,7 +37,7 @@ class AioDatabase(Database):
self.server_timezone = pytz.utc
self.has_codec_support = self.server_version >= (19, 1, 16)
self.has_low_cardinality_support = self.server_version >= (19, 0)
self._db_check = True
self._init = True
async def close(self):
await self.request_session.aclose()
@ -76,9 +61,9 @@ class AioDatabase(Database):
"""
from clickhouse_orm.query import Q
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
query = 'SELECT count() FROM $table'
@ -94,9 +79,9 @@ class AioDatabase(Database):
"""
Creates the database on the ClickHouse server if it does not already exist.
"""
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
@ -106,50 +91,65 @@ class AioDatabase(Database):
"""
Deletes the database on the ClickHouse server.
"""
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
await self._send('DROP DATABASE `%s`' % self.db_name)
self.db_exists = False
async def create_table(self, model_class: Type[MODEL]) -> None:
async def create_table(self, model_class: type[MODEL]) -> None:
"""
Creates a table for the given model class, if it does not exist already.
"""
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if model_class.is_temporary_model() and self.session_id is None:
raise DatabaseException(
"Creating a temporary table must be within the lifetime of a session "
)
if getattr(model_class, 'engine') is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
raise DatabaseException(f"%s class must define an engine" % model_class.__name__)
await self._send(model_class.create_table_sql(self))
async def drop_table(self, model_class: Type[MODEL]) -> None:
async def create_temporary_table(self, model_class: type[MODEL], table_name: str = None):
"""
Creates a temporary table for the given model class, if it does not exist already.
And you can specify the temporary table name explicitly.
"""
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the init method before it can be used'
)
await self._send(model_class.create_temporary_table_sql(self, table_name))
async def drop_table(self, model_class: type[MODEL]) -> None:
"""
Drops the database table of the given model class, if it exists.
"""
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
if model_class.is_system_model():
raise DatabaseException("You can't drop system table")
await self._send(model_class.drop_table_sql(self))
async def does_table_exist(self, model_class: Type[MODEL]) -> bool:
async def does_table_exist(self, model_class: type[MODEL]) -> bool:
"""
Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name.
"""
if not self._db_check:
if not self._init:
raise DatabaseException(
'The AioDatabase object must execute the `db_check` method before it can be used'
'The AioDatabase object must execute the init method before it can be used'
)
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
@ -185,8 +185,6 @@ class AioDatabase(Database):
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
"""
from io import BytesIO
i = iter(model_instances)
try:
first_instance = next(i)
@ -202,7 +200,7 @@ class AioDatabase(Database):
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
def gen():
async def gen():
buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8'))
first_instance.set_database(self)
@ -227,9 +225,9 @@ class AioDatabase(Database):
async def select(
self,
query: str,
model_class: Optional[Type[MODEL]] = None,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
) -> AsyncGenerator[MODEL, None]:
"""
Performs a query and returns a generator of model instances.
@ -269,7 +267,7 @@ class AioDatabase(Database):
async def paginate(
self,
model_class: Type[MODEL],
model_class: type[MODEL],
order_by: str,
page_num: int = 1,
page_size: int = 100,
@ -355,16 +353,16 @@ class AioDatabase(Database):
try:
r = await self._send('SELECT timezone()')
return pytz.timezone(r.text.strip())
except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e)
except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
return pytz.utc
async def _get_server_version(self, as_tuple=True):
try:
r = await self._send('SELECT version();')
ver = r.text
except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e)
except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver
@ -375,3 +373,6 @@ class AioDatabase(Database):
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
return set(obj.module_name async for obj in self.select(query))
__all__ = [AioDatabase]

View File

@ -1,3 +1,6 @@
from typing import Any, Optional, Union
from clickhouse_orm import F
from clickhouse_orm.fields import Field, Float64Field
from clickhouse_orm.utils import POINT_REGEX, RING_VALID_REGEX
@ -53,8 +56,15 @@ class PointField(Field):
class_default = Point(0, 0)
db_type = 'Point'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None):
def __init__(
self,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
super().__init__(default, alias, materialized, readonly, codec, db_column)
self.inner_field = Float64Field()

View File

@ -1,18 +1,18 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, annotations
import re
import logging
import datetime
from io import BytesIO
from math import ceil
from string import Template
from collections import namedtuple
from typing import Type, Optional, Generator, Union, Any
from typing import Optional, Generator, Union, Any
import pytz
import httpx
from .models import ModelBase, MODEL
from .utils import parse_tsv, import_submodules
from .query import Q
from .session import ctx_session_id, ctx_session_timeout
@ -24,7 +24,6 @@ class DatabaseException(Exception):
"""
Raised when a database operation fails.
"""
pass
class ServerError(DatabaseException):
@ -40,7 +39,7 @@ class ServerError(DatabaseException):
# just skip custom init
# if non-standard message format
self.message = message
super(ServerError, self).__init__(message)
super().__init__(message)
ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3
@ -82,14 +81,15 @@ class ServerError(DatabaseException):
return "{} ({})".format(self.message, self.code)
class Database(object):
class Database:
"""
Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations.
"""
_client_class = httpx.Client
def __init__(self, db_name, db_url='http://localhost:8123/',
username=None, password=None, readonly=False, autocreate=True,
username=None, password=None, readonly=False, auto_create=True,
timeout=60, verify_ssl_cert=True, log_statements=False):
"""
Initializes a database instance. Unless it's readonly, the database will be
@ -100,7 +100,8 @@ class Database(object):
- `username`: optional connection credentials.
- `password`: optional connection credentials.
- `readonly`: use a read-only connection.
- `autocreate`: automatically create the database if it does not exist (unless in readonly mode).
- `autocreate`: automatically create the database
if it does not exist (unless in readonly mode).
- `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged.
@ -108,26 +109,43 @@ class Database(object):
self.db_name = db_name
self.db_url = db_url
self.readonly = False
self._readonly = readonly
self.auto_create = auto_create
self.timeout = timeout
self.request_session = httpx.Client(verify=verify_ssl_cert, timeout=timeout)
self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout)
if username:
self.request_session.auth = (username, password or '')
self.log_statements = log_statements
self.settings = {}
self.db_exists = False # this is required before running _is_existing_database
self.connection_readonly = False
self.server_version = None
self.server_timezone = None
self.has_codec_support = None
self.has_low_cardinality_support = None
self._init = False
if self._client_class is httpx.Client:
self.init()
def init(self):
if self._init:
return
self.db_exists = self._is_existing_database()
if readonly:
if self._readonly:
if not self.db_exists:
raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection'
)
self.connection_readonly = self._is_connection_readonly()
self.readonly = True
elif autocreate and not self.db_exists:
elif self.auto_create and not self.db_exists:
self.create_database()
self.server_version = self._get_server_version()
# Versions 1.1.53981 and below don't have timezone function
self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc
if self.server_version > (1, 1, 53981):
self.server_timezone = self._get_server_timezone()
else:
self.server_timezone = pytz.utc
# Versions 19.1.16 and above support codec compression
self.has_codec_support = self.server_version >= (19, 1, 16)
# Version 19.0 and above support LowCardinality
@ -147,7 +165,7 @@ class Database(object):
self._send('DROP DATABASE `%s`' % self.db_name)
self.db_exists = False
def create_table(self, model_class: Type[MODEL]) -> None:
def create_table(self, model_class: type[MODEL]) -> None:
"""
Creates a table for the given model class, if it does not exist already.
"""
@ -157,7 +175,7 @@ class Database(object):
raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self))
def drop_table(self, model_class: Type[MODEL]) -> None:
def drop_table(self, model_class: type[MODEL]) -> None:
"""
Drops the database table of the given model class, if it exists.
"""
@ -165,7 +183,7 @@ class Database(object):
raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class: Type[MODEL]) -> bool:
def does_table_exist(self, model_class: type[MODEL]) -> bool:
"""
Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name.
@ -215,9 +233,9 @@ class Database(object):
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
- `batch_size`: number of records to send per chunk
(use a lower number if your records are very large).
"""
from io import BytesIO
i = iter(model_instances)
try:
first_instance = next(i)
@ -257,8 +275,8 @@ class Database(object):
def count(
self,
model_class: Optional[Type[MODEL]],
conditions: Optional[Union[str, Q]] = None
model_class: Optional[type[MODEL]],
conditions: Optional[Union[str, 'Q']] = None
) -> int:
"""
Counts the number of records in the model's table.
@ -267,6 +285,7 @@ class Database(object):
- `conditions`: optional SQL conditions (contents of the WHERE clause).
"""
from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
if conditions:
if isinstance(conditions, Q):
@ -279,7 +298,7 @@ class Database(object):
def select(
self,
query: str,
model_class: Optional[Type[MODEL]] = None,
model_class: Optional[type[MODEL]] = None,
settings: Optional[dict] = None
) -> Generator[MODEL, None, None]:
"""
@ -297,7 +316,8 @@ class Database(object):
lines = r.iter_lines()
field_names = parse_tsv(next(lines))
field_types = parse_tsv(next(lines))
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
if not model_class:
model_class = ModelBase.create_ad_hoc_model(zip(field_names, field_types))
for line in lines:
# skip blank line left by WITH TOTALS modifier
if line:
@ -318,7 +338,7 @@ class Database(object):
def paginate(
self,
model_class: Type[MODEL],
model_class: type[MODEL],
order_by: str,
page_num: int = 1,
page_size: int = 100,
@ -371,7 +391,8 @@ class Database(object):
containing the migrations.
- `up_to` - number of the last migration to apply.
"""
from .migrations import MigrationHistory
from .migrations import MigrationHistory # pylint: disable=C0415
logger = logging.getLogger('migrations')
applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name)
@ -380,7 +401,11 @@ class Database(object):
logger.info('Applying migration %s...', name)
for operation in modules[name].operations:
operation.apply(self)
self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())])
self.insert([MigrationHistory(
package_name=migrations_package_name,
module_name=name,
applied=datetime.date.today())
])
if int(name[:4]) >= up_to:
break
@ -398,7 +423,8 @@ class Database(object):
return params
def _get_applied_migrations(self, migrations_package_name):
from .migrations import MigrationHistory
from .migrations import MigrationHistory # pylint: disable=C0415
self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
@ -450,16 +476,16 @@ class Database(object):
try:
r = self._send('SELECT timezone()')
return pytz.timezone(r.text.strip())
except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e)
except ServerError as err:
logger.exception('Cannot determine server timezone (%s), assuming UTC', err)
return pytz.utc
def _get_server_version(self, as_tuple=True):
try:
r = self._send('SELECT version();')
ver = r.text
except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e)
except ServerError as err:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', err)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver

View File

@ -1,15 +1,22 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, annotations
import logging
from typing import TYPE_CHECKING, Optional, Union
from .utils import comma_join, get_subclass_names
if TYPE_CHECKING:
from clickhouse_orm.database import Database
from clickhouse_orm.models import Model
from clickhouse_orm.funcs import F
logger = logging.getLogger('clickhouse_orm')
class Engine(object):
class Engine:
def create_table_sql(self, db):
def create_table_sql(self, db: Database) -> str:
raise NotImplementedError() # pragma: no cover
@ -34,9 +41,15 @@ class Memory(Engine):
class MergeTree(Engine):
def __init__(
self, date_col=None, order_by=(), sampling_expr=None,
index_granularity=8192, replica_table_path=None,
replica_name=None, partition_key=None, primary_key=None
self,
date_col: Optional[str] = None,
order_by: Union[list, tuple] = (),
sampling_expr: Optional[F] = None,
index_granularity: int = 8192,
replica_table_path: Optional[str] = None,
replica_name: Optional[str] = None,
partition_key: Optional[Union[list, tuple]] = None,
primary_key: Optional[Union[list, tuple]] = None
):
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple'
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present'
@ -73,7 +86,7 @@ class MergeTree(Engine):
'Use `order_by` attribute instead')
self.order_by = value
def create_table_sql(self, db):
def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__
if self.replica_name:
name = 'Replicated' + name
@ -108,7 +121,7 @@ class MergeTree(Engine):
params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db):
def _build_sql_params(self, db: Database) -> list[str]:
params = []
if self.replica_name:
params += ["'%s'" % self.replica_table_path, "'%s'" % self.replica_name]
@ -140,8 +153,8 @@ class CollapsingMergeTree(MergeTree):
)
self.sign_col = sign_col
def _build_sql_params(self, db):
params = super(CollapsingMergeTree, self)._build_sql_params(db)
def _build_sql_params(self, db: Database) -> list[str]:
params = super()._build_sql_params(db)
params.append(self.sign_col)
return params
@ -161,7 +174,7 @@ class SummingMergeTree(MergeTree):
'summing_cols must be a list or tuple'
self.summing_cols = summing_cols
def _build_sql_params(self, db):
def _build_sql_params(self, db: Database) -> list[str]:
params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols))
@ -181,7 +194,7 @@ class ReplacingMergeTree(MergeTree):
)
self.ver_col = ver_col
def _build_sql_params(self, db):
def _build_sql_params(self, db: Database) -> list[str]:
params = super(ReplacingMergeTree, self)._build_sql_params(db)
if self.ver_col:
params.append(self.ver_col)
@ -195,9 +208,17 @@ class Buffer(Engine):
Read more [here](https://clickhouse.tech/docs/en/engines/table-engines/special/buffer/).
"""
#Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes)
def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000,
max_rows=1000000, min_bytes=10000000, max_bytes=100000000):
def __init__(
self,
main_model: type[Model],
num_layers: int = 16,
min_time: int = 10,
max_time: int = 100,
min_rows: int = 10000,
max_rows: int = 1000000,
min_bytes: int = 10000000,
max_bytes: int = 100000000
):
self.main_model = main_model
self.num_layers = num_layers
self.min_time = min_time
@ -207,7 +228,7 @@ class Buffer(Engine):
self.min_bytes = min_bytes
self.max_bytes = max_bytes
def create_table_sql(self, db):
def create_table_sql(self, db: Database) -> str:
# Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % (
@ -226,11 +247,11 @@ class Merge(Engine):
https://clickhouse.tech/docs/en/engines/table-engines/special/merge/
"""
def __init__(self, table_regex):
def __init__(self, table_regex: str):
assert isinstance(table_regex, str), "'table_regex' parameter must be string"
self.table_regex = table_regex
def create_table_sql(self, db):
def create_table_sql(self, db: Database) -> str:
return "Merge(`%s`, '%s')" % (db.db_name, self.table_regex)
@ -258,23 +279,22 @@ class Distributed(Engine):
self.sharding_key = sharding_key
@property
def table_name(self):
# TODO: circular import is bad
from .models import ModelBase
def table_name(self) -> str:
from clickhouse_orm.models import Model
table = self.table
if isinstance(table, ModelBase):
if isinstance(table, Model):
return table.table_name()
return table
def create_table_sql(self, db):
def create_table_sql(self, db: Database) -> str:
name = self.__class__.__name__
params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params))
def _build_sql_params(self, db):
def _build_sql_params(self, db: Database) -> list[str]:
if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format(
self.__class__.__name__))

View File

@ -1,10 +1,14 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, annotations
import re
from enum import Enum
from uuid import UUID
from calendar import timegm
import datetime
from decimal import Decimal, localcontext
from logging import getLogger
from ipaddress import IPv4Address, IPv6Address
from uuid import UUID
from typing import TYPE_CHECKING, Any, Optional, Union, Iterable
import iso8601
import pytz
@ -13,6 +17,10 @@ from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .funcs import F, FunctionOperatorsMixin
if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm')
@ -20,20 +28,27 @@ class Field(FunctionOperatorsMixin):
"""
Abstract base class for all field types.
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered
class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
name: str = None # this is set by the parent model
parent: type["Model"] = None # this is set by the parent model
creation_counter: int = 0 # used for keeping the model fields ordered
class_default: Any = 0 # should be overridden by concrete subclasses
db_type: str # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None):
def __init__(
self,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \
"Alias parameter must be a string or function object, if given"
assert materialized is None or isinstance(materialized, F) or isinstance(materialized,
str) and materialized != "", \
assert (materialized is None or isinstance(materialized, F) or
isinstance(materialized, str) and materialized != ""), \
"Materialized parameter must be a string or function object, if given"
assert readonly is None or type(
readonly) is bool, "readonly parameter must be bool if given"
@ -78,7 +93,8 @@ class Field(FunctionOperatorsMixin):
"""
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (
self.__class__.__name__, value, min_value, max_value))
self.__class__.__name__, value, min_value, max_value
))
def to_db_string(self, value, quote=True):
"""
@ -87,7 +103,7 @@ class Field(FunctionOperatorsMixin):
"""
return escape(value, quote)
def get_sql(self, with_default_expression=True, db=None):
def get_sql(self, with_default_expression=True, db=None) -> str:
"""
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
@ -107,7 +123,7 @@ class Field(FunctionOperatorsMixin):
"""Returns field type arguments"""
return []
def _extra_params(self, db):
def _extra_params(self, db: Database) -> str:
sql = ''
if self.alias:
sql += ' ALIAS %s' % string_or_func(self.alias)
@ -122,7 +138,7 @@ class Field(FunctionOperatorsMixin):
sql += ' CODEC(%s)' % self.codec
return sql
def isinstance(self, types):
def isinstance(self, types) -> bool:
"""
Checks if the instance if one of the types provided or if any of the inner_field child is one of the types
provided, returns True if field or any inner_field is one of ths provided, False otherwise
@ -145,7 +161,7 @@ class StringField(Field):
class_default = ''
db_type = 'String'
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> str:
if isinstance(value, str):
return value
if isinstance(value, bytes):
@ -155,13 +171,20 @@ class StringField(Field):
class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
def __init__(
self,
length: int,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None,
db_column: Optional[str] = None
):
self._length = length
self.db_type = 'FixedString(%d)' % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> str:
value = super(FixedStringField, self).to_python(value, timezone_in_use)
return value.rstrip('\0')
@ -169,8 +192,9 @@ class FixedStringField(StringField):
if isinstance(value, str):
value = value.encode('UTF-8')
if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (
len(value), self._length))
raise ValueError(
f'Value of {len(value)} bytes is too long for FixedStringField({self._length})'
)
class DateField(Field):
@ -179,7 +203,7 @@ class DateField(Field):
class_default = min_value
db_type = 'Date'
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> datetime.date:
if isinstance(value, datetime.datetime):
return value.astimezone(pytz.utc).date() if value.tzinfo else value.date()
if isinstance(value, datetime.date):
@ -195,7 +219,7 @@ class DateField(Field):
def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value)
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
return escape(value.isoformat(), quote)
@ -203,8 +227,16 @@ class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None, timezone=None):
def __init__(
self,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None
):
super().__init__(default, alias, materialized, readonly, codec, db_column)
# assert not timezone, 'Temporarily field timezone is not supported'
if timezone:
@ -217,7 +249,7 @@ class DateTimeField(Field):
args.append(escape(self.timezone.zone))
return args
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> datetime.datetime:
if isinstance(value, datetime.datetime):
return value if value.tzinfo else value.replace(tzinfo=pytz.utc)
if isinstance(value, datetime.date):
@ -245,15 +277,34 @@ class DateTimeField(Field):
return dt
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
return escape('%010d' % timegm(value.utctimetuple()), quote)
class DateTime64Field(DateTimeField):
db_type = 'DateTime64'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
db_column=None, timezone=None, precision=6):
"""
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
"""
def __init__(
self,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None,
timezone: Optional[Union[BaseTzInfo, str]] = None,
precision: int = 6
):
super().__init__(default, alias, materialized, readonly, codec, db_column, timezone)
assert precision is None or isinstance(precision, int), 'Precision must be int type'
self.precision = precision
@ -264,7 +315,7 @@ class DateTime64Field(DateTimeField):
args.append(escape(self.timezone.zone))
return args
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
"""
Returns the field's value prepared for writing to the database
@ -278,7 +329,7 @@ class DateTime64Field(DateTimeField):
quote
)
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> datetime.datetime:
try:
return super().to_python(value, timezone_in_use)
except ValueError:
@ -302,13 +353,13 @@ class BaseIntField(Field):
Abstract base class for all integer-type fields.
"""
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> int:
try:
return int(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return str(value)
@ -370,13 +421,13 @@ class BaseFloatField(Field):
Abstract base class for all float-type fields.
"""
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> float:
try:
return float(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return str(value)
@ -395,8 +446,16 @@ class DecimalField(Field):
Base class for all decimal fields. Can also be used directly.
"""
def __init__(self, precision, scale, default=None, alias=None, materialized=None,
readonly=None, db_column=None):
def __init__(
self,
precision: int,
scale: int,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
self.precision = precision
@ -409,7 +468,7 @@ class DecimalField(Field):
self.min_value = -self.max_value
super(DecimalField, self).__init__(default, alias, materialized, readonly, db_column)
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> Decimal:
if not isinstance(value, Decimal):
try:
value = Decimal(value)
@ -419,7 +478,7 @@ class DecimalField(Field):
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
return self._round(value)
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return str(value)
@ -432,25 +491,45 @@ class DecimalField(Field):
class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
def __init__(
self,
scale: int,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
):
super().__init__(9, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal32(%d)' % scale
class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None,
db_column=None):
def __init__(
self,
scale: int,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
):
super().__init__(18, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal64(%d)' % scale
class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None,
readonly=None, db_column=None):
def __init__(
self,
scale: int,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
db_column: Optional[str] = None
):
super().__init__(38, scale, default, alias, materialized, readonly, db_column)
self.db_type = 'Decimal128(%d)' % scale
@ -460,8 +539,16 @@ class BaseEnumField(Field):
Abstract base class for all enum-type fields.
"""
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None,
codec=None, db_column=None):
def __init__(
self,
enum_cls: type[Enum],
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
@ -488,20 +575,18 @@ class BaseEnumField(Field):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
return escape(value.name, quote)
def get_db_type_args(self):
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
@classmethod
def create_ad_hoc_field(cls, db_type):
def create_ad_hoc_field(cls, db_type) -> BaseEnumField:
"""
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
"""
import re
from enum import Enum
members = {}
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
members[match.group(1)] = int(match.group(2))
@ -521,8 +606,16 @@ class Enum16Field(BaseEnumField):
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None, db_column=None):
def __init__(
self,
inner_field: Field,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
assert isinstance(inner_field, Field), \
"The first argument of ArrayField must be a Field instance"
assert not isinstance(inner_field, ArrayField), \
@ -543,22 +636,81 @@ class ArrayField(Field):
for v in value:
self.inner_field.validate(v)
def to_db_string(self, value, quote=True):
def to_db_string(self, value, quote=True) -> str:
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + comma_join(array) + ']'
def get_sql(self, with_default_expression=True, db=None):
def get_sql(self, with_default_expression=True, db=None) -> str:
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec
return sql
class TupleField(Field):
class_default = ()
def __init__(
self,
name_fields: list[tuple[str, Field]],
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
self.names = {}
self.inner_fields = []
for (name, field) in name_fields:
if name in self.names:
raise ValueError('The Field name conflict')
assert isinstance(field, Field), \
"The first argument of TupleField must be a Field instance"
assert not isinstance(field, (ArrayField, TupleField)), \
"Multidimensional array fields are not supported by the ORM"
self.names[name] = field
self.inner_fields.append(field)
self.class_default = tuple(field.class_default for field in self.inner_fields)
super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use) -> tuple:
if isinstance(value, str):
value = parse_array(value)
value = (self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8'))
value = (self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
elif not isinstance(value, (list, tuple)):
raise ValueError('TupleField expects list or tuple, not %s' % type(value))
return tuple(self.inner_fields[i].to_python(v, timezone_in_use)
for i, v in enumerate(value))
def validate(self, value):
for i, v in enumerate(value):
self.inner_fields[i].validate(v)
def to_db_string(self, value, quote=True) -> str:
array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)]
return '(' + comma_join(array) + ')'
def get_sql(self, with_default_expression=True, db=None) -> str:
inner_sql = ', '.join('%s %s' % (name, field.get_sql(False))
for name, field in self.names.items())
sql = 'Tuple(%s)' % inner_sql
if with_default_expression and self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec
return sql
class UUIDField(Field):
class_default = UUID(int=0)
db_type = 'UUID'
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> UUID:
if isinstance(value, UUID):
return value
elif isinstance(value, bytes):
@ -580,7 +732,7 @@ class IPv4Field(Field):
class_default = 0
db_type = 'IPv4'
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> IPv4Address:
if isinstance(value, IPv4Address):
return value
elif isinstance(value, (bytes, str, int)):
@ -596,7 +748,7 @@ class IPv6Field(Field):
class_default = 0
db_type = 'IPv6'
def to_python(self, value, timezone_in_use):
def to_python(self, value, timezone_in_use) -> IPv6Address:
if isinstance(value, IPv6Address):
return value
elif isinstance(value, (bytes, str, int)):
@ -611,8 +763,16 @@ class IPv6Field(Field):
class NullableField(Field):
class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None,
extra_null_values=None, codec=None):
def __init__(
self,
inner_field: Field,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
extra_null_values: Optional[Iterable] = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
assert isinstance(inner_field, Field), \
"The first argument of NullableField must be a Field instance." \
" Not: {}".format(inner_field)
@ -620,8 +780,9 @@ class NullableField(Field):
self._null_values = [None]
if extra_null_values:
self._null_values.extend(extra_null_values)
super(NullableField, self).__init__(default, alias, materialized, readonly=None,
codec=codec)
super().__init__(
default, alias, materialized, readonly=None, codec=codec, db_column=db_column
)
def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values:
@ -645,8 +806,16 @@ class NullableField(Field):
class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None,
codec=None):
def __init__(
self,
inner_field: Field,
default: Any = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: Optional[bool] = None,
codec: Optional[str] = None,
db_column: Optional[str] = None
):
assert isinstance(inner_field, Field), \
"The first argument of LowCardinalityField must be a Field instance." \
" Not: {}".format(inner_field)
@ -657,7 +826,7 @@ class LowCardinalityField(Field):
" Use Array(LowCardinality) instead"
self.inner_field = inner_field
self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use):
return self.inner_field.to_python(value, timezone_in_use)

View File

@ -75,7 +75,7 @@ def parametric(func):
return wrapper
class FunctionOperatorsMixin(object):
class FunctionOperatorsMixin:
"""
A mixin for implementing Python operators using F objects.
"""
@ -248,7 +248,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
def __repr__(self):
return self.to_sql()
def to_sql(self, *args):
def to_sql(self, *args) -> str:
"""
Generates an SQL string for this function and its arguments.
For example if the function name is a symbol of a binary operator:
@ -898,7 +898,6 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
@staticmethod
def replace(haystack, pattern, replacement):
return F('replace', haystack, pattern, replacement)
replaceAll = replace
@staticmethod
def replaceAll(haystack, pattern, replacement):

View File

@ -1,30 +1,31 @@
from .models import Model, BufferModel
from .fields import DateField, StringField
from .engines import MergeTree
from .utils import escape, get_subclass_names
from .utils import get_subclass_names
import logging
logger = logging.getLogger('migrations')
class Operation():
'''
class Operation:
"""
Base class for migration operations.
'''
"""
def apply(self, database):
raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation):
'''
"""
Base class for migration operations that work on a specific model.
'''
"""
def __init__(self, model_class):
'''
"""
Initializer.
'''
"""
self.model_class = model_class
self.table_name = model_class.table_name()
@ -38,9 +39,9 @@ class ModelOperation(Operation):
class CreateTable(ModelOperation):
'''
"""
A migration operation that creates a table for a given model class.
'''
"""
def apply(self, database):
logger.info(' Create table %s', self.table_name)
@ -50,14 +51,14 @@ class CreateTable(ModelOperation):
class AlterTable(ModelOperation):
'''
"""
A migration operation that compares the table of a given model class to
the model's fields, and alters the table to match the model. The operation can:
- add new columns
- drop obsolete columns
- modify column types
Default values are not altered by this operation.
'''
"""
def _get_table_fields(self, database):
query = "DESC `%s`.`%s`" % (database.db_name, self.table_name)
@ -113,11 +114,11 @@ class AlterTable(ModelOperation):
class AlterTableWithBuffer(ModelOperation):
'''
"""
A migration operation for altering a buffer table and its underlying on-disk table.
The buffer table is dropped, the on-disk table is altered, and then the buffer table
is re-created.
'''
"""
def apply(self, database):
if issubclass(self.model_class, BufferModel):
@ -129,9 +130,9 @@ class AlterTableWithBuffer(ModelOperation):
class DropTable(ModelOperation):
'''
"""
A migration operation that drops the table of a given model class.
'''
"""
def apply(self, database):
logger.info(' Drop table %s', self.table_name)
@ -139,12 +140,12 @@ class DropTable(ModelOperation):
class AlterConstraints(ModelOperation):
'''
"""
A migration operation that adds new constraints from the model to the database
table, and drops obsolete ones. Constraints are identified by their names, so
a change in an existing constraint will not be detected unless its name was changed too.
ClickHouse does not check that the constraints hold for existing data in the table.
'''
"""
def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name)
@ -163,9 +164,9 @@ class AlterConstraints(ModelOperation):
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name)
def _get_constraint_names(self, database):
'''
"""
Returns a set containing the names of existing constraints in the table.
'''
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def)
@ -173,19 +174,19 @@ class AlterConstraints(ModelOperation):
class AlterIndexes(ModelOperation):
'''
"""
A migration operation that adds new indexes from the model to the database
table, and drops obsolete ones. Indexes are identified by their names, so
a change in an existing index will not be detected unless its name was changed too.
'''
"""
def __init__(self, model_class, reindex=False):
'''
"""
Initializer.
By default ClickHouse does not build indexes over existing data, only for
new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build
the indexes over the existing data.
'''
"""
super().__init__(model_class)
self.reindex = reindex
@ -211,9 +212,9 @@ class AlterIndexes(ModelOperation):
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name)
def _get_index_names(self, database):
'''
"""
Returns a set containing the names of existing indexes in the table.
'''
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def)
@ -221,9 +222,9 @@ class AlterIndexes(ModelOperation):
class RunPython(Operation):
'''
"""
A migration operation that executes a Python function.
'''
"""
def __init__(self, func):
'''
Initializer. The given Python function will be called with a single
@ -238,9 +239,9 @@ class RunPython(Operation):
class RunSQL(Operation):
'''
"""
A migration operation that executes arbitrary SQL statements.
'''
"""
def __init__(self, sql):
'''
@ -259,9 +260,9 @@ class RunSQL(Operation):
class MigrationHistory(Model):
'''
"""
A model for storing which migrations were already applied to the containing database.
'''
"""
package_name = StringField()
module_name = StringField()

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, annotations
import sys
from collections import OrderedDict
from itertools import chain
from logging import getLogger
from typing import TypeVar, Dict
from typing import TypeVar, Optional, TYPE_CHECKING, Any
import pytz
@ -13,6 +13,10 @@ from .query import QuerySet
from .funcs import F
from .engines import Merge, Distributed, Memory
if TYPE_CHECKING:
from clickhouse_orm.database import Database
logger = getLogger('clickhouse_orm')
@ -21,16 +25,16 @@ class Constraint:
Defines a model constraint.
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
name: Optional[str] = None # this is set by the parent model
parent: Optional[type["Model"]] = None # this is set by the parent model
def __init__(self, expr):
def __init__(self, expr: F):
"""
Initializer. Expects an expression that ClickHouse will verify when inserting data.
"""
self.expr = expr
def create_table_sql(self):
def create_table_sql(self) -> str:
"""
Returns the SQL statement for defining this constraint during table creation.
"""
@ -42,10 +46,10 @@ class Index:
Defines a data-skipping index.
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
name: Optional[str] = None # this is set by the parent model
parent: Optional[type["Model"]] = None # this is set by the parent model
def __init__(self, expr, type, granularity):
def __init__(self, expr: F, type: str, granularity: int):
"""
Initializer.
@ -58,11 +62,13 @@ class Index:
self.type = type
self.granularity = granularity
def create_table_sql(self):
def create_table_sql(self) -> str:
"""
Returns the SQL statement for defining this index during table creation.
"""
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (
self.name, arg_to_sql(self.expr), self.type, self.granularity
)
@staticmethod
def minmax():
@ -73,7 +79,7 @@ class Index:
return 'minmax'
@staticmethod
def set(max_rows):
def set(max_rows: int) -> str:
"""
An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
@ -82,7 +88,8 @@ class Index:
return 'set(%d)' % max_rows
@staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int,
number_of_hash_functions: int, random_seed: int) -> str:
"""
An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -93,10 +100,13 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
"""
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
return 'ngrambf_v1(%d, %d, %d, %d)' % (
n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
)
@staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int,
random_seed: int) -> str:
"""
An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters.
@ -106,16 +116,18 @@ class Index:
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
"""
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
return 'tokenbf_v1(%d, %d, %d)' % (
size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed
)
@staticmethod
def bloom_filter(false_positive=0.025):
'''
def bloom_filter(false_positive: float = 0.025) -> str:
"""
An index that stores a Bloom filter containing values of the index expression.
- `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter
'''
"""
return 'bloom_filter(%f)' % false_positive
@ -126,7 +138,7 @@ class ModelBase(type):
ad_hoc_model_cache = {}
def __new__(cls, name, bases, attrs):
def __new__(mcs, name, bases, attrs):
# Collect fields, constraints and indexes from parent classes
fields = {}
@ -147,7 +159,8 @@ class ModelBase(type):
elif isinstance(obj, Index):
indexes[n] = obj
# Convert fields to a list of (name, field) tuples, in the order they were listed in the class
# Convert fields to a list of (name, field) tuples
# in the order they were listed in the class
fields = sorted(fields.items(), key=lambda item: item[1].creation_counter)
# Build a dictionary of default values
@ -172,7 +185,7 @@ class ModelBase(type):
_defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults
)
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs)
# Let each field, constraint and index know its parent and its own name
for n, obj in chain(fields, constraints.items(), indexes.items()):
@ -265,6 +278,10 @@ class Model(metaclass=ModelBase):
cpu_percent = Float32Field()
engine = Memory()
"""
_has_funcs_as_defaults: bool
_constraints: dict[str, Constraint]
_indexes: dict[str, Index]
_writable_fields: dict
engine = None
@ -278,7 +295,7 @@ class Model(metaclass=ModelBase):
_database = None
_fields: Dict[str, Field]
_fields: dict[str, Field]
def __init__(self, **kwargs):
"""
@ -314,31 +331,32 @@ class Model(metaclass=ModelBase):
raise tp.with_traceback(tp(new_msg), tb)
super(Model, self).__setattr__(name, value)
def set_database(self, db):
def set_database(self, db: Database):
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
"""
# This can not be imported globally due to circular import
from .database import Database
from clickhouse_orm.database import Database
assert isinstance(db, Database), "database must be database.Database instance"
self._database = db
def get_database(self):
def get_database(self) -> Database:
"""
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
"""
return self._database
def get_field(self, name):
def get_field(self, name: str) -> Optional[Field]:
"""
Gets a `Field` instance given its name, or `None` if not found.
"""
return self._fields.get(name)
@classmethod
def table_name(cls):
def table_name(cls) -> str:
"""
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
@ -347,7 +365,7 @@ class Model(metaclass=ModelBase):
return cls.__name__.lower()
@classmethod
def has_funcs_as_defaults(cls):
def has_funcs_as_defaults(cls) -> bool:
"""
Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances.
@ -355,7 +373,7 @@ class Model(metaclass=ModelBase):
return cls._has_funcs_as_defaults
@classmethod
def create_table_sql(cls, db):
def create_table_sql(cls, db: Database) -> str:
"""
Returns the SQL statement for creating a table for this model.
"""
@ -377,7 +395,7 @@ class Model(metaclass=ModelBase):
return '\n'.join(parts)
@classmethod
def drop_table_sql(cls, db):
def drop_table_sql(cls, db: Database) -> str:
"""
Returns the SQL command for deleting this model's table.
"""
@ -432,7 +450,7 @@ class Model(metaclass=ModelBase):
parts.append(name + '=' + field.to_db_string(data[name], quote=False))
return '\t'.join(parts)
def to_db_string(self):
def to_db_string(self) -> bytes:
"""
Returns the instance as a bytestring ready to be inserted into the database.
"""
@ -440,7 +458,7 @@ class Model(metaclass=ModelBase):
s += '\n'
return s.encode('utf-8')
def to_dict(self, include_readonly=True, field_names=None):
def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]:
"""
Returns the instance's column values as a dict.
@ -456,14 +474,14 @@ class Model(metaclass=ModelBase):
return {name: data[name] for name in fields}
@classmethod
def objects_in(cls, database):
def objects_in(cls, database: Database) -> QuerySet:
"""
Returns a `QuerySet` for selecting instances of this model class.
"""
return QuerySet(cls, database)
@classmethod
def fields(cls, writable=False):
def fields(cls, writable: bool = False) -> dict[str, Field]:
"""
Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included.
@ -473,21 +491,21 @@ class Model(metaclass=ModelBase):
return cls._writable_fields if writable else cls._fields
@classmethod
def is_read_only(cls):
def is_read_only(cls) -> bool:
"""
Returns true if the model is marked as read only.
"""
return cls._readonly
@classmethod
def is_system_model(cls):
def is_system_model(cls) -> bool:
"""
Returns true if the model represents a system table.
"""
return cls._system
@classmethod
def is_temporary_model(cls):
def is_temporary_model(cls) -> bool:
"""
Returns true if the model represents a temporary table.
"""
@ -497,7 +515,7 @@ class Model(metaclass=ModelBase):
class BufferModel(Model):
@classmethod
def create_table_sql(cls, db) -> str:
def create_table_sql(cls, db: Database) -> str:
"""
Returns the SQL statement for creating a table for this model.
"""
@ -522,7 +540,7 @@ class MergeModel(Model):
_table = StringField(readonly=True)
@classmethod
def create_table_sql(cls, db) -> str:
def create_table_sql(cls, db: Database) -> str:
"""
Returns the SQL statement for creating a table for this model.
"""
@ -545,15 +563,14 @@ class DistributedModel(Model):
Model class for use with a `Distributed` engine.
"""
def set_database(self, db):
def set_database(self, db: Database):
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
"""
assert isinstance(self.engine, Distributed),\
"engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db)
return res
super().set_database(db)
@classmethod
def fix_engine_table(cls):
@ -607,7 +624,7 @@ class DistributedModel(Model):
cls.engine.table = storage_models[0]
@classmethod
def create_table_sql(cls, db) -> str:
def create_table_sql(cls, db: Database) -> str:
"""
Returns the SQL statement for creating a table for this model.
"""
@ -637,7 +654,7 @@ class TemporaryModel(Model):
_temporary = True
@classmethod
def create_table_sql(cls, db) -> str:
def create_table_sql(cls, db: Database) -> str:
assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance"
parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()]

View File

@ -1,21 +1,26 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, annotations
from math import ceil
from copy import copy, deepcopy
from types import CoroutineType
from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic
import pytz
from .utils import comma_join, string_or_func, arg_to_sql
# pylint: disable=R0903, W0212, C0415
# TODO
# - check that field names are valid
if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.database import Database, Page
class Operator(object):
class Operator:
"""
Base class for filtering operators.
"""
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
"""
Subclasses should implement this method. It returns an SQL string
that applies this operator on the given field and value.
@ -24,6 +29,7 @@ class Operator(object):
def _value_to_sql(self, field, value, quote=True):
from clickhouse_orm.funcs import F
if isinstance(value, F):
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -38,7 +44,7 @@ class SimpleOperator(Operator):
self._sql_operator = sql_operator
self._sql_for_null = sql_for_null
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
@ -55,7 +61,7 @@ class InOperator(Operator):
- a queryset (subquery)
"""
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
if isinstance(value, QuerySet):
value = value.as_sql()
@ -69,7 +75,7 @@ class InOperator(Operator):
class GlobalInOperator(Operator):
"""An operator that implements Group IN."""
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
if isinstance(value, QuerySet):
value = value.as_sql()
@ -90,14 +96,13 @@ class LikeOperator(Operator):
self._pattern = pattern
self._case_sensitive = case_sensitive
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
pattern = self._pattern.format(value)
if self._case_sensitive:
return '%s LIKE \'%s\'' % (field.name, pattern)
else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern)
@ -106,7 +111,7 @@ class IExactOperator(Operator):
An operator for case insensitive string comparison.
"""
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value)
@ -120,7 +125,7 @@ class NotOperator(Operator):
def __init__(self, base_operator):
self._base_operator = base_operator
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
# Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value)
@ -135,7 +140,7 @@ class BetweenOperator(Operator):
- '<= value[1]' if value[0] is None or empty
"""
def to_sql(self, model_cls, field_name, value):
def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str:
field = getattr(model_cls, field_name)
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(
str(value[0])) > 0 else None
@ -151,10 +156,10 @@ class BetweenOperator(Operator):
# Define the set of builtin operators
_operators = {}
_operators: dict[str, Operator] = {}
def register_operator(name, sql):
def register_operator(name: str, sql: Operator):
_operators[name] = sql
@ -178,12 +183,12 @@ register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator())
class Cond(object):
class Cond:
"""
An abstract object for storing a single query condition Field + Operator + Value.
"""
def to_sql(self, model_cls):
def to_sql(self, model_cls: type[Model]) -> str:
raise NotImplementedError
@ -192,7 +197,7 @@ class FieldCond(Cond):
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value):
def __init__(self, field_name: str, operator: str, value: Any):
self._field_name = field_name
self._operator = _operators.get(operator)
if self._operator is None:
@ -201,7 +206,7 @@ class FieldCond(Cond):
self._operator = _operators['eq']
self._value = value
def to_sql(self, model_cls):
def to_sql(self, model_cls: type[Model]) -> str:
return self._operator.to_sql(model_cls, self._field_name, self._value)
def __deepcopy__(self, memodict={}):
@ -210,7 +215,7 @@ class FieldCond(Cond):
return res
class Q(object):
class Q:
AND_MODE = 'AND'
OR_MODE = 'OR'
@ -222,7 +227,7 @@ class Q(object):
self._mode = self.AND_MODE
@property
def is_empty(self):
def is_empty(self) -> bool:
"""
Checks if there are any conditions in Q object
Returns: Boolean
@ -252,7 +257,7 @@ class Q(object):
field_name, operator = key, 'eq'
return FieldCond(field_name, operator, value)
def to_sql(self, model_cls):
def to_sql(self, model_cls: type[Model]) -> str:
condition_sql = []
if self._conds:
@ -276,13 +281,13 @@ class Q(object):
return sql
def __or__(self, other):
def __or__(self, other) -> "Q":
return Q._construct_from(self, other, self.OR_MODE)
def __and__(self, other):
def __and__(self, other) -> "Q":
return Q._construct_from(self, other, self.AND_MODE)
def __invert__(self):
def __invert__(self) -> "Q":
q = copy(self)
q._negate = True
return q
@ -302,14 +307,14 @@ class Q(object):
return q
class QuerySet(object):
class QuerySet:
"""
A queryset is an object that represents a database query using a specific `Model`.
It is lazy, meaning that it does not hit the database until you iterate over its
matching rows (model instances).
"""
def __init__(self, model_cls, database):
def __init__(self, model_cls: type[Model], database: Database):
"""
Initializer. It is possible to create a queryset like this, but the standard
way is to use `MyModel.objects_in(database)`.
@ -350,6 +355,9 @@ class QuerySet(object):
return self._database.select(self.as_sql(), self._model_cls)
async def __aiter__(self):
from clickhouse_orm.aio.database import AioDatabase
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
async for r in self._database.select(self.as_sql(), self._model_cls):
yield r
@ -365,6 +373,14 @@ class QuerySet(object):
def __str__(self):
return self.as_sql()
@overload
def __getitem__(self, s: int) -> Model:
...
@overload
def __getitem__(self, s: slice) -> "QuerySet":
...
def __getitem__(self, s):
if isinstance(s, int):
# Single index
@ -372,7 +388,6 @@ class QuerySet(object):
qs = copy(self)
qs._limits = (s, 1)
return next(iter(qs))
else:
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
start = s.start or 0
@ -383,7 +398,7 @@ class QuerySet(object):
qs._limits = (start, stop - start)
return qs
def limit_by(self, offset_limit, *fields_or_expr):
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet":
"""
Adds a LIMIT BY clause to the query.
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
@ -400,7 +415,7 @@ class QuerySet(object):
qs._limit_by_fields = fields_or_expr
return qs
def select_fields_as_sql(self):
def select_fields_as_sql(self) -> str:
"""
Returns the selected fields or expressions as a SQL string.
"""
@ -409,7 +424,7 @@ class QuerySet(object):
fields = comma_join('`%s`' % field for field in self._fields)
return fields
def as_sql(self):
def as_sql(self) -> str:
"""
Returns the whole query as a SQL string.
"""
@ -419,7 +434,7 @@ class QuerySet(object):
if self._model_cls.is_system_model():
table_name = '`system`.' + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = u'SELECT %s%s\nFROM %s%s' % params
sql = 'SELECT %s%s\nFROM %s%s' % params
if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True)
@ -445,7 +460,7 @@ class QuerySet(object):
return sql
def order_by_as_sql(self):
def order_by_as_sql(self) -> str:
"""
Returns the contents of the query's `ORDER BY` clause as a string.
"""
@ -454,20 +469,20 @@ class QuerySet(object):
for field in self._order_by
])
def conditions_as_sql(self, prewhere=False):
def conditions_as_sql(self, prewhere=False) -> str:
"""
Returns the contents of the query's `WHERE` or `PREWHERE` clause as a string.
"""
q_object = self._prewhere_q if prewhere else self._where_q
return q_object.to_sql(self._model_cls)
def count(self):
def count(self) -> Union[int, Coroutine[int]]:
"""
Returns the number of matching model instances.
"""
if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = 'SELECT count() FROM (%s)' % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -475,7 +490,7 @@ class QuerySet(object):
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
return self._database.count(self._model_cls, conditions)
def order_by(self, *field_names):
def order_by(self, *field_names) -> "QuerySet":
"""
Returns a copy of this queryset with the ordering changed.
"""
@ -483,7 +498,7 @@ class QuerySet(object):
qs._order_by = field_names
return qs
def only(self, *field_names):
def only(self, *field_names) -> "QuerySet":
"""
Returns a copy of this queryset limited to the specified field names.
Useful when there are large fields that are not needed,
@ -493,8 +508,8 @@ class QuerySet(object):
qs._fields = field_names
return qs
def _filter_or_exclude(self, *q, **kwargs):
from .funcs import F
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet":
from clickhouse_orm.funcs import F
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
@ -524,21 +539,21 @@ class QuerySet(object):
return qs
def filter(self, *q, **kwargs):
def filter(self, *q, **kwargs) -> "QuerySet":
"""
Returns a copy of this queryset that includes only rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
"""
return self._filter_or_exclude(*q, **kwargs)
def exclude(self, *q, **kwargs):
def exclude(self, *q, **kwargs) -> "QuerySet":
"""
Returns a copy of this queryset that excludes all rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
"""
return self._filter_or_exclude(*q, _inverse=True, **kwargs)
def paginate(self, page_num=1, page_size=100):
def paginate(self, page_num=1, page_size=100) -> Page:
"""
Returns a single page of model instances that match the queryset.
Note that `order_by` should be used first, to ensure a correct
@ -550,7 +565,8 @@ class QuerySet(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
"""
from .database import Page
from clickhouse_orm.database import Page
count = self.count()
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
@ -566,7 +582,7 @@ class QuerySet(object):
page_size=page_size
)
def distinct(self):
def distinct(self) -> "QuerySet":
"""
Adds a DISTINCT clause to the query, meaning that any duplicate rows
in the results will be omitted.
@ -575,12 +591,13 @@ class QuerySet(object):
qs._distinct = True
return qs
def final(self):
def final(self) -> "QuerySet":
"""
Adds a FINAL modifier to table, meaning data will be collapsed to final version.
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
"""
from .engines import CollapsingMergeTree, ReplacingMergeTree
from clickhouse_orm.engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError(
'final() method can be used only with the CollapsingMergeTree'
@ -591,7 +608,7 @@ class QuerySet(object):
qs._final = True
return qs
def delete(self):
def delete(self) -> "QuerySet":
"""
Deletes all records matched by this queryset's conditions.
Note that ClickHouse performs deletions in the background, so they are not immediate.
@ -602,7 +619,7 @@ class QuerySet(object):
self._database.raw(sql)
return self
def update(self, **kwargs):
def update(self, **kwargs) -> "QuerySet":
"""
Updates all records matched by this queryset's conditions.
Keyword arguments specify the field names and expressions to use for the update.
@ -627,7 +644,7 @@ class QuerySet(object):
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()'
def aggregate(self, *args, **kwargs):
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet":
"""
Returns an `AggregateQuerySet` over this query, with `args` serving as
grouping fields and `kwargs` serving as calculated fields. At least one
@ -650,7 +667,12 @@ class AggregateQuerySet(QuerySet):
A queryset used for aggregation.
"""
def __init__(self, base_qs, grouping_fields, calculated_fields):
def __init__(
self,
base_queryset: QuerySet,
grouping_fields: tuple[Any],
calculated_fields: dict[str, str]
):
"""
Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`.
@ -658,24 +680,26 @@ class AggregateQuerySet(QuerySet):
```
('event_type', 'event_subtype')
```
The calculated fields should be a mapping from name to a ClickHouse aggregation function. For example:
The calculated fields should be a mapping from name to a ClickHouse aggregation function.
For example:
```
{'weekday': 'toDayOfWeek(event_date)', 'number_of_events': 'count()'}
```
At least one calculated field is required.
"""
super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database)
super().__init__(base_queryset._model_cls, base_queryset._database)
assert calculated_fields, 'No calculated fields specified for aggregation'
self._fields = grouping_fields
self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields
self._order_by = list(base_qs._order_by)
self._where_q = base_qs._where_q
self._prewhere_q = base_qs._prewhere_q
self._limits = base_qs._limits
self._distinct = base_qs._distinct
self._order_by = list(base_queryset._order_by)
self._where_q = base_queryset._where_q
self._prewhere_q = base_queryset._prewhere_q
self._limits = base_queryset._limits
self._distinct = base_queryset._distinct
def group_by(self, *args):
def group_by(self, *args) -> "AggregateQuerySet":
"""
This method lets you specify the grouping fields explicitly. The `args` must
be names of grouping fields or calculated fields that this queryset was
@ -700,7 +724,7 @@ class AggregateQuerySet(QuerySet):
"""
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet')
def select_fields_as_sql(self):
def select_fields_as_sql(self) -> str:
"""
Returns the selected fields or expressions as a SQL string.
"""
@ -710,15 +734,17 @@ class AggregateQuerySet(QuerySet):
def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model
def count(self):
def count(self) -> Union[int, Coroutine[int]]:
"""
Returns the number of rows after aggregation.
"""
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = 'SELECT count() FROM (%s)' % self.as_sql()
raw = self._database.raw(sql)
if isinstance(raw, CoroutineType):
return raw
return int(raw) if raw else 0
def with_totals(self):
def with_totals(self) -> "AggregateQuerySet":
"""
Adds WITH TOTALS modifier ot GROUP BY, making query return extra row
with aggregate function calculated across all the rows. More information:

View File

@ -1,17 +1,17 @@
import uuid
from typing import Optional
from contextvars import ContextVar
from contextvars import ContextVar, Token
ctx_session_id: ContextVar[str] = ContextVar('ck.session_id')
ctx_session_timeout: ContextVar[int] = ContextVar('ck.session_timeout')
ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout')
class SessionContext:
def __init__(self, session: str, timeout: int):
def __init__(self, session: str, timeout: float):
self.session = session
self.timeout = timeout
self.token1 = None
self.token2 = None
self.token1: Optional[Token[str]] = None
self.token2: Optional[Token[float]] = None
def __enter__(self) -> str:
self.token1 = ctx_session_id.set(self.session)

View File

@ -114,7 +114,7 @@ def parse_array(array_string):
array_string = array_string[match.end():]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]", array_string)
match = re.search(r",|\]|\)", array_string)
values.append(array_string[0: match.start()])
array_string = array_string[match.end() - 1:]