- funcs support in limit_by

- aggregate functions wip
This commit is contained in:
Itai Shirav 2020-02-09 19:20:56 +02:00
parent 93747f7758
commit 25c4a6710e
5 changed files with 149 additions and 7 deletions

View File

@ -6,7 +6,7 @@ One of the ORM's core concepts is _expressions_, which are composed using functi
- When defining [field options](field_options.md) - `default`, `alias` and `materialized`.
- In [table engine](table_engines.md) parameters for engines in the `MergeTree` family.
- In [queryset](querysets.md) methods such as `filter`, `exclude`, `order_by`, `extra`, `aggregate` and `limit_by`.
- In [queryset](querysets.md) methods such as `filter`, `exclude`, `order_by`, `aggregate` and `limit_by`.
Using Expressions
-----------------

View File

@ -1311,7 +1311,95 @@ class F(Cond, FunctionOperatorsMixin):
def toIPv6(ipv6):
return F('toIPv6', ipv6)
# Aggregate functions
@staticmethod
def any(x):
return F('any', x)
@staticmethod
def anyHeavy(x):
return F('anyHeavy', x)
@staticmethod
def anyLast(x):
return F('anyLast', x)
@staticmethod
def argMax(x, y):
return F('argMax', x, y)
@staticmethod
def argMin(x, y):
return F('argMin', x, y)
@staticmethod
def avg(x):
return F('avg', x)
@staticmethod
def corr(x, y):
return F('corr', x, y)
@staticmethod
def count():
return F('count')
@staticmethod
def covarPop(x, y):
return F('covarPop', x, y)
@staticmethod
def covarSamp(x, y):
return F('covarSamp', x, y)
@staticmethod
def kurtPop(x):
return F('kurtPop', x)
@staticmethod
def kurtSamp(x):
return F('kurtSamp', x)
@staticmethod
def min(x):
return F('min', x)
@staticmethod
def max(x):
return F('max', x)
@staticmethod
def skewPop(x):
return F('skewPop', x)
@staticmethod
def skewSamp(x):
return F('skewSamp', x)
@staticmethod
def sum(x):
return F('sum', x)
@staticmethod
def uniq(*args):
return F('uniq', *args)
@staticmethod
def uniqExact(*args):
return F('uniqExact', *args)
@staticmethod
def uniqHLL12(*args):
return F('uniqHLL12', *args)
@staticmethod
def varPop(x):
return F('varPop', x)
@staticmethod
def varSamp(x):
return F('varSamp', x)
# Higher-order functions

View File

@ -5,12 +5,11 @@ from copy import copy, deepcopy
from math import ceil
from .engines import CollapsingMergeTree
from datetime import date, datetime
from .utils import comma_join
from .utils import comma_join, string_or_func
# TODO
# - check that field names are valid
# - operators for arrays: length, has, empty
class Operator(object):
"""
@ -345,11 +344,11 @@ class QuerySet(object):
qs._limits = (start, stop - start)
return qs
def limit_by(self, offset_limit, *fields):
def limit_by(self, offset_limit, *fields_or_expr):
"""
Adds a LIMIT BY clause to the query.
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
- `fields`: the field names to use in the clause.
- `fields_or_expr`: the field names or expressions to use in the clause.
"""
if isinstance(offset_limit, int):
# Single limit
@ -359,7 +358,7 @@ class QuerySet(object):
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
qs = copy(self)
qs._limit_by = (offset, limit)
qs._limit_by_fields = fields
qs._limit_by_fields = fields_or_expr
return qs
def select_fields_as_sql(self):
@ -403,7 +402,7 @@ class QuerySet(object):
if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by
sql += ' BY %s' % comma_join('`%s`' % field for field in self._limit_by_fields)
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits:
sql += '\nLIMIT %d, %d' % self._limits

View File

@ -30,6 +30,15 @@ class FuncsTestCase(TestCaseWithData):
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
def _test_aggr(self, func, expected_value=None):
qs = Person.objects_in(self.database).aggregate(value=func)
logger.info(qs.as_sql())
result = list(qs)
logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None:
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
def test_func_to_sql(self):
# No args
self.assertEqual(F('func').to_sql(), 'func()')
@ -514,3 +523,28 @@ class FuncsTestCase(TestCaseWithData):
# These require support for tuples:
# self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16), ['192.168.0.0','192.168.255.255'])
# self._test_func(F.IPv6CIDRToRange(x, y))
def test_aggregate_funcs(self):
self._test_aggr(F.any(Person.first_name))
self._test_aggr(F.anyHeavy(Person.first_name))
self._test_aggr(F.anyLast(Person.first_name))
self._test_aggr(F.argMin(Person.first_name, Person.height))
self._test_aggr(F.argMax(Person.first_name, Person.height))
self._test_aggr(F.round(F.avg(Person.height), 4), sum(p.height for p in self._sample_data()) / 100)
self._test_aggr(F.corr(Person.height, Person.height), 1)
self._test_aggr(F.count(), 100)
self._test_aggr(F.round(F.covarPop(Person.height, Person.height), 2), 0)
self._test_aggr(F.round(F.covarSamp(Person.height, Person.height), 2), 0)
self._test_aggr(F.kurtPop(Person.height))
self._test_aggr(F.kurtSamp(Person.height))
self._test_aggr(F.min(Person.height), 1.59)
self._test_aggr(F.max(Person.height), 1.80)
self._test_aggr(F.skewPop(Person.height))
self._test_aggr(F.skewSamp(Person.height))
self._test_aggr(F.round(F.sum(Person.height), 4), sum(p.height for p in self._sample_data()))
self._test_aggr(F.uniq(Person.first_name, Person.last_name), 100)
self._test_aggr(F.uniqExact(Person.first_name, Person.last_name), 100)
self._test_aggr(F.uniqHLL12(Person.first_name, Person.last_name), 99)
self._test_aggr(F.varPop(Person.height))
self._test_aggr(F.varSamp(Person.height))

View File

@ -339,6 +339,22 @@ class AggregateTestCase(TestCaseWithData):
self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2)
def test_aggregate_with_filter__funcs(self):
# When filter comes before aggregate
qs = Person.objects_in(self.database).filter(Person.first_name=='Warren').aggregate(average_height=F.avg(Person.height), count=F.count())
print(qs.as_sql())
self.assertEqual(qs.count(), 1)
for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2)
# When filter comes after aggregate
qs = Person.objects_in(self.database).aggregate(average_height=F.avg(Person.height), count=F.count()).filter(Person.first_name=='Warren')
print(qs.as_sql())
self.assertEqual(qs.count(), 1)
for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2)
def test_aggregate_with_implicit_grouping(self):
qs = Person.objects_in(self.database).aggregate('first_name', average_height='avg(height)', count='count()')
print(qs.as_sql())
@ -453,6 +469,11 @@ class AggregateTestCase(TestCaseWithData):
order_by('first_name', '-height').limit_by(1, 'first_name')
self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen')
# Test with funcs
qs = Person.objects_in(self.database).aggregate('first_name', 'last_name', 'height', n=F.count()).\
order_by('first_name', '-height').limit_by(1, F.upper(Person.first_name))
self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen')
# Test with limit and offset, also mixing LIMIT with LIMIT BY
qs = Person.objects_in(self.database).filter(height__gt=1.67).order_by('height', 'first_name')
limited_qs = qs.limit_by((0, 3), 'height')