diff --git a/clickhouse_orm/query.py b/clickhouse_orm/query.py index 675c98a..d9be1be 100644 --- a/clickhouse_orm/query.py +++ b/clickhouse_orm/query.py @@ -3,13 +3,13 @@ from __future__ import unicode_literals import pytz from copy import copy, deepcopy from math import ceil -from datetime import date, datetime from .utils import comma_join, string_or_func, arg_to_sql # TODO # - check that field names are valid + class Operator(object): """ Base class for filtering operators. @@ -20,10 +20,11 @@ class Operator(object): Subclasses should implement this method. It returns an SQL string that applies this operator on the given field and value. """ - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def _value_to_sql(self, field, value, quote=True): from clickhouse_orm.funcs import F + if isinstance(value, F): return value.to_sql() return field.to_db_string(field.to_python(value, pytz.utc), quote) @@ -41,9 +42,9 @@ class SimpleOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) - if value == '\\N' and self._sql_for_null is not None: - return ' '.join([field_name, self._sql_for_null]) - return ' '.join([field_name, self._sql_operator, value]) + if value == "\\N" and self._sql_for_null is not None: + return " ".join([field_name, self._sql_for_null]) + return " ".join([field_name, self._sql_operator, value]) class InOperator(Operator): @@ -63,7 +64,7 @@ class InOperator(Operator): pass else: value = comma_join([self._value_to_sql(field, v) for v in value]) - return '%s IN (%s)' % (field_name, value) + return "%s IN (%s)" % (field_name, value) class LikeOperator(Operator): @@ -79,12 +80,12 @@ class LikeOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) value = self._value_to_sql(field, value, quote=False) - value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') + value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_") pattern = self._pattern.format(value) if self._case_sensitive: - return '%s LIKE \'%s\'' % (field_name, pattern) + return "%s LIKE '%s'" % (field_name, pattern) else: - return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern) + return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field_name, pattern) class IExactOperator(Operator): @@ -95,7 +96,7 @@ class IExactOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) value = self._value_to_sql(field, value) - return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) + return "lowerUTF8(%s) = lowerUTF8(%s)" % (field_name, value) class NotOperator(Operator): @@ -108,7 +109,7 @@ class NotOperator(Operator): def to_sql(self, model_cls, field_name, value): # Negate the base operator - return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) + return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value) class BetweenOperator(Operator): @@ -126,35 +127,38 @@ class BetweenOperator(Operator): value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None if value0 and value1: - return '%s BETWEEN %s AND %s' % (field_name, value0, value1) + return "%s BETWEEN %s AND %s" % (field_name, value0, value1) if value0 and not value1: - return ' '.join([field_name, '>=', value0]) + return " ".join([field_name, ">=", value0]) if value1 and not value0: - return ' '.join([field_name, '<=', value1]) + return " ".join([field_name, "<=", value1]) + # Define the set of builtin operators _operators = {} + def register_operator(name, sql): _operators[name] = sql -register_operator('eq', SimpleOperator('=', 'IS NULL')) -register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) -register_operator('gt', SimpleOperator('>')) -register_operator('gte', SimpleOperator('>=')) -register_operator('lt', SimpleOperator('<')) -register_operator('lte', SimpleOperator('<=')) -register_operator('between', BetweenOperator()) -register_operator('in', InOperator()) -register_operator('not_in', NotOperator(InOperator())) -register_operator('contains', LikeOperator('%{}%')) -register_operator('startswith', LikeOperator('{}%')) -register_operator('endswith', LikeOperator('%{}')) -register_operator('icontains', LikeOperator('%{}%', False)) -register_operator('istartswith', LikeOperator('{}%', False)) -register_operator('iendswith', LikeOperator('%{}', False)) -register_operator('iexact', IExactOperator()) + +register_operator("eq", SimpleOperator("=", "IS NULL")) +register_operator("ne", SimpleOperator("!=", "IS NOT NULL")) +register_operator("gt", SimpleOperator(">")) +register_operator("gte", SimpleOperator(">=")) +register_operator("lt", SimpleOperator("<")) +register_operator("lte", SimpleOperator("<=")) +register_operator("between", BetweenOperator()) +register_operator("in", InOperator()) +register_operator("not_in", NotOperator(InOperator())) +register_operator("contains", LikeOperator("%{}%")) +register_operator("startswith", LikeOperator("{}%")) +register_operator("endswith", LikeOperator("%{}")) +register_operator("icontains", LikeOperator("%{}%", False)) +register_operator("istartswith", LikeOperator("{}%", False)) +register_operator("iendswith", LikeOperator("%{}", False)) +register_operator("iexact", IExactOperator()) class Cond(object): @@ -170,19 +174,20 @@ class FieldCond(Cond): """ A single query condition made up of Field + Operator + Value. """ + def __init__(self, field_name, operator, value): self._field_name = field_name self._operator = _operators.get(operator) if self._operator is None: # The field name contains __ like my__field - self._field_name = field_name + '__' + operator - self._operator = _operators['eq'] + self._field_name = field_name + "__" + operator + self._operator = _operators["eq"] self._value = value def to_sql(self, model_cls): return self._operator.to_sql(model_cls, self._field_name, self._value) - def __deepcopy__(self, memodict={}): + def __deepcopy__(self, memo): res = copy(self) res._value = deepcopy(self._value) return res @@ -190,8 +195,8 @@ class FieldCond(Cond): class Q(object): - AND_MODE = 'AND' - OR_MODE = 'OR' + AND_MODE = "AND" + OR_MODE = "OR" def __init__(self, *filter_funcs, **filter_fields): self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()] @@ -224,10 +229,10 @@ class Q(object): return q def _build_cond(self, key, value): - if '__' in key: - field_name, operator = key.rsplit('__', 1) + if "__" in key: + field_name, operator = key.rsplit("__", 1) else: - field_name, operator = key, 'eq' + field_name, operator = key, "eq" return FieldCond(field_name, operator, value) def to_sql(self, model_cls): @@ -241,16 +246,16 @@ class Q(object): if not condition_sql: # Empty Q() object returns everything - sql = '1' + sql = "1" elif len(condition_sql) == 1: # Skip not needed brackets over single condition sql = condition_sql[0] else: # Each condition must be enclosed in brackets, or order of operations may be wrong - sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql) + sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql) if self._negate: - sql = 'NOT (%s)' % sql + sql = "NOT (%s)" % sql return sql @@ -268,7 +273,7 @@ class Q(object): def __bool__(self): return not self.is_empty - def __deepcopy__(self, memodict={}): + def __deepcopy__(self, memo): q = Q() q._conds = [deepcopy(cond) for cond in self._conds] q._negate = self._negate @@ -318,7 +323,7 @@ class QuerySet(object): """ return bool(self.count()) - def __nonzero__(self): # Python 2 compatibility + def __nonzero__(self): # Python 2 compatibility return type(self).__bool__(self) def __str__(self): @@ -327,17 +332,17 @@ class QuerySet(object): def __getitem__(self, s): if isinstance(s, int): # Single index - assert s >= 0, 'negative indexes are not supported' + assert s >= 0, "negative indexes are not supported" qs = copy(self) qs._limits = (s, 1) return next(iter(qs)) else: # Slice - assert s.step in (None, 1), 'step is not supported in slices' + assert s.step in (None, 1), "step is not supported in slices" start = s.start or 0 - stop = s.stop or 2**63 - 1 - assert start >= 0 and stop >= 0, 'negative indexes are not supported' - assert start <= stop, 'start of slice cannot be smaller than its end' + stop = s.stop or 2 ** 63 - 1 + assert start >= 0 and stop >= 0, "negative indexes are not supported" + assert start <= stop, "start of slice cannot be smaller than its end" qs = copy(self) qs._limits = (start, stop - start) return qs @@ -353,7 +358,7 @@ class QuerySet(object): offset_limit = (0, offset_limit) offset = offset_limit[0] limit = offset_limit[1] - assert offset >= 0 and limit >= 0, 'negative limits are not supported' + assert offset >= 0 and limit >= 0, "negative limits are not supported" qs = copy(self) qs._limit_by = (offset, limit) qs._limit_by_fields = fields_or_expr @@ -363,44 +368,44 @@ class QuerySet(object): """ Returns the selected fields or expressions as a SQL string. """ - fields = '*' + fields = "*" if self._fields: - fields = comma_join('`%s`' % field for field in self._fields) + fields = comma_join("`%s`" % field for field in self._fields) return fields def as_sql(self): """ Returns the whole query as a SQL string. """ - distinct = 'DISTINCT ' if self._distinct else '' - final = ' FINAL' if self._final else '' - table_name = '`%s`' % self._model_cls.table_name() + distinct = "DISTINCT " if self._distinct else "" + final = " FINAL" if self._final else "" + table_name = "`%s`" % self._model_cls.table_name() if self._model_cls.is_system_model(): - table_name = '`system`.' + table_name + table_name = "`system`." + table_name params = (distinct, self.select_fields_as_sql(), table_name, final) - sql = u'SELECT %s%s\nFROM %s%s' % params + sql = "SELECT %s%s\nFROM %s%s" % params if self._prewhere_q and not self._prewhere_q.is_empty: - sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) + sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True) if self._where_q and not self._where_q.is_empty: - sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) + sql += "\nWHERE " + self.conditions_as_sql(prewhere=False) if self._grouping_fields: - sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) + sql += "\nGROUP BY %s" % comma_join("`%s`" % field for field in self._grouping_fields) if self._grouping_with_totals: - sql += ' WITH TOTALS' + sql += " WITH TOTALS" if self._order_by: - sql += '\nORDER BY ' + self.order_by_as_sql() + sql += "\nORDER BY " + self.order_by_as_sql() if self._limit_by: - sql += '\nLIMIT %d, %d' % self._limit_by - sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields) + sql += "\nLIMIT %d, %d" % self._limit_by + sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields) if self._limits: - sql += '\nLIMIT %d, %d' % self._limits + sql += "\nLIMIT %d, %d" % self._limits return sql @@ -408,10 +413,12 @@ class QuerySet(object): """ Returns the contents of the query's `ORDER BY` clause as a string. """ - return comma_join([ - '%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field) - for field in self._order_by - ]) + return comma_join( + [ + "%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field) + for field in self._order_by + ] + ) def conditions_as_sql(self, prewhere=False): """ @@ -426,7 +433,7 @@ class QuerySet(object): """ if self._distinct or self._limits: # Use a subquery, since a simple count won't be accurate - sql = u'SELECT count() FROM (%s)' % self.as_sql() + sql = "SELECT count() FROM (%s)" % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 @@ -455,8 +462,8 @@ class QuerySet(object): def _filter_or_exclude(self, *q, **kwargs): from .funcs import F - inverse = kwargs.pop('_inverse', False) - prewhere = kwargs.pop('prewhere', False) + inverse = kwargs.pop("_inverse", False) + prewhere = kwargs.pop("prewhere", False) qs = copy(self) @@ -510,19 +517,20 @@ class QuerySet(object): `pages_total`, `number` (of the current page), and `page_size`. """ from .database import Page + count = self.count() pages_total = int(ceil(count / float(page_size))) if page_num == -1: page_num = pages_total elif page_num < 1: - raise ValueError('Invalid page number: %d' % page_num) + raise ValueError("Invalid page number: %d" % page_num) offset = (page_num - 1) * page_size return Page( objects=list(self[offset : offset + page_size]), number_of_objects=count, pages_total=pages_total, number=page_num, - page_size=page_size + page_size=page_size, ) def distinct(self): @@ -540,8 +548,11 @@ class QuerySet(object): Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. """ from .engines import CollapsingMergeTree, ReplacingMergeTree + if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): - raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') + raise TypeError( + "final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines" + ) qs = copy(self) qs._final = True @@ -554,7 +565,7 @@ class QuerySet(object): """ self._verify_mutation_allowed() conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) - sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions) + sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions) self._database.raw(sql) return self @@ -564,22 +575,22 @@ class QuerySet(object): Keyword arguments specify the field names and expressions to use for the update. Note that ClickHouse performs updates in the background, so they are not immediate. """ - assert kwargs, 'No fields specified for update' + assert kwargs, "No fields specified for update" self._verify_mutation_allowed() - fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) + fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) - sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions) + sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (self._model_cls.table_name(), fields, conditions) self._database.raw(sql) return self def _verify_mutation_allowed(self): - ''' + """ Checks that the queryset's state allows mutations. Raises an AssertionError if not. - ''' - assert not self._limits, 'Mutations are not allowed after slicing the queryset' - assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' - assert not self._distinct, 'Mutations are not allowed after calling distinct()' - assert not self._final, 'Mutations are not allowed after calling final()' + """ + assert not self._limits, "Mutations are not allowed after slicing the queryset" + assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)" + assert not self._distinct, "Mutations are not allowed after calling distinct()" + assert not self._final, "Mutations are not allowed after calling final()" def aggregate(self, *args, **kwargs): """ @@ -619,7 +630,7 @@ class AggregateQuerySet(QuerySet): At least one calculated field is required. """ super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database) - assert calculated_fields, 'No calculated fields specified for aggregation' + assert calculated_fields, "No calculated fields specified for aggregation" self._fields = grouping_fields self._grouping_fields = grouping_fields self._calculated_fields = calculated_fields @@ -636,8 +647,9 @@ class AggregateQuerySet(QuerySet): created with. """ for name in args: - assert name in self._fields or name in self._calculated_fields, \ - 'Cannot group by `%s` since it is not included in the query' % name + assert name in self._fields or name in self._calculated_fields, ( + "Cannot group by `%s` since it is not included in the query" % name + ) qs = copy(self) qs._grouping_fields = args return qs @@ -652,22 +664,24 @@ class AggregateQuerySet(QuerySet): """ This method is not supported on `AggregateQuerySet`. """ - raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') + raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet") def select_fields_as_sql(self): """ Returns the selected fields or expressions as a SQL string. """ - return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) + return comma_join( + [str(f) for f in self._fields] + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()] + ) def __iter__(self): - return self._database.select(self.as_sql()) # using an ad-hoc model + return self._database.select(self.as_sql()) # using an ad-hoc model def count(self): """ Returns the number of rows after aggregation. """ - sql = u'SELECT count() FROM (%s)' % self.as_sql() + sql = "SELECT count() FROM (%s)" % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 @@ -682,7 +696,7 @@ class AggregateQuerySet(QuerySet): return qs def _verify_mutation_allowed(self): - raise AssertionError('Cannot mutate an AggregateQuerySet') + raise AssertionError("Cannot mutate an AggregateQuerySet") # Expose only relevant classes in import *