Support PointField/RingField

This commit is contained in:
sswest 2022-05-22 15:58:59 +08:00
parent 4f157e1857
commit 17a5c30bfd
10 changed files with 168 additions and 59 deletions

View File

@ -20,10 +20,7 @@ dependencies = [
"iso8601 >= 0.1.12",
"setuptools"
]
version = "0.0.1"
version = "0.0.2"
[tool.setuptools.packages.find]
where = ["src"]
[project.optional-dependencies]
pkg = ["setuptools", "requests", "pytz", "iso8601>=0.1.12"]

View File

View File

@ -0,0 +1 @@
from .fields import PointField, Point

View File

@ -0,0 +1,97 @@
from clickhouse_orm.fields import Field, Float64Field
from clickhouse_orm.utils import POINT_REGEX, RING_VALID_REGEX
class Point:
def __init__(self, x, y):
self.x = float(x)
self.y = float(y)
def __repr__(self):
return f'<Point x={self.x} y={self.y}>'
def to_db_string(self):
return f'({self.x},{self.y})'
class Ring:
def __init__(self, points):
self.array = points
@property
def size(self):
return len(self.array)
def __len__(self):
return len(self.array)
def __repr__(self):
return f'<Ring {self.to_db_string()}>'
def to_db_string(self):
return f'[{",".join(pt.to_db_string() for pt in self.array)}]'
def parse_point(array_string: str) -> Point:
if len(array_string) < 2 or array_string[0] != '(' or array_string[-1] != ')':
raise ValueError('Invalid point string: "%s"' % array_string)
x, y = array_string.strip('()').split(',')
return Point(x, y)
def parse_ring(array_string: str) -> Ring:
if not RING_VALID_REGEX.match(array_string):
raise ValueError('Invalid ring string: "%s"' % array_string)
ring = []
for point in POINT_REGEX.finditer(array_string):
x, y = point.group('x'), point.group('y')
ring.append(Point(x, y))
return Ring(ring)
class PointField(Field):
class_default = Point(0, 0)
db_type = 'Point'
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
super().__init__(default, alias, materialized, readonly, codec)
self.inner_field = Float64Field()
def to_python(self, value, timezone_in_use):
if isinstance(value, str):
value = parse_point(value)
elif isinstance(value, (tuple, list)):
if len(value) != 2:
raise ValueError('PointField takes 2 value, but %s were given' % len(value))
value = Point(value[0], value[1])
if not isinstance(value, Point):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
return value
def validate(self, value):
pass
def to_db_string(self, value, quote=True):
return value.to_db_string()
class RingField(Field):
class_default = [Point(0, 0)]
db_type = 'Ring'
def to_python(self, value, timezone_in_use):
if isinstance(value, str):
value = parse_ring(value)
elif isinstance(value, (tuple, list)):
ring = []
for point in value:
if len(point) != 2:
raise ValueError('Point takes 2 value, but %s were given' % len(value))
ring.append(Point(point[0], point[1]))
value = Ring(ring)
if not isinstance(value, Ring):
raise ValueError('PointField expects list or tuple and Point, not %s' % type(value))
return value
def to_db_string(self, value, quote=True):
return value.to_db_string()

View File

@ -254,7 +254,7 @@ class Database(object):
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
'''
from infi.clickhouse_orm.query import Q
from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
if conditions:
if isinstance(conditions, Q):
@ -311,7 +311,7 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
'''
from infi.clickhouse_orm.query import Q
from clickhouse_orm.query import Q
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:

View File

@ -91,7 +91,7 @@ class MergeTree(Engine):
elif not self.date_col:
# Can't import it globally due to circular import
from infi.clickhouse_orm.database import DatabaseException
from clickhouse_orm.database import DatabaseException
raise DatabaseException("Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/")

View File

@ -1,15 +1,17 @@
from __future__ import unicode_literals
from calendar import timegm
import datetime
from decimal import Decimal, localcontext
from logging import getLogger
from ipaddress import IPv4Address, IPv6Address
from uuid import UUID
import iso8601
import pytz
from calendar import timegm
from decimal import Decimal, localcontext
from uuid import UUID
from logging import getLogger
from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .funcs import F, FunctionOperatorsMixin
from ipaddress import IPv4Address, IPv6Address
logger = getLogger('clickhouse_orm')

