diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..eebf614 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,558 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10.0 + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the ignore-list. The +# regex matches against paths. +ignore-paths=protobuf + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +init-hook='from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))' + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=no + +# Min Python version to use for version dependend checks. Will default to the +# version used to run pylint. +py-version=3.9 + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + consider-using-f-string + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=colorized + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear and the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=no + +# Signatures are removed from the similarity computation +ignore-signatures=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i,j,k,ex,Run,_,r,s,n,x,y,z,f,v,db,c,id,t,ap,pk,qs,q,Q + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=mcs + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +# 函数或方法的最大参数个数 +max-args=10 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=18 + +# Maximum number of locals for function / method body. +max-locals=25 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +# 函数或方法内出现的最大return/yield语句的数量 +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3df5f50..c7ebe86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "iso8601 >= 0.1.12", "setuptools" ] -version = "0.0.6" +version = "0.1.0" [tool.setuptools.packages.find] where = ["src"] diff --git a/src/clickhouse_orm/__init__.py b/src/clickhouse_orm/__init__.py index 292e25c..8a46492 100644 --- a/src/clickhouse_orm/__init__.py +++ b/src/clickhouse_orm/__init__.py @@ -1,4 +1,4 @@ -__import__("pkg_resources").declare_namespace(__name__) +from inspect import isclass from clickhouse_orm.database import * from clickhouse_orm.engines import * @@ -9,5 +9,4 @@ from clickhouse_orm.models import * from clickhouse_orm.query import * from clickhouse_orm.system_models import * -from inspect import isclass __all__ = [c.__name__ for c in locals().values() if isclass(c)] diff --git a/src/clickhouse_orm/aio/database.py b/src/clickhouse_orm/aio/database.py index ef15f51..0b68669 100644 --- a/src/clickhouse_orm/aio/database.py +++ b/src/clickhouse_orm/aio/database.py @@ -1,7 +1,8 @@ import datetime import logging +from io import BytesIO from math import ceil -from typing import Type, Optional, Generator +from typing import Optional, AsyncGenerator import httpx import pytz @@ -11,29 +12,13 @@ from clickhouse_orm.utils import parse_tsv, import_submodules from clickhouse_orm.database import Database, ServerError, DatabaseException, logger, Page +# pylint: disable=C0116 + class AioDatabase(Database): + _client_class = httpx.AsyncClient - def __init__( - self, db_name, db_url='http://localhost:18123/', username=None, - password=None, readonly=False, auto_create=True, timeout=60, - verify_ssl_cert=True, log_statements=False - ): - self.db_name = db_name - self.db_url = db_url - self.readonly = False - self._readonly = readonly - self.auto_create = auto_create - self.timeout = timeout - self.request_session = httpx.AsyncClient(verify=verify_ssl_cert, timeout=timeout) - if username: - self.request_session.auth = (username, password or '') - self.log_statements = log_statements - self.settings = {} - self._db_check = False - self.db_exists = False - - async def db_check(self): - if self._db_check: + async def init(self): + if self._init: return self.db_exists = await self._is_existing_database() if self._readonly: @@ -52,7 +37,7 @@ class AioDatabase(Database): self.server_timezone = pytz.utc self.has_codec_support = self.server_version >= (19, 1, 16) self.has_low_cardinality_support = self.server_version >= (19, 0) - self._db_check = True + self._init = True async def close(self): await self.request_session.aclose() @@ -76,9 +61,9 @@ class AioDatabase(Database): """ from clickhouse_orm.query import Q - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) query = 'SELECT count() FROM $table' @@ -94,9 +79,9 @@ class AioDatabase(Database): """ Creates the database on the ClickHouse server if it does not already exist. """ - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) await self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) @@ -106,50 +91,65 @@ class AioDatabase(Database): """ Deletes the database on the ClickHouse server. """ - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) await self._send('DROP DATABASE `%s`' % self.db_name) self.db_exists = False - async def create_table(self, model_class: Type[MODEL]) -> None: + async def create_table(self, model_class: type[MODEL]) -> None: """ Creates a table for the given model class, if it does not exist already. """ - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) - if model_class.is_system_model(): raise DatabaseException("You can't create system table") + if model_class.is_temporary_model() and self.session_id is None: + raise DatabaseException( + "Creating a temporary table must be within the lifetime of a session " + ) if getattr(model_class, 'engine') is None: - raise DatabaseException("%s class must define an engine" % model_class.__name__) + raise DatabaseException(f"%s class must define an engine" % model_class.__name__) await self._send(model_class.create_table_sql(self)) - async def drop_table(self, model_class: Type[MODEL]) -> None: + async def create_temporary_table(self, model_class: type[MODEL], table_name: str = None): + """ + Creates a temporary table for the given model class, if it does not exist already. + And you can specify the temporary table name explicitly. + """ + if not self._init: + raise DatabaseException( + 'The AioDatabase object must execute the init method before it can be used' + ) + + await self._send(model_class.create_temporary_table_sql(self, table_name)) + + async def drop_table(self, model_class: type[MODEL]) -> None: """ Drops the database table of the given model class, if it exists. """ - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) if model_class.is_system_model(): raise DatabaseException("You can't drop system table") await self._send(model_class.drop_table_sql(self)) - async def does_table_exist(self, model_class: Type[MODEL]) -> bool: + async def does_table_exist(self, model_class: type[MODEL]) -> bool: """ Checks whether a table for the given model class already exists. Note that this only checks for existence of a table with the expected name. """ - if not self._db_check: + if not self._init: raise DatabaseException( - 'The AioDatabase object must execute the `db_check` method before it can be used' + 'The AioDatabase object must execute the init method before it can be used' ) sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" @@ -185,8 +185,6 @@ class AioDatabase(Database): - `model_instances`: any iterable containing instances of a single model class. - `batch_size`: number of records to send per chunk (use a lower number if your records are very large). """ - from io import BytesIO - i = iter(model_instances) try: first_instance = next(i) @@ -202,7 +200,7 @@ class AioDatabase(Database): fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt) - def gen(): + async def gen(): buf = BytesIO() buf.write(self._substitute(query, model_class).encode('utf-8')) first_instance.set_database(self) @@ -227,9 +225,9 @@ class AioDatabase(Database): async def select( self, query: str, - model_class: Optional[Type[MODEL]] = None, + model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None - ) -> Generator[MODEL, None, None]: + ) -> AsyncGenerator[MODEL, None]: """ Performs a query and returns a generator of model instances. @@ -269,7 +267,7 @@ class AioDatabase(Database): async def paginate( self, - model_class: Type[MODEL], + model_class: type[MODEL], order_by: str, page_num: int = 1, page_size: int = 100, @@ -355,16 +353,16 @@ class AioDatabase(Database): try: r = await self._send('SELECT timezone()') return pytz.timezone(r.text.strip()) - except ServerError as e: - logger.exception('Cannot determine server timezone (%s), assuming UTC', e) + except ServerError as err: + logger.exception('Cannot determine server timezone (%s), assuming UTC', err) return pytz.utc async def _get_server_version(self, as_tuple=True): try: r = await self._send('SELECT version();') ver = r.text - except ServerError as e: - logger.exception('Cannot determine server version (%s), assuming 1.1.0', e) + except ServerError as err: + logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) ver = '1.1.0' return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver @@ -375,3 +373,6 @@ class AioDatabase(Database): query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = self._substitute(query, MigrationHistory) return set(obj.module_name async for obj in self.select(query)) + + +__all__ = [AioDatabase] diff --git a/src/clickhouse_orm/contrib/geo/fields.py b/src/clickhouse_orm/contrib/geo/fields.py index 9eecd0a..df4d2cb 100644 --- a/src/clickhouse_orm/contrib/geo/fields.py +++ b/src/clickhouse_orm/contrib/geo/fields.py @@ -1,3 +1,6 @@ +from typing import Any, Optional, Union + +from clickhouse_orm import F from clickhouse_orm.fields import Field, Float64Field from clickhouse_orm.utils import POINT_REGEX, RING_VALID_REGEX @@ -53,8 +56,15 @@ class PointField(Field): class_default = Point(0, 0) db_type = 'Point' - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - db_column=None): + def __init__( + self, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): super().__init__(default, alias, materialized, readonly, codec, db_column) self.inner_field = Float64Field() diff --git a/src/clickhouse_orm/database.py b/src/clickhouse_orm/database.py index 5c767d1..8e058a1 100644 --- a/src/clickhouse_orm/database.py +++ b/src/clickhouse_orm/database.py @@ -1,18 +1,18 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import re import logging import datetime +from io import BytesIO from math import ceil from string import Template from collections import namedtuple -from typing import Type, Optional, Generator, Union, Any +from typing import Optional, Generator, Union, Any import pytz import httpx from .models import ModelBase, MODEL from .utils import parse_tsv, import_submodules -from .query import Q from .session import ctx_session_id, ctx_session_timeout @@ -24,7 +24,6 @@ class DatabaseException(Exception): """ Raised when a database operation fails. """ - pass class ServerError(DatabaseException): @@ -40,7 +39,7 @@ class ServerError(DatabaseException): # just skip custom init # if non-standard message format self.message = message - super(ServerError, self).__init__(message) + super().__init__(message) ERROR_PATTERNS = ( # ClickHouse prior to v19.3.3 @@ -82,14 +81,15 @@ class ServerError(DatabaseException): return "{} ({})".format(self.message, self.code) -class Database(object): +class Database: """ Database instances connect to a specific ClickHouse database for running queries, inserting data and other operations. """ + _client_class = httpx.Client def __init__(self, db_name, db_url='http://localhost:8123/', - username=None, password=None, readonly=False, autocreate=True, + username=None, password=None, readonly=False, auto_create=True, timeout=60, verify_ssl_cert=True, log_statements=False): """ Initializes a database instance. Unless it's readonly, the database will be @@ -100,7 +100,8 @@ class Database(object): - `username`: optional connection credentials. - `password`: optional connection credentials. - `readonly`: use a read-only connection. - - `autocreate`: automatically create the database if it does not exist (unless in readonly mode). + - `autocreate`: automatically create the database + if it does not exist (unless in readonly mode). - `timeout`: the connection timeout in seconds. - `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS. - `log_statements`: when True, all database statements are logged. @@ -108,26 +109,43 @@ class Database(object): self.db_name = db_name self.db_url = db_url self.readonly = False + self._readonly = readonly + self.auto_create = auto_create self.timeout = timeout - self.request_session = httpx.Client(verify=verify_ssl_cert, timeout=timeout) + self.request_session = self._client_class(verify=verify_ssl_cert, timeout=timeout) if username: self.request_session.auth = (username, password or '') self.log_statements = log_statements self.settings = {} self.db_exists = False # this is required before running _is_existing_database + self.connection_readonly = False + self.server_version = None + self.server_timezone = None + self.has_codec_support = None + self.has_low_cardinality_support = None + self._init = False + if self._client_class is httpx.Client: + self.init() + + def init(self): + if self._init: + return self.db_exists = self._is_existing_database() - if readonly: + if self._readonly: if not self.db_exists: raise DatabaseException( 'Database does not exist, and cannot be created under readonly connection' ) self.connection_readonly = self._is_connection_readonly() self.readonly = True - elif autocreate and not self.db_exists: + elif self.auto_create and not self.db_exists: self.create_database() self.server_version = self._get_server_version() # Versions 1.1.53981 and below don't have timezone function - self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc + if self.server_version > (1, 1, 53981): + self.server_timezone = self._get_server_timezone() + else: + self.server_timezone = pytz.utc # Versions 19.1.16 and above support codec compression self.has_codec_support = self.server_version >= (19, 1, 16) # Version 19.0 and above support LowCardinality @@ -147,7 +165,7 @@ class Database(object): self._send('DROP DATABASE `%s`' % self.db_name) self.db_exists = False - def create_table(self, model_class: Type[MODEL]) -> None: + def create_table(self, model_class: type[MODEL]) -> None: """ Creates a table for the given model class, if it does not exist already. """ @@ -157,7 +175,7 @@ class Database(object): raise DatabaseException("%s class must define an engine" % model_class.__name__) self._send(model_class.create_table_sql(self)) - def drop_table(self, model_class: Type[MODEL]) -> None: + def drop_table(self, model_class: type[MODEL]) -> None: """ Drops the database table of the given model class, if it exists. """ @@ -165,7 +183,7 @@ class Database(object): raise DatabaseException("You can't drop system table") self._send(model_class.drop_table_sql(self)) - def does_table_exist(self, model_class: Type[MODEL]) -> bool: + def does_table_exist(self, model_class: type[MODEL]) -> bool: """ Checks whether a table for the given model class already exists. Note that this only checks for existence of a table with the expected name. @@ -215,9 +233,9 @@ class Database(object): Insert records into the database. - `model_instances`: any iterable containing instances of a single model class. - - `batch_size`: number of records to send per chunk (use a lower number if your records are very large). + - `batch_size`: number of records to send per chunk + (use a lower number if your records are very large). """ - from io import BytesIO i = iter(model_instances) try: first_instance = next(i) @@ -257,8 +275,8 @@ class Database(object): def count( self, - model_class: Optional[Type[MODEL]], - conditions: Optional[Union[str, Q]] = None + model_class: Optional[type[MODEL]], + conditions: Optional[Union[str, 'Q']] = None ) -> int: """ Counts the number of records in the model's table. @@ -267,6 +285,7 @@ class Database(object): - `conditions`: optional SQL conditions (contents of the WHERE clause). """ from clickhouse_orm.query import Q + query = 'SELECT count() FROM $table' if conditions: if isinstance(conditions, Q): @@ -279,7 +298,7 @@ class Database(object): def select( self, query: str, - model_class: Optional[Type[MODEL]] = None, + model_class: Optional[type[MODEL]] = None, settings: Optional[dict] = None ) -> Generator[MODEL, None, None]: """ @@ -297,7 +316,8 @@ class Database(object): lines = r.iter_lines() field_names = parse_tsv(next(lines)) field_types = parse_tsv(next(lines)) - model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types)) + if not model_class: + model_class = ModelBase.create_ad_hoc_model(zip(field_names, field_types)) for line in lines: # skip blank line left by WITH TOTALS modifier if line: @@ -318,7 +338,7 @@ class Database(object): def paginate( self, - model_class: Type[MODEL], + model_class: type[MODEL], order_by: str, page_num: int = 1, page_size: int = 100, @@ -371,7 +391,8 @@ class Database(object): containing the migrations. - `up_to` - number of the last migration to apply. """ - from .migrations import MigrationHistory + from .migrations import MigrationHistory # pylint: disable=C0415 + logger = logging.getLogger('migrations') applied_migrations = self._get_applied_migrations(migrations_package_name) modules = import_submodules(migrations_package_name) @@ -380,7 +401,11 @@ class Database(object): logger.info('Applying migration %s...', name) for operation in modules[name].operations: operation.apply(self) - self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())]) + self.insert([MigrationHistory( + package_name=migrations_package_name, + module_name=name, + applied=datetime.date.today()) + ]) if int(name[:4]) >= up_to: break @@ -398,7 +423,8 @@ class Database(object): return params def _get_applied_migrations(self, migrations_package_name): - from .migrations import MigrationHistory + from .migrations import MigrationHistory # pylint: disable=C0415 + self.create_table(MigrationHistory) query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = self._substitute(query, MigrationHistory) @@ -450,16 +476,16 @@ class Database(object): try: r = self._send('SELECT timezone()') return pytz.timezone(r.text.strip()) - except ServerError as e: - logger.exception('Cannot determine server timezone (%s), assuming UTC', e) + except ServerError as err: + logger.exception('Cannot determine server timezone (%s), assuming UTC', err) return pytz.utc def _get_server_version(self, as_tuple=True): try: r = self._send('SELECT version();') ver = r.text - except ServerError as e: - logger.exception('Cannot determine server version (%s), assuming 1.1.0', e) + except ServerError as err: + logger.exception('Cannot determine server version (%s), assuming 1.1.0', err) ver = '1.1.0' return tuple(int(n) for n in ver.split('.') if n.isdigit()) if as_tuple else ver diff --git a/src/clickhouse_orm/engines.py b/src/clickhouse_orm/engines.py index fea9e1f..3c39751 100644 --- a/src/clickhouse_orm/engines.py +++ b/src/clickhouse_orm/engines.py @@ -1,15 +1,22 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging +from typing import TYPE_CHECKING, Optional, Union from .utils import comma_join, get_subclass_names + +if TYPE_CHECKING: + from clickhouse_orm.database import Database + from clickhouse_orm.models import Model + from clickhouse_orm.funcs import F + logger = logging.getLogger('clickhouse_orm') -class Engine(object): +class Engine: - def create_table_sql(self, db): + def create_table_sql(self, db: Database) -> str: raise NotImplementedError() # pragma: no cover @@ -34,9 +41,15 @@ class Memory(Engine): class MergeTree(Engine): def __init__( - self, date_col=None, order_by=(), sampling_expr=None, - index_granularity=8192, replica_table_path=None, - replica_name=None, partition_key=None, primary_key=None + self, + date_col: Optional[str] = None, + order_by: Union[list, tuple] = (), + sampling_expr: Optional[F] = None, + index_granularity: int = 8192, + replica_table_path: Optional[str] = None, + replica_name: Optional[str] = None, + partition_key: Optional[Union[list, tuple]] = None, + primary_key: Optional[Union[list, tuple]] = None ): assert type(order_by) in (list, tuple), 'order_by must be a list or tuple' assert date_col is None or isinstance(date_col, str), 'date_col must be string if present' @@ -73,7 +86,7 @@ class MergeTree(Engine): 'Use `order_by` attribute instead') self.order_by = value - def create_table_sql(self, db): + def create_table_sql(self, db: Database) -> str: name = self.__class__.__name__ if self.replica_name: name = 'Replicated' + name @@ -108,7 +121,7 @@ class MergeTree(Engine): params = self._build_sql_params(db) return '%s(%s) %s' % (name, comma_join(params), partition_sql) - def _build_sql_params(self, db): + def _build_sql_params(self, db: Database) -> list[str]: params = [] if self.replica_name: params += ["'%s'" % self.replica_table_path, "'%s'" % self.replica_name] @@ -140,8 +153,8 @@ class CollapsingMergeTree(MergeTree): ) self.sign_col = sign_col - def _build_sql_params(self, db): - params = super(CollapsingMergeTree, self)._build_sql_params(db) + def _build_sql_params(self, db: Database) -> list[str]: + params = super()._build_sql_params(db) params.append(self.sign_col) return params @@ -161,7 +174,7 @@ class SummingMergeTree(MergeTree): 'summing_cols must be a list or tuple' self.summing_cols = summing_cols - def _build_sql_params(self, db): + def _build_sql_params(self, db: Database) -> list[str]: params = super(SummingMergeTree, self)._build_sql_params(db) if self.summing_cols: params.append('(%s)' % comma_join(self.summing_cols)) @@ -181,7 +194,7 @@ class ReplacingMergeTree(MergeTree): ) self.ver_col = ver_col - def _build_sql_params(self, db): + def _build_sql_params(self, db: Database) -> list[str]: params = super(ReplacingMergeTree, self)._build_sql_params(db) if self.ver_col: params.append(self.ver_col) @@ -195,9 +208,17 @@ class Buffer(Engine): Read more [here](https://clickhouse.tech/docs/en/engines/table-engines/special/buffer/). """ - #Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes) - def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, - max_rows=1000000, min_bytes=10000000, max_bytes=100000000): + def __init__( + self, + main_model: type[Model], + num_layers: int = 16, + min_time: int = 10, + max_time: int = 100, + min_rows: int = 10000, + max_rows: int = 1000000, + min_bytes: int = 10000000, + max_bytes: int = 100000000 + ): self.main_model = main_model self.num_layers = num_layers self.min_time = min_time @@ -207,7 +228,7 @@ class Buffer(Engine): self.min_bytes = min_bytes self.max_bytes = max_bytes - def create_table_sql(self, db): + def create_table_sql(self, db: Database) -> str: # Overriden create_table_sql example: # sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( @@ -226,11 +247,11 @@ class Merge(Engine): https://clickhouse.tech/docs/en/engines/table-engines/special/merge/ """ - def __init__(self, table_regex): + def __init__(self, table_regex: str): assert isinstance(table_regex, str), "'table_regex' parameter must be string" self.table_regex = table_regex - def create_table_sql(self, db): + def create_table_sql(self, db: Database) -> str: return "Merge(`%s`, '%s')" % (db.db_name, self.table_regex) @@ -258,23 +279,22 @@ class Distributed(Engine): self.sharding_key = sharding_key @property - def table_name(self): - # TODO: circular import is bad - from .models import ModelBase + def table_name(self) -> str: + from clickhouse_orm.models import Model table = self.table - if isinstance(table, ModelBase): + if isinstance(table, Model): return table.table_name() return table - def create_table_sql(self, db): + def create_table_sql(self, db: Database) -> str: name = self.__class__.__name__ params = self._build_sql_params(db) return '%s(%s)' % (name, ', '.join(params)) - def _build_sql_params(self, db): + def _build_sql_params(self, db: Database) -> list[str]: if self.table_name is None: raise ValueError("Cannot create {} engine: specify an underlying table".format( self.__class__.__name__)) diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index 6dd342d..0377965 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -1,10 +1,14 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations + +import re +from enum import Enum +from uuid import UUID from calendar import timegm import datetime from decimal import Decimal, localcontext from logging import getLogger from ipaddress import IPv4Address, IPv6Address -from uuid import UUID +from typing import TYPE_CHECKING, Any, Optional, Union, Iterable import iso8601 import pytz @@ -13,6 +17,10 @@ from pytz import BaseTzInfo from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names from .funcs import F, FunctionOperatorsMixin +if TYPE_CHECKING: + from clickhouse_orm.models import Model + from clickhouse_orm.database import Database + logger = getLogger('clickhouse_orm') @@ -20,20 +28,27 @@ class Field(FunctionOperatorsMixin): """ Abstract base class for all field types. """ - name = None # this is set by the parent model - parent = None # this is set by the parent model - creation_counter = 0 # used for keeping the model fields ordered - class_default = 0 # should be overridden by concrete subclasses - db_type = None # should be overridden by concrete subclasses + name: str = None # this is set by the parent model + parent: type["Model"] = None # this is set by the parent model + creation_counter: int = 0 # used for keeping the model fields ordered + class_default: Any = 0 # should be overridden by concrete subclasses + db_type: str # should be overridden by concrete subclasses - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - db_column=None): + def __init__( + self, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): assert [default, alias, materialized].count(None) >= 2, \ "Only one of default, alias and materialized parameters can be given" assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "", \ "Alias parameter must be a string or function object, if given" - assert materialized is None or isinstance(materialized, F) or isinstance(materialized, - str) and materialized != "", \ + assert (materialized is None or isinstance(materialized, F) or + isinstance(materialized, str) and materialized != ""), \ "Materialized parameter must be a string or function object, if given" assert readonly is None or type( readonly) is bool, "readonly parameter must be bool if given" @@ -78,7 +93,8 @@ class Field(FunctionOperatorsMixin): """ if value < min_value or value > max_value: raise ValueError('%s out of range - %s is not between %s and %s' % ( - self.__class__.__name__, value, min_value, max_value)) + self.__class__.__name__, value, min_value, max_value + )) def to_db_string(self, value, quote=True): """ @@ -87,7 +103,7 @@ class Field(FunctionOperatorsMixin): """ return escape(value, quote) - def get_sql(self, with_default_expression=True, db=None): + def get_sql(self, with_default_expression=True, db=None) -> str: """ Returns an SQL expression describing the field (e.g. for CREATE TABLE). @@ -107,7 +123,7 @@ class Field(FunctionOperatorsMixin): """Returns field type arguments""" return [] - def _extra_params(self, db): + def _extra_params(self, db: Database) -> str: sql = '' if self.alias: sql += ' ALIAS %s' % string_or_func(self.alias) @@ -122,7 +138,7 @@ class Field(FunctionOperatorsMixin): sql += ' CODEC(%s)' % self.codec return sql - def isinstance(self, types): + def isinstance(self, types) -> bool: """ Checks if the instance if one of the types provided or if any of the inner_field child is one of the types provided, returns True if field or any inner_field is one of ths provided, False otherwise @@ -145,7 +161,7 @@ class StringField(Field): class_default = '' db_type = 'String' - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> str: if isinstance(value, str): return value if isinstance(value, bytes): @@ -155,13 +171,20 @@ class StringField(Field): class FixedStringField(StringField): - def __init__(self, length, default=None, alias=None, materialized=None, readonly=None, - db_column=None): + def __init__( + self, + length: int, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: Optional[bool] = None, + db_column: Optional[str] = None + ): self._length = length self.db_type = 'FixedString(%d)' % length super(FixedStringField, self).__init__(default, alias, materialized, readonly, db_column) - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> str: value = super(FixedStringField, self).to_python(value, timezone_in_use) return value.rstrip('\0') @@ -169,8 +192,9 @@ class FixedStringField(StringField): if isinstance(value, str): value = value.encode('UTF-8') if len(value) > self._length: - raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % ( - len(value), self._length)) + raise ValueError( + f'Value of {len(value)} bytes is too long for FixedStringField({self._length})' + ) class DateField(Field): @@ -179,7 +203,7 @@ class DateField(Field): class_default = min_value db_type = 'Date' - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> datetime.date: if isinstance(value, datetime.datetime): return value.astimezone(pytz.utc).date() if value.tzinfo else value.date() if isinstance(value, datetime.date): @@ -195,7 +219,7 @@ class DateField(Field): def validate(self, value): self._range_check(value, DateField.min_value, DateField.max_value) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: return escape(value.isoformat(), quote) @@ -203,8 +227,16 @@ class DateTimeField(Field): class_default = datetime.datetime.fromtimestamp(0, pytz.utc) db_type = 'DateTime' - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - db_column=None, timezone=None): + def __init__( + self, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None, + timezone: Optional[Union[BaseTzInfo, str]] = None + ): super().__init__(default, alias, materialized, readonly, codec, db_column) # assert not timezone, 'Temporarily field timezone is not supported' if timezone: @@ -217,7 +249,7 @@ class DateTimeField(Field): args.append(escape(self.timezone.zone)) return args - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> datetime.datetime: if isinstance(value, datetime.datetime): return value if value.tzinfo else value.replace(tzinfo=pytz.utc) if isinstance(value, datetime.date): @@ -245,15 +277,34 @@ class DateTimeField(Field): return dt raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: return escape('%010d' % timegm(value.utctimetuple()), quote) class DateTime64Field(DateTimeField): db_type = 'DateTime64' - def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, - db_column=None, timezone=None, precision=6): + """ + + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + """ + + def __init__( + self, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None, + timezone: Optional[Union[BaseTzInfo, str]] = None, + precision: int = 6 + ): super().__init__(default, alias, materialized, readonly, codec, db_column, timezone) assert precision is None or isinstance(precision, int), 'Precision must be int type' self.precision = precision @@ -264,7 +315,7 @@ class DateTime64Field(DateTimeField): args.append(escape(self.timezone.zone)) return args - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: """ Returns the field's value prepared for writing to the database @@ -278,7 +329,7 @@ class DateTime64Field(DateTimeField): quote ) - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> datetime.datetime: try: return super().to_python(value, timezone_in_use) except ValueError: @@ -302,13 +353,13 @@ class BaseIntField(Field): Abstract base class for all integer-type fields. """ - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> int: try: return int(value) except: raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: # There's no need to call escape since numbers do not contain # special characters, and never need quoting return str(value) @@ -370,13 +421,13 @@ class BaseFloatField(Field): Abstract base class for all float-type fields. """ - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> float: try: return float(value) except: raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: # There's no need to call escape since numbers do not contain # special characters, and never need quoting return str(value) @@ -395,8 +446,16 @@ class DecimalField(Field): Base class for all decimal fields. Can also be used directly. """ - def __init__(self, precision, scale, default=None, alias=None, materialized=None, - readonly=None, db_column=None): + def __init__( + self, + precision: int, + scale: int, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + db_column: Optional[str] = None + ): assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' self.precision = precision @@ -409,7 +468,7 @@ class DecimalField(Field): self.min_value = -self.max_value super(DecimalField, self).__init__(default, alias, materialized, readonly, db_column) - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> Decimal: if not isinstance(value, Decimal): try: value = Decimal(value) @@ -419,7 +478,7 @@ class DecimalField(Field): raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) return self._round(value) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: # There's no need to call escape since numbers do not contain # special characters, and never need quoting return str(value) @@ -432,25 +491,45 @@ class DecimalField(Field): class Decimal32Field(DecimalField): - - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None, - db_column=None): + def __init__( + self, + scale: int, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + db_column: Optional[str] = None + ): super().__init__(9, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal32(%d)' % scale class Decimal64Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None, - db_column=None): + def __init__( + self, + scale: int, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + db_column: Optional[str] = None + ): super().__init__(18, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal64(%d)' % scale class Decimal128Field(DecimalField): - def __init__(self, scale, default=None, alias=None, materialized=None, - readonly=None, db_column=None): + def __init__( + self, + scale: int, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + db_column: Optional[str] = None + ): super().__init__(38, scale, default, alias, materialized, readonly, db_column) self.db_type = 'Decimal128(%d)' % scale @@ -460,8 +539,16 @@ class BaseEnumField(Field): Abstract base class for all enum-type fields. """ - def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, - codec=None, db_column=None): + def __init__( + self, + enum_cls: type[Enum], + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): self.enum_cls = enum_cls if default is None: default = list(enum_cls)[0] @@ -488,20 +575,18 @@ class BaseEnumField(Field): pass raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: return escape(value.name, quote) def get_db_type_args(self): return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] @classmethod - def create_ad_hoc_field(cls, db_type): + def create_ad_hoc_field(cls, db_type) -> BaseEnumField: """ Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" this method returns a matching enum field. """ - import re - from enum import Enum members = {} for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type): members[match.group(1)] = int(match.group(2)) @@ -521,8 +606,16 @@ class Enum16Field(BaseEnumField): class ArrayField(Field): class_default = [] - def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, - codec=None, db_column=None): + def __init__( + self, + inner_field: Field, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): assert isinstance(inner_field, Field), \ "The first argument of ArrayField must be a Field instance" assert not isinstance(inner_field, ArrayField), \ @@ -543,22 +636,81 @@ class ArrayField(Field): for v in value: self.inner_field.validate(v) - def to_db_string(self, value, quote=True): + def to_db_string(self, value, quote=True) -> str: array = [self.inner_field.to_db_string(v, quote=True) for v in value] return '[' + comma_join(array) + ']' - def get_sql(self, with_default_expression=True, db=None): + def get_sql(self, with_default_expression=True, db=None) -> str: sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) if with_default_expression and self.codec and db and db.has_codec_support: sql += ' CODEC(%s)' % self.codec return sql +class TupleField(Field): + class_default = () + + def __init__( + self, + name_fields: list[tuple[str, Field]], + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): + self.names = {} + self.inner_fields = [] + for (name, field) in name_fields: + if name in self.names: + raise ValueError('The Field name conflict') + assert isinstance(field, Field), \ + "The first argument of TupleField must be a Field instance" + assert not isinstance(field, (ArrayField, TupleField)), \ + "Multidimensional array fields are not supported by the ORM" + self.names[name] = field + self.inner_fields.append(field) + self.class_default = tuple(field.class_default for field in self.inner_fields) + super().__init__(default, alias, materialized, readonly, codec, db_column) + + def to_python(self, value, timezone_in_use) -> tuple: + if isinstance(value, str): + value = parse_array(value) + value = (self.inner_fields[i].to_python(v, timezone_in_use) + for i, v in enumerate(value)) + elif isinstance(value, bytes): + value = parse_array(value.decode('UTF-8')) + value = (self.inner_fields[i].to_python(v, timezone_in_use) + for i, v in enumerate(value)) + elif not isinstance(value, (list, tuple)): + raise ValueError('TupleField expects list or tuple, not %s' % type(value)) + return tuple(self.inner_fields[i].to_python(v, timezone_in_use) + for i, v in enumerate(value)) + + def validate(self, value): + for i, v in enumerate(value): + self.inner_fields[i].validate(v) + + def to_db_string(self, value, quote=True) -> str: + array = [self.inner_fields[i].to_db_string(v, quote=True) for i, v in enumerate(value)] + return '(' + comma_join(array) + ')' + + def get_sql(self, with_default_expression=True, db=None) -> str: + inner_sql = ', '.join('%s %s' % (name, field.get_sql(False)) + for name, field in self.names.items()) + + sql = 'Tuple(%s)' % inner_sql + if with_default_expression and self.codec and db and db.has_codec_support: + sql += ' CODEC(%s)' % self.codec + return sql + + class UUIDField(Field): class_default = UUID(int=0) db_type = 'UUID' - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> UUID: if isinstance(value, UUID): return value elif isinstance(value, bytes): @@ -580,7 +732,7 @@ class IPv4Field(Field): class_default = 0 db_type = 'IPv4' - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> IPv4Address: if isinstance(value, IPv4Address): return value elif isinstance(value, (bytes, str, int)): @@ -596,7 +748,7 @@ class IPv6Field(Field): class_default = 0 db_type = 'IPv6' - def to_python(self, value, timezone_in_use): + def to_python(self, value, timezone_in_use) -> IPv6Address: if isinstance(value, IPv6Address): return value elif isinstance(value, (bytes, str, int)): @@ -611,8 +763,16 @@ class IPv6Field(Field): class NullableField(Field): class_default = None - def __init__(self, inner_field, default=None, alias=None, materialized=None, - extra_null_values=None, codec=None): + def __init__( + self, + inner_field: Field, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + extra_null_values: Optional[Iterable] = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): assert isinstance(inner_field, Field), \ "The first argument of NullableField must be a Field instance." \ " Not: {}".format(inner_field) @@ -620,8 +780,9 @@ class NullableField(Field): self._null_values = [None] if extra_null_values: self._null_values.extend(extra_null_values) - super(NullableField, self).__init__(default, alias, materialized, readonly=None, - codec=codec) + super().__init__( + default, alias, materialized, readonly=None, codec=codec, db_column=db_column + ) def to_python(self, value, timezone_in_use): if value == '\\N' or value in self._null_values: @@ -645,8 +806,16 @@ class NullableField(Field): class LowCardinalityField(Field): - def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, - codec=None): + def __init__( + self, + inner_field: Field, + default: Any = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: Optional[bool] = None, + codec: Optional[str] = None, + db_column: Optional[str] = None + ): assert isinstance(inner_field, Field), \ "The first argument of LowCardinalityField must be a Field instance." \ " Not: {}".format(inner_field) @@ -657,7 +826,7 @@ class LowCardinalityField(Field): " Use Array(LowCardinality) instead" self.inner_field = inner_field self.class_default = self.inner_field.class_default - super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec) + super().__init__(default, alias, materialized, readonly, codec, db_column) def to_python(self, value, timezone_in_use): return self.inner_field.to_python(value, timezone_in_use) diff --git a/src/clickhouse_orm/funcs.py b/src/clickhouse_orm/funcs.py index 2d0d1d6..0904221 100644 --- a/src/clickhouse_orm/funcs.py +++ b/src/clickhouse_orm/funcs.py @@ -75,7 +75,7 @@ def parametric(func): return wrapper -class FunctionOperatorsMixin(object): +class FunctionOperatorsMixin: """ A mixin for implementing Python operators using F objects. """ @@ -248,7 +248,7 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): def __repr__(self): return self.to_sql() - def to_sql(self, *args): + def to_sql(self, *args) -> str: """ Generates an SQL string for this function and its arguments. For example if the function name is a symbol of a binary operator: @@ -898,7 +898,6 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): @staticmethod def replace(haystack, pattern, replacement): return F('replace', haystack, pattern, replacement) - replaceAll = replace @staticmethod def replaceAll(haystack, pattern, replacement): diff --git a/src/clickhouse_orm/migrations.py b/src/clickhouse_orm/migrations.py index af7dc51..efcc19f 100644 --- a/src/clickhouse_orm/migrations.py +++ b/src/clickhouse_orm/migrations.py @@ -1,30 +1,31 @@ from .models import Model, BufferModel from .fields import DateField, StringField from .engines import MergeTree -from .utils import escape, get_subclass_names +from .utils import get_subclass_names import logging + logger = logging.getLogger('migrations') -class Operation(): - ''' +class Operation: + """ Base class for migration operations. - ''' + """ def apply(self, database): raise NotImplementedError() # pragma: no cover class ModelOperation(Operation): - ''' + """ Base class for migration operations that work on a specific model. - ''' + """ def __init__(self, model_class): - ''' + """ Initializer. - ''' + """ self.model_class = model_class self.table_name = model_class.table_name() @@ -38,9 +39,9 @@ class ModelOperation(Operation): class CreateTable(ModelOperation): - ''' + """ A migration operation that creates a table for a given model class. - ''' + """ def apply(self, database): logger.info(' Create table %s', self.table_name) @@ -50,14 +51,14 @@ class CreateTable(ModelOperation): class AlterTable(ModelOperation): - ''' + """ A migration operation that compares the table of a given model class to the model's fields, and alters the table to match the model. The operation can: - add new columns - drop obsolete columns - modify column types Default values are not altered by this operation. - ''' + """ def _get_table_fields(self, database): query = "DESC `%s`.`%s`" % (database.db_name, self.table_name) @@ -113,11 +114,11 @@ class AlterTable(ModelOperation): class AlterTableWithBuffer(ModelOperation): - ''' + """ A migration operation for altering a buffer table and its underlying on-disk table. The buffer table is dropped, the on-disk table is altered, and then the buffer table is re-created. - ''' + """ def apply(self, database): if issubclass(self.model_class, BufferModel): @@ -129,9 +130,9 @@ class AlterTableWithBuffer(ModelOperation): class DropTable(ModelOperation): - ''' + """ A migration operation that drops the table of a given model class. - ''' + """ def apply(self, database): logger.info(' Drop table %s', self.table_name) @@ -139,12 +140,12 @@ class DropTable(ModelOperation): class AlterConstraints(ModelOperation): - ''' + """ A migration operation that adds new constraints from the model to the database table, and drops obsolete ones. Constraints are identified by their names, so a change in an existing constraint will not be detected unless its name was changed too. ClickHouse does not check that the constraints hold for existing data in the table. - ''' + """ def apply(self, database): logger.info(' Alter constraints for %s', self.table_name) @@ -163,9 +164,9 @@ class AlterConstraints(ModelOperation): self._alter_table(database, 'DROP CONSTRAINT `%s`' % name) def _get_constraint_names(self, database): - ''' + """ Returns a set containing the names of existing constraints in the table. - ''' + """ import re table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name) matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def) @@ -173,19 +174,19 @@ class AlterConstraints(ModelOperation): class AlterIndexes(ModelOperation): - ''' + """ A migration operation that adds new indexes from the model to the database table, and drops obsolete ones. Indexes are identified by their names, so a change in an existing index will not be detected unless its name was changed too. - ''' + """ def __init__(self, model_class, reindex=False): - ''' + """ Initializer. By default ClickHouse does not build indexes over existing data, only for new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build the indexes over the existing data. - ''' + """ super().__init__(model_class) self.reindex = reindex @@ -211,9 +212,9 @@ class AlterIndexes(ModelOperation): database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name) def _get_index_names(self, database): - ''' + """ Returns a set containing the names of existing indexes in the table. - ''' + """ import re table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name) matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def) @@ -221,9 +222,9 @@ class AlterIndexes(ModelOperation): class RunPython(Operation): - ''' + """ A migration operation that executes a Python function. - ''' + """ def __init__(self, func): ''' Initializer. The given Python function will be called with a single @@ -238,9 +239,9 @@ class RunPython(Operation): class RunSQL(Operation): - ''' + """ A migration operation that executes arbitrary SQL statements. - ''' + """ def __init__(self, sql): ''' @@ -259,9 +260,9 @@ class RunSQL(Operation): class MigrationHistory(Model): - ''' + """ A model for storing which migrations were already applied to the containing database. - ''' + """ package_name = StringField() module_name = StringField() diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index 491a83e..b64cd82 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -1,9 +1,9 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import sys from collections import OrderedDict from itertools import chain from logging import getLogger -from typing import TypeVar, Dict +from typing import TypeVar, Optional, TYPE_CHECKING, Any import pytz @@ -13,6 +13,10 @@ from .query import QuerySet from .funcs import F from .engines import Merge, Distributed, Memory + +if TYPE_CHECKING: + from clickhouse_orm.database import Database + logger = getLogger('clickhouse_orm') @@ -21,16 +25,16 @@ class Constraint: Defines a model constraint. """ - name = None # this is set by the parent model - parent = None # this is set by the parent model + name: Optional[str] = None # this is set by the parent model + parent: Optional[type["Model"]] = None # this is set by the parent model - def __init__(self, expr): + def __init__(self, expr: F): """ Initializer. Expects an expression that ClickHouse will verify when inserting data. """ self.expr = expr - def create_table_sql(self): + def create_table_sql(self) -> str: """ Returns the SQL statement for defining this constraint during table creation. """ @@ -42,10 +46,10 @@ class Index: Defines a data-skipping index. """ - name = None # this is set by the parent model - parent = None # this is set by the parent model + name: Optional[str] = None # this is set by the parent model + parent: Optional[type["Model"]] = None # this is set by the parent model - def __init__(self, expr, type, granularity): + def __init__(self, expr: F, type: str, granularity: int): """ Initializer. @@ -58,11 +62,13 @@ class Index: self.type = type self.granularity = granularity - def create_table_sql(self): + def create_table_sql(self) -> str: """ Returns the SQL statement for defining this index during table creation. """ - return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity) + return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % ( + self.name, arg_to_sql(self.expr), self.type, self.granularity + ) @staticmethod def minmax(): @@ -73,7 +79,7 @@ class Index: return 'minmax' @staticmethod - def set(max_rows): + def set(max_rows: int) -> str: """ An index that stores unique values of the specified expression (no more than max_rows rows, or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable @@ -82,7 +88,8 @@ class Index: return 'set(%d)' % max_rows @staticmethod - def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): + def ngrambf_v1(n: int, size_of_bloom_filter_in_bytes: int, + number_of_hash_functions: int, random_seed: int) -> str: """ An index that stores a Bloom filter containing all ngrams from a block of data. Works only with strings. Can be used for optimization of equals, like and in expressions. @@ -93,10 +100,13 @@ class Index: - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. """ - return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) + return 'ngrambf_v1(%d, %d, %d, %d)' % ( + n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed + ) @staticmethod - def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): + def tokenbf_v1(size_of_bloom_filter_in_bytes: int, number_of_hash_functions: int, + random_seed: int) -> str: """ An index that stores a Bloom filter containing string tokens. Tokens are sequences separated by non-alphanumeric characters. @@ -106,16 +116,18 @@ class Index: - `number_of_hash_functions` — The number of hash functions used in the Bloom filter. - `random_seed` — The seed for Bloom filter hash functions. """ - return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) + return 'tokenbf_v1(%d, %d, %d)' % ( + size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed + ) @staticmethod - def bloom_filter(false_positive=0.025): - ''' + def bloom_filter(false_positive: float = 0.025) -> str: + """ An index that stores a Bloom filter containing values of the index expression. - `false_positive` - the probability (between 0 and 1) of receiving a false positive response from the filter - ''' + """ return 'bloom_filter(%f)' % false_positive @@ -126,7 +138,7 @@ class ModelBase(type): ad_hoc_model_cache = {} - def __new__(cls, name, bases, attrs): + def __new__(mcs, name, bases, attrs): # Collect fields, constraints and indexes from parent classes fields = {} @@ -147,7 +159,8 @@ class ModelBase(type): elif isinstance(obj, Index): indexes[n] = obj - # Convert fields to a list of (name, field) tuples, in the order they were listed in the class + # Convert fields to a list of (name, field) tuples + # in the order they were listed in the class fields = sorted(fields.items(), key=lambda item: item[1].creation_counter) # Build a dictionary of default values @@ -172,7 +185,7 @@ class ModelBase(type): _defaults=defaults, _has_funcs_as_defaults=has_funcs_as_defaults ) - model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) + model = super(ModelBase, mcs).__new__(mcs, str(name), bases, attrs) # Let each field, constraint and index know its parent and its own name for n, obj in chain(fields, constraints.items(), indexes.items()): @@ -265,6 +278,10 @@ class Model(metaclass=ModelBase): cpu_percent = Float32Field() engine = Memory() """ + _has_funcs_as_defaults: bool + _constraints: dict[str, Constraint] + _indexes: dict[str, Index] + _writable_fields: dict engine = None @@ -278,7 +295,7 @@ class Model(metaclass=ModelBase): _database = None - _fields: Dict[str, Field] + _fields: dict[str, Field] def __init__(self, **kwargs): """ @@ -314,31 +331,32 @@ class Model(metaclass=ModelBase): raise tp.with_traceback(tp(new_msg), tb) super(Model, self).__setattr__(name, value) - def set_database(self, db): + def set_database(self, db: Database): """ Sets the `Database` that this model instance belongs to. This is done automatically when the instance is read from the database or written to it. """ # This can not be imported globally due to circular import - from .database import Database + from clickhouse_orm.database import Database + assert isinstance(db, Database), "database must be database.Database instance" self._database = db - def get_database(self): + def get_database(self) -> Database: """ Gets the `Database` that this model instance belongs to. Returns `None` unless the instance was read from the database or written to it. """ return self._database - def get_field(self, name): + def get_field(self, name: str) -> Optional[Field]: """ Gets a `Field` instance given its name, or `None` if not found. """ return self._fields.get(name) @classmethod - def table_name(cls): + def table_name(cls) -> str: """ Returns the model's database table name. By default this is the class name converted to lowercase. Override this if you want to use @@ -347,7 +365,7 @@ class Model(metaclass=ModelBase): return cls.__name__.lower() @classmethod - def has_funcs_as_defaults(cls): + def has_funcs_as_defaults(cls) -> bool: """ Return True if some of the model's fields use a function expression as a default value. This requires special handling when inserting instances. @@ -355,7 +373,7 @@ class Model(metaclass=ModelBase): return cls._has_funcs_as_defaults @classmethod - def create_table_sql(cls, db): + def create_table_sql(cls, db: Database) -> str: """ Returns the SQL statement for creating a table for this model. """ @@ -377,7 +395,7 @@ class Model(metaclass=ModelBase): return '\n'.join(parts) @classmethod - def drop_table_sql(cls, db): + def drop_table_sql(cls, db: Database) -> str: """ Returns the SQL command for deleting this model's table. """ @@ -432,7 +450,7 @@ class Model(metaclass=ModelBase): parts.append(name + '=' + field.to_db_string(data[name], quote=False)) return '\t'.join(parts) - def to_db_string(self): + def to_db_string(self) -> bytes: """ Returns the instance as a bytestring ready to be inserted into the database. """ @@ -440,7 +458,7 @@ class Model(metaclass=ModelBase): s += '\n' return s.encode('utf-8') - def to_dict(self, include_readonly=True, field_names=None): + def to_dict(self, include_readonly=True, field_names=None) -> dict[str, Any]: """ Returns the instance's column values as a dict. @@ -456,14 +474,14 @@ class Model(metaclass=ModelBase): return {name: data[name] for name in fields} @classmethod - def objects_in(cls, database): + def objects_in(cls, database: Database) -> QuerySet: """ Returns a `QuerySet` for selecting instances of this model class. """ return QuerySet(cls, database) @classmethod - def fields(cls, writable=False): + def fields(cls, writable: bool = False) -> dict[str, Field]: """ Returns an `OrderedDict` of the model's fields (from name to `Field` instance). If `writable` is true, only writable fields are included. @@ -473,21 +491,21 @@ class Model(metaclass=ModelBase): return cls._writable_fields if writable else cls._fields @classmethod - def is_read_only(cls): + def is_read_only(cls) -> bool: """ Returns true if the model is marked as read only. """ return cls._readonly @classmethod - def is_system_model(cls): + def is_system_model(cls) -> bool: """ Returns true if the model represents a system table. """ return cls._system @classmethod - def is_temporary_model(cls): + def is_temporary_model(cls) -> bool: """ Returns true if the model represents a temporary table. """ @@ -497,7 +515,7 @@ class Model(metaclass=ModelBase): class BufferModel(Model): @classmethod - def create_table_sql(cls, db) -> str: + def create_table_sql(cls, db: Database) -> str: """ Returns the SQL statement for creating a table for this model. """ @@ -522,7 +540,7 @@ class MergeModel(Model): _table = StringField(readonly=True) @classmethod - def create_table_sql(cls, db) -> str: + def create_table_sql(cls, db: Database) -> str: """ Returns the SQL statement for creating a table for this model. """ @@ -545,15 +563,14 @@ class DistributedModel(Model): Model class for use with a `Distributed` engine. """ - def set_database(self, db): + def set_database(self, db: Database): """ Sets the `Database` that this model instance belongs to. This is done automatically when the instance is read from the database or written to it. """ assert isinstance(self.engine, Distributed),\ "engine must be an instance of engines.Distributed" - res = super(DistributedModel, self).set_database(db) - return res + super().set_database(db) @classmethod def fix_engine_table(cls): @@ -607,7 +624,7 @@ class DistributedModel(Model): cls.engine.table = storage_models[0] @classmethod - def create_table_sql(cls, db) -> str: + def create_table_sql(cls, db: Database) -> str: """ Returns the SQL statement for creating a table for this model. """ @@ -637,7 +654,7 @@ class TemporaryModel(Model): _temporary = True @classmethod - def create_table_sql(cls, db) -> str: + def create_table_sql(cls, db: Database) -> str: assert isinstance(cls.engine, Memory), "engine must be engines.Memory instance" parts = ['CREATE TEMPORARY TABLE IF NOT EXISTS `%s` (' % cls.table_name()] diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index 8b6cb09..d122431 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -1,21 +1,26 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations from math import ceil from copy import copy, deepcopy +from types import CoroutineType +from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic import pytz from .utils import comma_join, string_or_func, arg_to_sql +# pylint: disable=R0903, W0212, C0415 -# TODO -# - check that field names are valid +if TYPE_CHECKING: + from clickhouse_orm.models import Model + from clickhouse_orm.database import Database, Page -class Operator(object): + +class Operator: """ Base class for filtering operators. """ - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: """ Subclasses should implement this method. It returns an SQL string that applies this operator on the given field and value. @@ -24,6 +29,7 @@ class Operator(object): def _value_to_sql(self, field, value, quote=True): from clickhouse_orm.funcs import F + if isinstance(value, F): return value.to_sql() return field.to_db_string(field.to_python(value, pytz.utc), quote) @@ -38,7 +44,7 @@ class SimpleOperator(Operator): self._sql_operator = sql_operator self._sql_for_null = sql_for_null - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) if value == '\\N' and self._sql_for_null is not None: @@ -55,7 +61,7 @@ class InOperator(Operator): - a queryset (subquery) """ - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) if isinstance(value, QuerySet): value = value.as_sql() @@ -69,7 +75,7 @@ class InOperator(Operator): class GlobalInOperator(Operator): """An operator that implements Group IN.""" - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) if isinstance(value, QuerySet): value = value.as_sql() @@ -90,15 +96,14 @@ class LikeOperator(Operator): self._pattern = pattern self._case_sensitive = case_sensitive - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value, quote=False) value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') pattern = self._pattern.format(value) if self._case_sensitive: return '%s LIKE \'%s\'' % (field.name, pattern) - else: - return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern) + return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field.name, pattern) class IExactOperator(Operator): @@ -106,7 +111,7 @@ class IExactOperator(Operator): An operator for case insensitive string comparison. """ - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field.name, value) @@ -120,7 +125,7 @@ class NotOperator(Operator): def __init__(self, base_operator): self._base_operator = base_operator - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: # Negate the base operator return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) @@ -135,7 +140,7 @@ class BetweenOperator(Operator): - '<= value[1]' if value[0] is None or empty """ - def to_sql(self, model_cls, field_name, value): + def to_sql(self, model_cls: type[Model], field_name: str, value: Any) -> str: field = getattr(model_cls, field_name) value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len( str(value[0])) > 0 else None @@ -151,10 +156,10 @@ class BetweenOperator(Operator): # Define the set of builtin operators -_operators = {} +_operators: dict[str, Operator] = {} -def register_operator(name, sql): +def register_operator(name: str, sql: Operator): _operators[name] = sql @@ -178,12 +183,12 @@ register_operator('iendswith', LikeOperator('%{}', False)) register_operator('iexact', IExactOperator()) -class Cond(object): +class Cond: """ An abstract object for storing a single query condition Field + Operator + Value. """ - def to_sql(self, model_cls): + def to_sql(self, model_cls: type[Model]) -> str: raise NotImplementedError @@ -192,7 +197,7 @@ class FieldCond(Cond): A single query condition made up of Field + Operator + Value. """ - def __init__(self, field_name, operator, value): + def __init__(self, field_name: str, operator: str, value: Any): self._field_name = field_name self._operator = _operators.get(operator) if self._operator is None: @@ -201,7 +206,7 @@ class FieldCond(Cond): self._operator = _operators['eq'] self._value = value - def to_sql(self, model_cls): + def to_sql(self, model_cls: type[Model]) -> str: return self._operator.to_sql(model_cls, self._field_name, self._value) def __deepcopy__(self, memodict={}): @@ -210,7 +215,7 @@ class FieldCond(Cond): return res -class Q(object): +class Q: AND_MODE = 'AND' OR_MODE = 'OR' @@ -222,7 +227,7 @@ class Q(object): self._mode = self.AND_MODE @property - def is_empty(self): + def is_empty(self) -> bool: """ Checks if there are any conditions in Q object Returns: Boolean @@ -252,7 +257,7 @@ class Q(object): field_name, operator = key, 'eq' return FieldCond(field_name, operator, value) - def to_sql(self, model_cls): + def to_sql(self, model_cls: type[Model]) -> str: condition_sql = [] if self._conds: @@ -276,13 +281,13 @@ class Q(object): return sql - def __or__(self, other): + def __or__(self, other) -> "Q": return Q._construct_from(self, other, self.OR_MODE) - def __and__(self, other): + def __and__(self, other) -> "Q": return Q._construct_from(self, other, self.AND_MODE) - def __invert__(self): + def __invert__(self) -> "Q": q = copy(self) q._negate = True return q @@ -302,14 +307,14 @@ class Q(object): return q -class QuerySet(object): +class QuerySet: """ A queryset is an object that represents a database query using a specific `Model`. It is lazy, meaning that it does not hit the database until you iterate over its matching rows (model instances). """ - def __init__(self, model_cls, database): + def __init__(self, model_cls: type[Model], database: Database): """ Initializer. It is possible to create a queryset like this, but the standard way is to use `MyModel.objects_in(database)`. @@ -350,6 +355,9 @@ class QuerySet(object): return self._database.select(self.as_sql(), self._model_cls) async def __aiter__(self): + from clickhouse_orm.aio.database import AioDatabase + + assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'" async for r in self._database.select(self.as_sql(), self._model_cls): yield r @@ -365,6 +373,14 @@ class QuerySet(object): def __str__(self): return self.as_sql() + @overload + def __getitem__(self, s: int) -> Model: + ... + + @overload + def __getitem__(self, s: slice) -> "QuerySet": + ... + def __getitem__(self, s): if isinstance(s, int): # Single index @@ -372,18 +388,17 @@ class QuerySet(object): qs = copy(self) qs._limits = (s, 1) return next(iter(qs)) - else: - # Slice - assert s.step in (None, 1), 'step is not supported in slices' - start = s.start or 0 - stop = s.stop or 2 ** 63 - 1 - assert start >= 0 and stop >= 0, 'negative indexes are not supported' - assert start <= stop, 'start of slice cannot be smaller than its end' - qs = copy(self) - qs._limits = (start, stop - start) - return qs + # Slice + assert s.step in (None, 1), 'step is not supported in slices' + start = s.start or 0 + stop = s.stop or 2 ** 63 - 1 + assert start >= 0 and stop >= 0, 'negative indexes are not supported' + assert start <= stop, 'start of slice cannot be smaller than its end' + qs = copy(self) + qs._limits = (start, stop - start) + return qs - def limit_by(self, offset_limit, *fields_or_expr): + def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet": """ Adds a LIMIT BY clause to the query. - `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit). @@ -400,7 +415,7 @@ class QuerySet(object): qs._limit_by_fields = fields_or_expr return qs - def select_fields_as_sql(self): + def select_fields_as_sql(self) -> str: """ Returns the selected fields or expressions as a SQL string. """ @@ -409,7 +424,7 @@ class QuerySet(object): fields = comma_join('`%s`' % field for field in self._fields) return fields - def as_sql(self): + def as_sql(self) -> str: """ Returns the whole query as a SQL string. """ @@ -419,7 +434,7 @@ class QuerySet(object): if self._model_cls.is_system_model(): table_name = '`system`.' + table_name params = (distinct, self.select_fields_as_sql(), table_name, final) - sql = u'SELECT %s%s\nFROM %s%s' % params + sql = 'SELECT %s%s\nFROM %s%s' % params if self._prewhere_q and not self._prewhere_q.is_empty: sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) @@ -445,7 +460,7 @@ class QuerySet(object): return sql - def order_by_as_sql(self): + def order_by_as_sql(self) -> str: """ Returns the contents of the query's `ORDER BY` clause as a string. """ @@ -454,20 +469,20 @@ class QuerySet(object): for field in self._order_by ]) - def conditions_as_sql(self, prewhere=False): + def conditions_as_sql(self, prewhere=False) -> str: """ Returns the contents of the query's `WHERE` or `PREWHERE` clause as a string. """ q_object = self._prewhere_q if prewhere else self._where_q return q_object.to_sql(self._model_cls) - def count(self): + def count(self) -> Union[int, Coroutine[int]]: """ Returns the number of matching model instances. """ if self._distinct or self._limits: # Use a subquery, since a simple count won't be accurate - sql = u'SELECT count() FROM (%s)' % self.as_sql() + sql = 'SELECT count() FROM (%s)' % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 @@ -475,7 +490,7 @@ class QuerySet(object): conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) return self._database.count(self._model_cls, conditions) - def order_by(self, *field_names): + def order_by(self, *field_names) -> "QuerySet": """ Returns a copy of this queryset with the ordering changed. """ @@ -483,7 +498,7 @@ class QuerySet(object): qs._order_by = field_names return qs - def only(self, *field_names): + def only(self, *field_names) -> "QuerySet": """ Returns a copy of this queryset limited to the specified field names. Useful when there are large fields that are not needed, @@ -493,8 +508,8 @@ class QuerySet(object): qs._fields = field_names return qs - def _filter_or_exclude(self, *q, **kwargs): - from .funcs import F + def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet": + from clickhouse_orm.funcs import F inverse = kwargs.pop('_inverse', False) prewhere = kwargs.pop('prewhere', False) @@ -524,21 +539,21 @@ class QuerySet(object): return qs - def filter(self, *q, **kwargs): + def filter(self, *q, **kwargs) -> "QuerySet": """ Returns a copy of this queryset that includes only rows matching the conditions. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. """ return self._filter_or_exclude(*q, **kwargs) - def exclude(self, *q, **kwargs): + def exclude(self, *q, **kwargs) -> "QuerySet": """ Returns a copy of this queryset that excludes all rows matching the conditions. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. """ return self._filter_or_exclude(*q, _inverse=True, **kwargs) - def paginate(self, page_num=1, page_size=100): + def paginate(self, page_num=1, page_size=100) -> Page: """ Returns a single page of model instances that match the queryset. Note that `order_by` should be used first, to ensure a correct @@ -550,7 +565,8 @@ class QuerySet(object): The result is a namedtuple containing `objects` (list), `number_of_objects`, `pages_total`, `number` (of the current page), and `page_size`. """ - from .database import Page + from clickhouse_orm.database import Page + count = self.count() pages_total = int(ceil(count / float(page_size))) if page_num == -1: @@ -566,7 +582,7 @@ class QuerySet(object): page_size=page_size ) - def distinct(self): + def distinct(self) -> "QuerySet": """ Adds a DISTINCT clause to the query, meaning that any duplicate rows in the results will be omitted. @@ -575,12 +591,13 @@ class QuerySet(object): qs._distinct = True return qs - def final(self): + def final(self) -> "QuerySet": """ Adds a FINAL modifier to table, meaning data will be collapsed to final version. Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. """ - from .engines import CollapsingMergeTree, ReplacingMergeTree + from clickhouse_orm.engines import CollapsingMergeTree, ReplacingMergeTree + if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): raise TypeError( 'final() method can be used only with the CollapsingMergeTree' @@ -591,7 +608,7 @@ class QuerySet(object): qs._final = True return qs - def delete(self): + def delete(self) -> "QuerySet": """ Deletes all records matched by this queryset's conditions. Note that ClickHouse performs deletions in the background, so they are not immediate. @@ -602,7 +619,7 @@ class QuerySet(object): self._database.raw(sql) return self - def update(self, **kwargs): + def update(self, **kwargs) -> "QuerySet": """ Updates all records matched by this queryset's conditions. Keyword arguments specify the field names and expressions to use for the update. @@ -627,7 +644,7 @@ class QuerySet(object): assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._final, 'Mutations are not allowed after calling final()' - def aggregate(self, *args, **kwargs): + def aggregate(self, *args, **kwargs) -> "AggregateQuerySet": """ Returns an `AggregateQuerySet` over this query, with `args` serving as grouping fields and `kwargs` serving as calculated fields. At least one @@ -650,7 +667,12 @@ class AggregateQuerySet(QuerySet): A queryset used for aggregation. """ - def __init__(self, base_qs, grouping_fields, calculated_fields): + def __init__( + self, + base_queryset: QuerySet, + grouping_fields: tuple[Any], + calculated_fields: dict[str, str] + ): """ Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`. @@ -658,24 +680,26 @@ class AggregateQuerySet(QuerySet): ``` ('event_type', 'event_subtype') ``` - The calculated fields should be a mapping from name to a ClickHouse aggregation function. For example: + The calculated fields should be a mapping from name to a ClickHouse aggregation function. + + For example: ``` {'weekday': 'toDayOfWeek(event_date)', 'number_of_events': 'count()'} ``` At least one calculated field is required. """ - super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database) + super().__init__(base_queryset._model_cls, base_queryset._database) assert calculated_fields, 'No calculated fields specified for aggregation' self._fields = grouping_fields self._grouping_fields = grouping_fields self._calculated_fields = calculated_fields - self._order_by = list(base_qs._order_by) - self._where_q = base_qs._where_q - self._prewhere_q = base_qs._prewhere_q - self._limits = base_qs._limits - self._distinct = base_qs._distinct + self._order_by = list(base_queryset._order_by) + self._where_q = base_queryset._where_q + self._prewhere_q = base_queryset._prewhere_q + self._limits = base_queryset._limits + self._distinct = base_queryset._distinct - def group_by(self, *args): + def group_by(self, *args) -> "AggregateQuerySet": """ This method lets you specify the grouping fields explicitly. The `args` must be names of grouping fields or calculated fields that this queryset was @@ -700,7 +724,7 @@ class AggregateQuerySet(QuerySet): """ raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') - def select_fields_as_sql(self): + def select_fields_as_sql(self) -> str: """ Returns the selected fields or expressions as a SQL string. """ @@ -710,15 +734,17 @@ class AggregateQuerySet(QuerySet): def __iter__(self): return self._database.select(self.as_sql()) # using an ad-hoc model - def count(self): + def count(self) -> Union[int, Coroutine[int]]: """ Returns the number of rows after aggregation. """ - sql = u'SELECT count() FROM (%s)' % self.as_sql() + sql = 'SELECT count() FROM (%s)' % self.as_sql() raw = self._database.raw(sql) + if isinstance(raw, CoroutineType): + return raw return int(raw) if raw else 0 - def with_totals(self): + def with_totals(self) -> "AggregateQuerySet": """ Adds WITH TOTALS modifier ot GROUP BY, making query return extra row with aggregate function calculated across all the rows. More information: diff --git a/src/clickhouse_orm/session.py b/src/clickhouse_orm/session.py index ccd99bf..ef5be08 100644 --- a/src/clickhouse_orm/session.py +++ b/src/clickhouse_orm/session.py @@ -1,17 +1,17 @@ import uuid from typing import Optional -from contextvars import ContextVar +from contextvars import ContextVar, Token ctx_session_id: ContextVar[str] = ContextVar('ck.session_id') -ctx_session_timeout: ContextVar[int] = ContextVar('ck.session_timeout') +ctx_session_timeout: ContextVar[float] = ContextVar('ck.session_timeout') class SessionContext: - def __init__(self, session: str, timeout: int): + def __init__(self, session: str, timeout: float): self.session = session self.timeout = timeout - self.token1 = None - self.token2 = None + self.token1: Optional[Token[str]] = None + self.token2: Optional[Token[float]] = None def __enter__(self) -> str: self.token1 = ctx_session_id.set(self.session) diff --git a/src/clickhouse_orm/utils.py b/src/clickhouse_orm/utils.py index 27cf09e..8f97eb8 100644 --- a/src/clickhouse_orm/utils.py +++ b/src/clickhouse_orm/utils.py @@ -114,7 +114,7 @@ def parse_array(array_string): array_string = array_string[match.end():] else: # Start of non-quoted value, find its end - match = re.search(r",|\]", array_string) + match = re.search(r",|\]|\)", array_string) values.append(array_string[0: match.start()]) array_string = array_string[match.end() - 1:]