Chore: fix linting on query.py

This commit is contained in:
olliemath 2021-07-27 22:56:45 +01:00
parent 0213aed397
commit 13655da35d

View File

@ -3,13 +3,13 @@ from __future__ import unicode_literals
import pytz
from copy import copy, deepcopy
from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql
# TODO
# - check that field names are valid
class Operator(object):
"""
Base class for filtering operators.
@ -20,10 +20,11 @@ class Operator(object):
Subclasses should implement this method. It returns an SQL string
that applies this operator on the given field and value.
"""
raise NotImplementedError # pragma: no cover
raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True):
from clickhouse_orm.funcs import F
if isinstance(value, F):
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -41,9 +42,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value])
if value == "\\N" and self._sql_for_null is not None:
return " ".join([field_name, self._sql_for_null])
return " ".join([field_name, self._sql_operator, value])
class InOperator(Operator):
@ -63,7 +64,7 @@ class InOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value)
return "%s IN (%s)" % (field_name, value)
class LikeOperator(Operator):
@ -79,12 +80,12 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value)
if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern)
return "%s LIKE '%s'" % (field_name, pattern)
else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern)
return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field_name, pattern)
class IExactOperator(Operator):
@ -95,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value)
return "lowerUTF8(%s) = lowerUTF8(%s)" % (field_name, value)
class NotOperator(Operator):
@ -108,7 +109,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls, field_name, value):
# Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value)
return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator):
@ -126,35 +127,38 @@ class BetweenOperator(Operator):
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
return "%s BETWEEN %s AND %s" % (field_name, value0, value1)
if value0 and not value1:
return ' '.join([field_name, '>=', value0])
return " ".join([field_name, ">=", value0])
if value1 and not value0:
return ' '.join([field_name, '<=', value1])
return " ".join([field_name, "<=", value1])
# Define the set of builtin operators
_operators = {}
def register_operator(name, sql):
_operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL'))
register_operator('gt', SimpleOperator('>'))
register_operator('gte', SimpleOperator('>='))
register_operator('lt', SimpleOperator('<'))
register_operator('lte', SimpleOperator('<='))
register_operator('between', BetweenOperator())
register_operator('in', InOperator())
register_operator('not_in', NotOperator(InOperator()))
register_operator('contains', LikeOperator('%{}%'))
register_operator('startswith', LikeOperator('{}%'))
register_operator('endswith', LikeOperator('%{}'))
register_operator('icontains', LikeOperator('%{}%', False))
register_operator('istartswith', LikeOperator('{}%', False))
register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator())
register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator("gt", SimpleOperator(">"))
register_operator("gte", SimpleOperator(">="))
register_operator("lt", SimpleOperator("<"))
register_operator("lte", SimpleOperator("<="))
register_operator("between", BetweenOperator())
register_operator("in", InOperator())
register_operator("not_in", NotOperator(InOperator()))
register_operator("contains", LikeOperator("%{}%"))
register_operator("startswith", LikeOperator("{}%"))
register_operator("endswith", LikeOperator("%{}"))
register_operator("icontains", LikeOperator("%{}%", False))
register_operator("istartswith", LikeOperator("{}%", False))
register_operator("iendswith", LikeOperator("%{}", False))
register_operator("iexact", IExactOperator())
class Cond(object):
@ -170,19 +174,20 @@ class FieldCond(Cond):
"""
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value):
self._field_name = field_name
self._operator = _operators.get(operator)
if self._operator is None:
# The field name contains __ like my__field
self._field_name = field_name + '__' + operator
self._operator = _operators['eq']
self._field_name = field_name + "__" + operator
self._operator = _operators["eq"]
self._value = value
def to_sql(self, model_cls):
return self._operator.to_sql(model_cls, self._field_name, self._value)
def __deepcopy__(self, memodict={}):
def __deepcopy__(self, memo):
res = copy(self)
res._value = deepcopy(self._value)
return res
@ -190,8 +195,8 @@ class FieldCond(Cond):
class Q(object):
AND_MODE = 'AND'
OR_MODE = 'OR'
AND_MODE = "AND"
OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
@ -224,10 +229,10 @@ class Q(object):
return q
def _build_cond(self, key, value):
if '__' in key:
field_name, operator = key.rsplit('__', 1)
if "__" in key:
field_name, operator = key.rsplit("__", 1)
else:
field_name, operator = key, 'eq'
field_name, operator = key, "eq"
return FieldCond(field_name, operator, value)
def to_sql(self, model_cls):
@ -241,16 +246,16 @@ class Q(object):
if not condition_sql:
# Empty Q() object returns everything
sql = '1'
sql = "1"
elif len(condition_sql) == 1:
# Skip not needed brackets over single condition
sql = condition_sql[0]
else:
# Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql)
sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate:
sql = 'NOT (%s)' % sql
sql = "NOT (%s)" % sql
return sql
@ -268,7 +273,7 @@ class Q(object):
def __bool__(self):
return not self.is_empty
def __deepcopy__(self, memodict={}):
def __deepcopy__(self, memo):
q = Q()
q._conds = [deepcopy(cond) for cond in self._conds]
q._negate = self._negate
@ -318,7 +323,7 @@ class QuerySet(object):
"""
return bool(self.count())
def __nonzero__(self): # Python 2 compatibility
def __nonzero__(self): # Python 2 compatibility
return type(self).__bool__(self)
def __str__(self):
@ -327,17 +332,17 @@ class QuerySet(object):
def __getitem__(self, s):
if isinstance(s, int):
# Single index
assert s >= 0, 'negative indexes are not supported'
assert s >= 0, "negative indexes are not supported"
qs = copy(self)
qs._limits = (s, 1)
return next(iter(qs))
else:
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0
stop = s.stop or 2**63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end'
stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, "start of slice cannot be smaller than its end"
qs = copy(self)
qs._limits = (start, stop - start)
return qs
@ -353,7 +358,7 @@ class QuerySet(object):
offset_limit = (0, offset_limit)
offset = offset_limit[0]
limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
assert offset >= 0 and limit >= 0, "negative limits are not supported"
qs = copy(self)
qs._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr
@ -363,44 +368,44 @@ class QuerySet(object):
"""
Returns the selected fields or expressions as a SQL string.
"""
fields = '*'
fields = "*"
if self._fields:
fields = comma_join('`%s`' % field for field in self._fields)
fields = comma_join("`%s`" % field for field in self._fields)
return fields
def as_sql(self):
"""
Returns the whole query as a SQL string.
"""
distinct = 'DISTINCT ' if self._distinct else ''
final = ' FINAL' if self._final else ''
table_name = '`%s`' % self._model_cls.table_name()
distinct = "DISTINCT " if self._distinct else ""
final = " FINAL" if self._final else ""
table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model():
table_name = '`system`.' + table_name
table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = u'SELECT %s%s\nFROM %s%s' % params
sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True)
sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields)
sql += "\nGROUP BY %s" % comma_join("`%s`" % field for field in self._grouping_fields)
if self._grouping_with_totals:
sql += ' WITH TOTALS'
sql += " WITH TOTALS"
if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql()
sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
sql += "\nLIMIT %d, %d" % self._limit_by
sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits:
sql += '\nLIMIT %d, %d' % self._limits
sql += "\nLIMIT %d, %d" % self._limits
return sql
@ -408,10 +413,12 @@ class QuerySet(object):
"""
Returns the contents of the query's `ORDER BY` clause as a string.
"""
return comma_join([
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field)
for field in self._order_by
])
return comma_join(
[
"%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
for field in self._order_by
]
)
def conditions_as_sql(self, prewhere=False):
"""
@ -426,7 +433,7 @@ class QuerySet(object):
"""
if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -455,8 +462,8 @@ class QuerySet(object):
def _filter_or_exclude(self, *q, **kwargs):
from .funcs import F
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop("prewhere", False)
qs = copy(self)
@ -510,19 +517,20 @@ class QuerySet(object):
`pages_total`, `number` (of the current page), and `page_size`.
"""
from .database import Page
count = self.count()
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = pages_total
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
return Page(
objects=list(self[offset : offset + page_size]),
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
def distinct(self):
@ -540,8 +548,11 @@ class QuerySet(object):
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
"""
from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines')
raise TypeError(
"final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines"
)
qs = copy(self)
qs._final = True
@ -554,7 +565,7 @@ class QuerySet(object):
"""
self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions)
sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql)
return self
@ -564,22 +575,22 @@ class QuerySet(object):
Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate.
"""
assert kwargs, 'No fields specified for update'
assert kwargs, "No fields specified for update"
self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions)
sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (self._model_cls.table_name(), fields, conditions)
self._database.raw(sql)
return self
def _verify_mutation_allowed(self):
'''
"""
Checks that the queryset's state allows mutations. Raises an AssertionError if not.
'''
assert not self._limits, 'Mutations are not allowed after slicing the queryset'
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()'
"""
assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs):
"""
@ -619,7 +630,7 @@ class AggregateQuerySet(QuerySet):
At least one calculated field is required.
"""
super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database)
assert calculated_fields, 'No calculated fields specified for aggregation'
assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields
self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields
@ -636,8 +647,9 @@ class AggregateQuerySet(QuerySet):
created with.
"""
for name in args:
assert name in self._fields or name in self._calculated_fields, \
'Cannot group by `%s` since it is not included in the query' % name
assert name in self._fields or name in self._calculated_fields, (
"Cannot group by `%s` since it is not included in the query" % name
)
qs = copy(self)
qs._grouping_fields = args
return qs
@ -652,22 +664,24 @@ class AggregateQuerySet(QuerySet):
"""
This method is not supported on `AggregateQuerySet`.
"""
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet')
raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self):
"""
Returns the selected fields or expressions as a SQL string.
"""
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()])
return comma_join(
[str(f) for f in self._fields] + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model
return self._database.select(self.as_sql()) # using an ad-hoc model
def count(self):
"""
Returns the number of rows after aggregation.
"""
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -682,7 +696,7 @@ class AggregateQuerySet(QuerySet):
return qs
def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet')
raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import *