View File

@ -200,7 +200,7 @@ class ModelBase(type):
@classmethod
def create_ad_hoc_field(cls, db_type):
import infi.clickhouse_orm.fields as orm_fields
import clickhouse_orm.fields as orm_fields
# Enums
if db_type.startswith('Enum'):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)

View File

@ -23,7 +23,7 @@ class Operator(object):
raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True):
from infi.clickhouse_orm.funcs import F
from clickhouse_orm.funcs import F
if isinstance(value, F):
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -123,8 +123,10 @@ class BetweenOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(
str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(
str(value[1])) > 0 else None
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
if value0 and not value1:
@ -132,13 +134,16 @@ class BetweenOperator(Operator):
if value1 and not value0:
return ' '.join([field_name, '<=', value1])
# Define the set of builtin operators
_operators = {}
def register_operator(name, sql):
_operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL'))
register_operator('gt', SimpleOperator('>'))
@ -170,6 +175,7 @@ class FieldCond(Cond):
"""
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value):
self._field_name = field_name
self._operator = _operators.get(operator)
@ -189,12 +195,12 @@ class FieldCond(Cond):
class Q(object):
AND_MODE = 'AND'
OR_MODE = 'OR'
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in
filter_fields.items()]
self._children = []
self._negate = False
self._mode = self.AND_MODE
@ -335,7 +341,7 @@ class QuerySet(object):
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
start = s.start or 0
stop = s.stop or 2**63 - 1
stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end'
qs = copy(self)
@ -518,7 +524,7 @@ class QuerySet(object):
raise ValueError('Invalid page number: %d' % page_num)
offset = (page_num - 1) * page_size
return Page(
objects=list(self[offset : offset + page_size]),
objects=list(self[offset: offset + page_size]),
number_of_objects=count,
pages_total=pages_total,
number=page_num,
@ -541,7 +547,8 @@ class QuerySet(object):
"""
from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines')
raise TypeError(
'final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines')
qs = copy(self)
qs._final = True
@ -568,7 +575,8 @@ class QuerySet(object):
self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (
self._model_cls.table_name(), fields, conditions)
self._database.raw(sql)
return self
@ -658,7 +666,8 @@ class AggregateQuerySet(QuerySet):
"""
Returns the selected fields or expressions as a SQL string.
"""
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()])
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
self._calculated_fields.items()])
def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model

View File

@ -2,28 +2,29 @@ import codecs
import re
from datetime import date, datetime, tzinfo, timedelta
SPECIAL_CHARS = {
"\b" : "\\b",
"\f" : "\\f",
"\r" : "\\r",
"\n" : "\\n",
"\t" : "\\t",
"\0" : "\\0",
"\\" : "\\\\",
"'" : "\\'"
"\b": "\\b",
"\f": "\\f",
"\r": "\\r",
"\n": "\\n",
"\t": "\\t",
"\0": "\\0",
"\\": "\\\\",
"'": "\\'"
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
def escape(value, quote=True):
'''
"""
If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one.
'''
"""
def escape_one(match):
return SPECIAL_CHARS[match.group(0)]
@ -48,7 +49,7 @@ def arg_to_sql(arg):
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
None, numbers, timezones, arrays/iterables.
"""
from infi.clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
from clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
if isinstance(arg, F):
return arg.to_sql()
if isinstance(arg, Field):
@ -109,12 +110,12 @@ def parse_array(array_string):
match = re.search(r"[^\\]'", array_string)
if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1])
values.append(array_string[1: match.start() + 1])
array_string = array_string[match.end():]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]", array_string)
values.append(array_string[0 : match.start()])
values.append(array_string[0: match.start()])
array_string = array_string[match.end() - 1:]
@ -157,11 +158,13 @@ def get_subclass_names(locals, base_class):
class NoValue:
'''
"""
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
"""
def __repr__(self):
return 'NO_VALUE'
NO_VALUE = NoValue()