Chore: fix linting on query.py

This commit is contained in:
olliemath 2021-07-27 22:56:45 +01:00
parent 0213aed397
commit 13655da35d

View File

@ -3,13 +3,13 @@ from __future__ import unicode_literals
import pytz import pytz
from copy import copy, deepcopy from copy import copy, deepcopy
from math import ceil from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql from .utils import comma_join, string_or_func, arg_to_sql
# TODO # TODO
# - check that field names are valid # - check that field names are valid
class Operator(object): class Operator(object):
""" """
Base class for filtering operators. Base class for filtering operators.
@ -20,10 +20,11 @@ class Operator(object):
Subclasses should implement this method. It returns an SQL string Subclasses should implement this method. It returns an SQL string
that applies this operator on the given field and value. that applies this operator on the given field and value.
""" """
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True): def _value_to_sql(self, field, value, quote=True):
from clickhouse_orm.funcs import F from clickhouse_orm.funcs import F
if isinstance(value, F): if isinstance(value, F):
return value.to_sql() return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote) return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -41,9 +42,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None: if value == "\\N" and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null]) return " ".join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value]) return " ".join([field_name, self._sql_operator, value])
class InOperator(Operator): class InOperator(Operator):
@ -63,7 +64,7 @@ class InOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value) return "%s IN (%s)" % (field_name, value)
class LikeOperator(Operator): class LikeOperator(Operator):
@ -79,12 +80,12 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False) value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value) pattern = self._pattern.format(value)
if self._case_sensitive: if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern) return "%s LIKE '%s'" % (field_name, pattern)
else: else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern) return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field_name, pattern)
class IExactOperator(Operator): class IExactOperator(Operator):
@ -95,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) return "lowerUTF8(%s) = lowerUTF8(%s)" % (field_name, value)
class NotOperator(Operator): class NotOperator(Operator):
@ -108,7 +109,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
# Negate the base operator # Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator): class BetweenOperator(Operator):
@ -126,35 +127,38 @@ class BetweenOperator(Operator):
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1: if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1) return "%s BETWEEN %s AND %s" % (field_name, value0, value1)
if value0 and not value1: if value0 and not value1:
return ' '.join([field_name, '>=', value0]) return " ".join([field_name, ">=", value0])
if value1 and not value0: if value1 and not value0:
return ' '.join([field_name, '<=', value1]) return " ".join([field_name, "<=", value1])
# Define the set of builtin operators # Define the set of builtin operators
_operators = {} _operators = {}
def register_operator(name, sql): def register_operator(name, sql):
_operators[name] = sql _operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator('gt', SimpleOperator('>')) register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator('gte', SimpleOperator('>=')) register_operator("gt", SimpleOperator(">"))
register_operator('lt', SimpleOperator('<')) register_operator("gte", SimpleOperator(">="))
register_operator('lte', SimpleOperator('<=')) register_operator("lt", SimpleOperator("<"))
register_operator('between', BetweenOperator()) register_operator("lte", SimpleOperator("<="))
register_operator('in', InOperator()) register_operator("between", BetweenOperator())
register_operator('not_in', NotOperator(InOperator())) register_operator("in", InOperator())
register_operator('contains', LikeOperator('%{}%')) register_operator("not_in", NotOperator(InOperator()))
register_operator('startswith', LikeOperator('{}%')) register_operator("contains", LikeOperator("%{}%"))
register_operator('endswith', LikeOperator('%{}')) register_operator("startswith", LikeOperator("{}%"))
register_operator('icontains', LikeOperator('%{}%', False)) register_operator("endswith", LikeOperator("%{}"))
register_operator('istartswith', LikeOperator('{}%', False)) register_operator("icontains", LikeOperator("%{}%", False))
register_operator('iendswith', LikeOperator('%{}', False)) register_operator("istartswith", LikeOperator("{}%", False))
register_operator('iexact', IExactOperator()) register_operator("iendswith", LikeOperator("%{}", False))
register_operator("iexact", IExactOperator())
class Cond(object): class Cond(object):
@ -170,19 +174,20 @@ class FieldCond(Cond):
""" """
A single query condition made up of Field + Operator + Value. A single query condition made up of Field + Operator + Value.
""" """
def __init__(self, field_name, operator, value): def __init__(self, field_name, operator, value):
self._field_name = field_name self._field_name = field_name
self._operator = _operators.get(operator) self._operator = _operators.get(operator)
if self._operator is None: if self._operator is None:
# The field name contains __ like my__field # The field name contains __ like my__field
self._field_name = field_name + '__' + operator self._field_name = field_name + "__" + operator
self._operator = _operators['eq'] self._operator = _operators["eq"]
self._value = value self._value = value
def to_sql(self, model_cls): def to_sql(self, model_cls):
return self._operator.to_sql(model_cls, self._field_name, self._value) return self._operator.to_sql(model_cls, self._field_name, self._value)
def __deepcopy__(self, memodict={}): def __deepcopy__(self, memo):
res = copy(self) res = copy(self)
res._value = deepcopy(self._value) res._value = deepcopy(self._value)
return res return res
@ -190,8 +195,8 @@ class FieldCond(Cond):
class Q(object): class Q(object):
AND_MODE = 'AND' AND_MODE = "AND"
OR_MODE = 'OR' OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields): def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()] self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
@ -224,10 +229,10 @@ class Q(object):
return q return q
def _build_cond(self, key, value): def _build_cond(self, key, value):
if '__' in key: if "__" in key:
field_name, operator = key.rsplit('__', 1) field_name, operator = key.rsplit("__", 1)
else: else:
field_name, operator = key, 'eq' field_name, operator = key, "eq"
return FieldCond(field_name, operator, value) return FieldCond(field_name, operator, value)
def to_sql(self, model_cls): def to_sql(self, model_cls):
@ -241,16 +246,16 @@ class Q(object):
if not condition_sql: if not condition_sql:
# Empty Q() object returns everything # Empty Q() object returns everything
sql = '1' sql = "1"
elif len(condition_sql) == 1: elif len(condition_sql) == 1:
# Skip not needed brackets over single condition # Skip not needed brackets over single condition
sql = condition_sql[0] sql = condition_sql[0]
else: else:
# Each condition must be enclosed in brackets, or order of operations may be wrong # Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql) sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate: if self._negate:
sql = 'NOT (%s)' % sql sql = "NOT (%s)" % sql
return sql return sql
@ -268,7 +273,7 @@ class Q(object):
def __bool__(self): def __bool__(self):
return not self.is_empty return not self.is_empty
def __deepcopy__(self, memodict={}): def __deepcopy__(self, memo):
q = Q() q = Q()
q._conds = [deepcopy(cond) for cond in self._conds] q._conds = [deepcopy(cond) for cond in self._conds]
q._negate = self._negate q._negate = self._negate
@ -318,7 +323,7 @@ class QuerySet(object):
""" """
return bool(self.count()) return bool(self.count())
def __nonzero__(self): # Python 2 compatibility def __nonzero__(self): # Python 2 compatibility
return type(self).__bool__(self) return type(self).__bool__(self)
def __str__(self): def __str__(self):
@ -327,17 +332,17 @@ class QuerySet(object):
def __getitem__(self, s): def __getitem__(self, s):
if isinstance(s, int): if isinstance(s, int):
# Single index # Single index
assert s >= 0, 'negative indexes are not supported' assert s >= 0, "negative indexes are not supported"
qs = copy(self) qs = copy(self)
qs._limits = (s, 1) qs._limits = (s, 1)
return next(iter(qs)) return next(iter(qs))
else: else:
# Slice # Slice
assert s.step in (None, 1), 'step is not supported in slices' assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0 start = s.start or 0
stop = s.stop or 2**63 - 1 stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, 'start of slice cannot be smaller than its end' assert start <= stop, "start of slice cannot be smaller than its end"
qs = copy(self) qs = copy(self)
qs._limits = (start, stop - start) qs._limits = (start, stop - start)
return qs return qs
@ -353,7 +358,7 @@ class QuerySet(object):
offset_limit = (0, offset_limit) offset_limit = (0, offset_limit)
offset = offset_limit[0] offset = offset_limit[0]
limit = offset_limit[1] limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported' assert offset >= 0 and limit >= 0, "negative limits are not supported"
qs = copy(self) qs = copy(self)
qs._limit_by = (offset, limit) qs._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr qs._limit_by_fields = fields_or_expr
@ -363,44 +368,44 @@ class QuerySet(object):
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
fields = '*' fields = "*"
if self._fields: if self._fields:
fields = comma_join('`%s`' % field for field in self._fields) fields = comma_join("`%s`" % field for field in self._fields)
return fields return fields
def as_sql(self): def as_sql(self):
""" """
Returns the whole query as a SQL string. Returns the whole query as a SQL string.
""" """
distinct = 'DISTINCT ' if self._distinct else '' distinct = "DISTINCT " if self._distinct else ""
final = ' FINAL' if self._final else '' final = " FINAL" if self._final else ""
table_name = '`%s`' % self._model_cls.table_name() table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model(): if self._model_cls.is_system_model():
table_name = '`system`.' + table_name table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final) params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = u'SELECT %s%s\nFROM %s%s' % params sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty: if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty: if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields: if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) sql += "\nGROUP BY %s" % comma_join("`%s`" % field for field in self._grouping_fields)
if self._grouping_with_totals: if self._grouping_with_totals:
sql += ' WITH TOTALS' sql += " WITH TOTALS"
if self._order_by: if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql() sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by: if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by sql += "\nLIMIT %d, %d" % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields) sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits: if self._limits:
sql += '\nLIMIT %d, %d' % self._limits sql += "\nLIMIT %d, %d" % self._limits
return sql return sql
@ -408,10 +413,12 @@ class QuerySet(object):
""" """
Returns the contents of the query's `ORDER BY` clause as a string. Returns the contents of the query's `ORDER BY` clause as a string.
""" """
return comma_join([ return comma_join(
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field) [
for field in self._order_by "%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
]) for field in self._order_by
]
)
def conditions_as_sql(self, prewhere=False): def conditions_as_sql(self, prewhere=False):
""" """
@ -426,7 +433,7 @@ class QuerySet(object):
""" """
if self._distinct or self._limits: if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate # Use a subquery, since a simple count won't be accurate
sql = u'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0
@ -455,8 +462,8 @@ class QuerySet(object):
def _filter_or_exclude(self, *q, **kwargs): def _filter_or_exclude(self, *q, **kwargs):
from .funcs import F from .funcs import F
inverse = kwargs.pop('_inverse', False) inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop('prewhere', False) prewhere = kwargs.pop("prewhere", False)
qs = copy(self) qs = copy(self)
@ -510,19 +517,20 @@ class QuerySet(object):
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
""" """
from .database import Page from .database import Page
count = self.count() count = self.count()
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
if page_num == -1: if page_num == -1:
page_num = pages_total page_num = pages_total
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
return Page( return Page(
objects=list(self[offset : offset + page_size]), objects=list(self[offset : offset + page_size]),
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def distinct(self): def distinct(self):
@ -540,8 +548,11 @@ class QuerySet(object):
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
""" """
from .engines import CollapsingMergeTree, ReplacingMergeTree from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') raise TypeError(
"final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines"
)
qs = copy(self) qs = copy(self)
qs._final = True qs._final = True
@ -554,7 +565,7 @@ class QuerySet(object):
""" """
self._verify_mutation_allowed() self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions) sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql) self._database.raw(sql)
return self return self
@ -564,22 +575,22 @@ class QuerySet(object):
Keyword arguments specify the field names and expressions to use for the update. Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate. Note that ClickHouse performs updates in the background, so they are not immediate.
""" """
assert kwargs, 'No fields specified for update' assert kwargs, "No fields specified for update"
self._verify_mutation_allowed() self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions) sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (self._model_cls.table_name(), fields, conditions)
self._database.raw(sql) self._database.raw(sql)
return self return self
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
''' """
Checks that the queryset's state allows mutations. Raises an AssertionError if not. Checks that the queryset's state allows mutations. Raises an AssertionError if not.
''' """
assert not self._limits, 'Mutations are not allowed after slicing the queryset' assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, 'Mutations are not allowed after calling final()' assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs): def aggregate(self, *args, **kwargs):
""" """
@ -619,7 +630,7 @@ class AggregateQuerySet(QuerySet):
At least one calculated field is required. At least one calculated field is required.
""" """
super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database) super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database)
assert calculated_fields, 'No calculated fields specified for aggregation' assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields self._fields = grouping_fields
self._grouping_fields = grouping_fields self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields self._calculated_fields = calculated_fields
@ -636,8 +647,9 @@ class AggregateQuerySet(QuerySet):
created with. created with.
""" """
for name in args: for name in args:
assert name in self._fields or name in self._calculated_fields, \ assert name in self._fields or name in self._calculated_fields, (
'Cannot group by `%s` since it is not included in the query' % name "Cannot group by `%s` since it is not included in the query" % name
)
qs = copy(self) qs = copy(self)
qs._grouping_fields = args qs._grouping_fields = args
return qs return qs
@ -652,22 +664,24 @@ class AggregateQuerySet(QuerySet):
""" """
This method is not supported on `AggregateQuerySet`. This method is not supported on `AggregateQuerySet`.
""" """
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self): def select_fields_as_sql(self):
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) return comma_join(
[str(f) for f in self._fields] + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self): def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model return self._database.select(self.as_sql()) # using an ad-hoc model
def count(self): def count(self):
""" """
Returns the number of rows after aggregation. Returns the number of rows after aggregation.
""" """
sql = u'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0
@ -682,7 +696,7 @@ class AggregateQuerySet(QuerySet):
return qs return qs
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet') raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import * # Expose only relevant classes in import *