diff --git a/pyproject.toml b/pyproject.toml index d1353d9..0b00f14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,7 @@ dependencies = [ "iso8601 >= 0.1.12", "setuptools" ] -version = "0.0.1" +version = "0.0.2" [tool.setuptools.packages.find] where = ["src"] - -[project.optional-dependencies] -pkg = ["setuptools", "requests", "pytz", "iso8601>=0.1.12"] diff --git a/src/clickhouse_orm/contrib/__init__.py b/src/clickhouse_orm/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/clickhouse_orm/contrib/geo/__init__.py b/src/clickhouse_orm/contrib/geo/__init__.py new file mode 100644 index 0000000..fa03948 --- /dev/null +++ b/src/clickhouse_orm/contrib/geo/__init__.py @@ -0,0 +1 @@ +from .fields import PointField, Point diff --git a/src/clickhouse_orm/contrib/geo/fields.py b/src/clickhouse_orm/contrib/geo/fields.py new file mode 100644 index 0000000..6c1bb87 --- /dev/null +++ b/src/clickhouse_orm/contrib/geo/fields.py @@ -0,0 +1,97 @@ +from clickhouse_orm.fields import Field, Float64Field +from clickhouse_orm.utils import POINT_REGEX, RING_VALID_REGEX + + +class Point: + def __init__(self, x, y): + self.x = float(x) + self.y = float(y) + + def __repr__(self): + return f'' + + def to_db_string(self): + return f'({self.x},{self.y})' + + +class Ring: + def __init__(self, points): + self.array = points + + @property + def size(self): + return len(self.array) + + def __len__(self): + return len(self.array) + + def __repr__(self): + return f'' + + def to_db_string(self): + return f'[{",".join(pt.to_db_string() for pt in self.array)}]' + + +def parse_point(array_string: str) -> Point: + if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')': + raise ValueError('Invalid point string: "%s"' % array_string) + x, y = array_string.strip('()').split(',') + return Point(x, y) + + +def parse_ring(array_string: str) -> Ring: + if not RING_VALID_REGEX.match(array_string): + raise ValueError('Invalid ring string: "%s"' % array_string) + ring = [] + for point in POINT_REGEX.finditer(array_string): + x, y = point.group('x'), point.group('y') + ring.append(Point(x, y)) + return Ring(ring) + + +class PointField(Field): + class_default = Point(0, 0) + db_type = 'Point' + + def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): + super().__init__(default, alias, materialized, readonly, codec) + self.inner_field = Float64Field() + + def to_python(self, value, timezone_in_use): + if isinstance(value, str): + value = parse_point(value) + elif isinstance(value, (tuple, list)): + if len(value) != 2: + raise ValueError('PointField takes 2 value, but %s were given' % len(value)) + value = Point(value[0], value[1]) + if not isinstance(value, Point): + raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) + return value + + def validate(self, value): + pass + + def to_db_string(self, value, quote=True): + return value.to_db_string() + + +class RingField(Field): + class_default = [Point(0, 0)] + db_type = 'Ring' + + def to_python(self, value, timezone_in_use): + if isinstance(value, str): + value = parse_ring(value) + elif isinstance(value, (tuple, list)): + ring = [] + for point in value: + if len(point) != 2: + raise ValueError('Point takes 2 value, but %s were given' % len(value)) + ring.append(Point(point[0], point[1])) + value = Ring(ring) + if not isinstance(value, Ring): + raise ValueError('PointField expects list or tuple and Point, not %s' % type(value)) + return value + + def to_db_string(self, value, quote=True): + return value.to_db_string() diff --git a/src/clickhouse_orm/database.py b/src/clickhouse_orm/database.py index 703e982..e4396cc 100644 --- a/src/clickhouse_orm/database.py +++ b/src/clickhouse_orm/database.py @@ -254,7 +254,7 @@ class Database(object): - `model_class`: the model to count. - `conditions`: optional SQL conditions (contents of the WHERE clause). ''' - from infi.clickhouse_orm.query import Q + from clickhouse_orm.query import Q query = 'SELECT count() FROM $table' if conditions: if isinstance(conditions, Q): @@ -311,7 +311,7 @@ class Database(object): The result is a namedtuple containing `objects` (list), `number_of_objects`, `pages_total`, `number` (of the current page), and `page_size`. ''' - from infi.clickhouse_orm.query import Q + from clickhouse_orm.query import Q count = self.count(model_class, conditions) pages_total = int(ceil(count / float(page_size))) if page_num == -1: diff --git a/src/clickhouse_orm/engines.py b/src/clickhouse_orm/engines.py index 7fb83be..285a848 100644 --- a/src/clickhouse_orm/engines.py +++ b/src/clickhouse_orm/engines.py @@ -91,7 +91,7 @@ class MergeTree(Engine): elif not self.date_col: # Can't import it globally due to circular import - from infi.clickhouse_orm.database import DatabaseException + from clickhouse_orm.database import DatabaseException raise DatabaseException("Custom partitioning is not supported before ClickHouse 1.1.54310. " "Please update your server or use date_col syntax." "https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/") diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index 6e73e3f..2807e2c 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -1,15 +1,17 @@ from __future__ import unicode_literals +from calendar import timegm import datetime +from decimal import Decimal, localcontext +from logging import getLogger +from ipaddress import IPv4Address, IPv6Address +from uuid import UUID + import iso8601 import pytz -from calendar import timegm -from decimal import Decimal, localcontext -from uuid import UUID -from logging import getLogger from pytz import BaseTzInfo + from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names from .funcs import F, FunctionOperatorsMixin -from ipaddress import IPv4Address, IPv6Address logger = getLogger('clickhouse_orm') diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index e3f95e3..0fc5e7f 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -200,7 +200,7 @@ class ModelBase(type): @classmethod def create_ad_hoc_field(cls, db_type): - import infi.clickhouse_orm.fields as orm_fields + import clickhouse_orm.fields as orm_fields # Enums if db_type.startswith('Enum'): return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index 897d45a..3074cb2 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -20,10 +20,10 @@ class Operator(object): Subclasses should implement this method. It returns an SQL string that applies this operator on the given field and value. """ - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _value_to_sql(self, field, value, quote=True): - from infi.clickhouse_orm.funcs import F + from clickhouse_orm.funcs import F if isinstance(value, F): return value.to_sql() return field.to_db_string(field.to_python(value, pytz.utc), quote) @@ -123,8 +123,10 @@ class BetweenOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) - value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None - value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None + value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len( + str(value[0])) > 0 else None + value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len( + str(value[1])) > 0 else None if value0 and value1: return '%s BETWEEN %s AND %s' % (field_name, value0, value1) if value0 and not value1: @@ -132,29 +134,32 @@ class BetweenOperator(Operator): if value1 and not value0: return ' '.join([field_name, '<=', value1]) + # Define the set of builtin operators _operators = {} + def register_operator(name, sql): _operators[name] = sql -register_operator('eq', SimpleOperator('=', 'IS NULL')) -register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) -register_operator('gt', SimpleOperator('>')) -register_operator('gte', SimpleOperator('>=')) -register_operator('lt', SimpleOperator('<')) -register_operator('lte', SimpleOperator('<=')) -register_operator('between', BetweenOperator()) -register_operator('in', InOperator()) -register_operator('not_in', NotOperator(InOperator())) -register_operator('contains', LikeOperator('%{}%')) -register_operator('startswith', LikeOperator('{}%')) -register_operator('endswith', LikeOperator('%{}')) -register_operator('icontains', LikeOperator('%{}%', False)) + +register_operator('eq', SimpleOperator('=', 'IS NULL')) +register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) +register_operator('gt', SimpleOperator('>')) +register_operator('gte', SimpleOperator('>=')) +register_operator('lt', SimpleOperator('<')) +register_operator('lte', SimpleOperator('<=')) +register_operator('between', BetweenOperator()) +register_operator('in', InOperator()) +register_operator('not_in', NotOperator(InOperator())) +register_operator('contains', LikeOperator('%{}%')) +register_operator('startswith', LikeOperator('{}%')) +register_operator('endswith', LikeOperator('%{}')) +register_operator('icontains', LikeOperator('%{}%', False)) register_operator('istartswith', LikeOperator('{}%', False)) -register_operator('iendswith', LikeOperator('%{}', False)) -register_operator('iexact', IExactOperator()) +register_operator('iendswith', LikeOperator('%{}', False)) +register_operator('iexact', IExactOperator()) class Cond(object): @@ -170,6 +175,7 @@ class FieldCond(Cond): """ A single query condition made up of Field + Operator + Value. """ + def __init__(self, field_name, operator, value): self._field_name = field_name self._operator = _operators.get(operator) @@ -189,12 +195,12 @@ class FieldCond(Cond): class Q(object): - AND_MODE = 'AND' OR_MODE = 'OR' def __init__(self, *filter_funcs, **filter_fields): - self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()] + self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in + filter_fields.items()] self._children = [] self._negate = False self._mode = self.AND_MODE @@ -318,7 +324,7 @@ class QuerySet(object): """ return bool(self.count()) - def __nonzero__(self): # Python 2 compatibility + def __nonzero__(self): # Python 2 compatibility return type(self).__bool__(self) def __str__(self): @@ -335,7 +341,7 @@ class QuerySet(object): # Slice assert s.step in (None, 1), 'step is not supported in slices' start = s.start or 0 - stop = s.stop or 2**63 - 1 + stop = s.stop or 2 ** 63 - 1 assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start <= stop, 'start of slice cannot be smaller than its end' qs = copy(self) @@ -518,7 +524,7 @@ class QuerySet(object): raise ValueError('Invalid page number: %d' % page_num) offset = (page_num - 1) * page_size return Page( - objects=list(self[offset : offset + page_size]), + objects=list(self[offset: offset + page_size]), number_of_objects=count, pages_total=pages_total, number=page_num, @@ -541,7 +547,8 @@ class QuerySet(object): """ from .engines import CollapsingMergeTree, ReplacingMergeTree if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): - raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') + raise TypeError( + 'final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') qs = copy(self) qs._final = True @@ -568,7 +575,8 @@ class QuerySet(object): self._verify_mutation_allowed() fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) - sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions) + sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % ( + self._model_cls.table_name(), fields, conditions) self._database.raw(sql) return self @@ -637,7 +645,7 @@ class AggregateQuerySet(QuerySet): """ for name in args: assert name in self._fields or name in self._calculated_fields, \ - 'Cannot group by `%s` since it is not included in the query' % name + 'Cannot group by `%s` since it is not included in the query' % name qs = copy(self) qs._grouping_fields = args return qs @@ -658,10 +666,11 @@ class AggregateQuerySet(QuerySet): """ Returns the selected fields or expressions as a SQL string. """ - return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) + return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in + self._calculated_fields.items()]) def __iter__(self): - return self._database.select(self.as_sql()) # using an ad-hoc model + return self._database.select(self.as_sql()) # using an ad-hoc model def count(self): """ diff --git a/src/clickhouse_orm/utils.py b/src/clickhouse_orm/utils.py index c0d0325..27cf09e 100644 --- a/src/clickhouse_orm/utils.py +++ b/src/clickhouse_orm/utils.py @@ -2,28 +2,29 @@ import codecs import re from datetime import date, datetime, tzinfo, timedelta - SPECIAL_CHARS = { - "\b" : "\\b", - "\f" : "\\f", - "\r" : "\\r", - "\n" : "\\n", - "\t" : "\\t", - "\0" : "\\0", - "\\" : "\\\\", - "'" : "\\'" + "\b": "\\b", + "\f": "\\f", + "\r": "\\r", + "\n": "\\n", + "\t": "\\t", + "\0": "\\0", + "\\": "\\\\", + "'": "\\'" } SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") - +POINT_REGEX = re.compile(r"\((?P\d+(\.\d+)?),(?P\d+(\.\d+)?)\)") +RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]") def escape(value, quote=True): - ''' + """ If the value is a string, escapes any special characters and optionally surrounds it with single quotes. If the value is not a string (e.g. a number), converts it to one. - ''' + """ + def escape_one(match): return SPECIAL_CHARS[match.group(0)] @@ -48,7 +49,7 @@ def arg_to_sql(arg): Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans, None, numbers, timezones, arrays/iterables. """ - from infi.clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet + from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet if isinstance(arg, F): return arg.to_sql() if isinstance(arg, Field): @@ -109,12 +110,12 @@ def parse_array(array_string): match = re.search(r"[^\\]'", array_string) if match is None: raise ValueError('Missing closing quote: "%s"' % array_string) - values.append(array_string[1 : match.start() + 1]) + values.append(array_string[1: match.start() + 1]) array_string = array_string[match.end():] else: # Start of non-quoted value, find its end match = re.search(r",|\]", array_string) - values.append(array_string[0 : match.start()]) + values.append(array_string[0: match.start()]) array_string = array_string[match.end() - 1:] @@ -157,11 +158,13 @@ def get_subclass_names(locals, base_class): class NoValue: - ''' + """ A sentinel for fields with an expression for a default value, that were not assigned a value yet. - ''' + """ + def __repr__(self): return 'NO_VALUE' + NO_VALUE = NoValue()