- funcs support in limit_by

- aggregate functions wip
This commit is contained in:
Itai Shirav 2020-02-09 19:20:56 +02:00
parent 93747f7758
commit 25c4a6710e
5 changed files with 149 additions and 7 deletions

View File

@ -6,7 +6,7 @@ One of the ORM's core concepts is _expressions_, which are composed using functi
- When defining [field options](field_options.md) - `default`, `alias` and `materialized`. - When defining [field options](field_options.md) - `default`, `alias` and `materialized`.
- In [table engine](table_engines.md) parameters for engines in the `MergeTree` family. - In [table engine](table_engines.md) parameters for engines in the `MergeTree` family.
- In [queryset](querysets.md) methods such as `filter`, `exclude`, `order_by`, `extra`, `aggregate` and `limit_by`. - In [queryset](querysets.md) methods such as `filter`, `exclude`, `order_by`, `aggregate` and `limit_by`.
Using Expressions Using Expressions
----------------- -----------------

View File

@ -1311,7 +1311,95 @@ class F(Cond, FunctionOperatorsMixin):
def toIPv6(ipv6): def toIPv6(ipv6):
return F('toIPv6', ipv6) return F('toIPv6', ipv6)
# Aggregate functions
@staticmethod
def any(x):
return F('any', x)
@staticmethod
def anyHeavy(x):
return F('anyHeavy', x)
@staticmethod
def anyLast(x):
return F('anyLast', x)
@staticmethod
def argMax(x, y):
return F('argMax', x, y)
@staticmethod
def argMin(x, y):
return F('argMin', x, y)
@staticmethod
def avg(x):
return F('avg', x)
@staticmethod
def corr(x, y):
return F('corr', x, y)
@staticmethod
def count():
return F('count')
@staticmethod
def covarPop(x, y):
return F('covarPop', x, y)
@staticmethod
def covarSamp(x, y):
return F('covarSamp', x, y)
@staticmethod
def kurtPop(x):
return F('kurtPop', x)
@staticmethod
def kurtSamp(x):
return F('kurtSamp', x)
@staticmethod
def min(x):
return F('min', x)
@staticmethod
def max(x):
return F('max', x)
@staticmethod
def skewPop(x):
return F('skewPop', x)
@staticmethod
def skewSamp(x):
return F('skewSamp', x)
@staticmethod
def sum(x):
return F('sum', x)
@staticmethod
def uniq(*args):
return F('uniq', *args)
@staticmethod
def uniqExact(*args):
return F('uniqExact', *args)
@staticmethod
def uniqHLL12(*args):
return F('uniqHLL12', *args)
@staticmethod
def varPop(x):
return F('varPop', x)
@staticmethod
def varSamp(x):
return F('varSamp', x)
# Higher-order functions # Higher-order functions

View File

@ -5,12 +5,11 @@ from copy import copy, deepcopy
from math import ceil from math import ceil
from .engines import CollapsingMergeTree from .engines import CollapsingMergeTree
from datetime import date, datetime from datetime import date, datetime
from .utils import comma_join from .utils import comma_join, string_or_func
# TODO # TODO
# - check that field names are valid # - check that field names are valid
# - operators for arrays: length, has, empty
class Operator(object): class Operator(object):
""" """
@ -345,11 +344,11 @@ class QuerySet(object):
qs._limits = (start, stop - start) qs._limits = (start, stop - start)
return qs return qs
def limit_by(self, offset_limit, *fields): def limit_by(self, offset_limit, *fields_or_expr):
""" """
Adds a LIMIT BY clause to the query. Adds a LIMIT BY clause to the query.
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit). - `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
- `fields`: the field names to use in the clause. - `fields_or_expr`: the field names or expressions to use in the clause.
""" """
if isinstance(offset_limit, int): if isinstance(offset_limit, int):
# Single limit # Single limit
@ -359,7 +358,7 @@ class QuerySet(object):
assert offset >= 0 and limit >= 0, 'negative limits are not supported' assert offset >= 0 and limit >= 0, 'negative limits are not supported'
qs = copy(self) qs = copy(self)
qs._limit_by = (offset, limit) qs._limit_by = (offset, limit)
qs._limit_by_fields = fields qs._limit_by_fields = fields_or_expr
return qs return qs
def select_fields_as_sql(self): def select_fields_as_sql(self):
@ -403,7 +402,7 @@ class QuerySet(object):
if self._limit_by: if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by sql += '\nLIMIT %d, %d' % self._limit_by
sql += ' BY %s' % comma_join('`%s`' % field for field in self._limit_by_fields) sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits: if self._limits:
sql += '\nLIMIT %d, %d' % self._limits sql += '\nLIMIT %d, %d' % self._limits

View File

@ -30,6 +30,15 @@ class FuncsTestCase(TestCaseWithData):
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
def _test_aggr(self, func, expected_value=None):
qs = Person.objects_in(self.database).aggregate(value=func)
logger.info(qs.as_sql())
result = list(qs)
logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None:
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
def test_func_to_sql(self): def test_func_to_sql(self):
# No args # No args
self.assertEqual(F('func').to_sql(), 'func()') self.assertEqual(F('func').to_sql(), 'func()')
@ -514,3 +523,28 @@ class FuncsTestCase(TestCaseWithData):
# These require support for tuples: # These require support for tuples:
# self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16), ['192.168.0.0','192.168.255.255']) # self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16), ['192.168.0.0','192.168.255.255'])
# self._test_func(F.IPv6CIDRToRange(x, y)) # self._test_func(F.IPv6CIDRToRange(x, y))
def test_aggregate_funcs(self):
self._test_aggr(F.any(Person.first_name))
self._test_aggr(F.anyHeavy(Person.first_name))
self._test_aggr(F.anyLast(Person.first_name))
self._test_aggr(F.argMin(Person.first_name, Person.height))
self._test_aggr(F.argMax(Person.first_name, Person.height))
self._test_aggr(F.round(F.avg(Person.height), 4), sum(p.height for p in self._sample_data()) / 100)
self._test_aggr(F.corr(Person.height, Person.height), 1)
self._test_aggr(F.count(), 100)
self._test_aggr(F.round(F.covarPop(Person.height, Person.height), 2), 0)
self._test_aggr(F.round(F.covarSamp(Person.height, Person.height), 2), 0)
self._test_aggr(F.kurtPop(Person.height))
self._test_aggr(F.kurtSamp(Person.height))
self._test_aggr(F.min(Person.height), 1.59)
self._test_aggr(F.max(Person.height), 1.80)
self._test_aggr(F.skewPop(Person.height))
self._test_aggr(F.skewSamp(Person.height))
self._test_aggr(F.round(F.sum(Person.height), 4), sum(p.height for p in self._sample_data()))
self._test_aggr(F.uniq(Person.first_name, Person.last_name), 100)
self._test_aggr(F.uniqExact(Person.first_name, Person.last_name), 100)
self._test_aggr(F.uniqHLL12(Person.first_name, Person.last_name), 99)
self._test_aggr(F.varPop(Person.height))
self._test_aggr(F.varSamp(Person.height))

View File

@ -339,6 +339,22 @@ class AggregateTestCase(TestCaseWithData):
self.assertAlmostEqual(row.average_height, 1.675, places=4) self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2) self.assertEqual(row.count, 2)
def test_aggregate_with_filter__funcs(self):
# When filter comes before aggregate
qs = Person.objects_in(self.database).filter(Person.first_name=='Warren').aggregate(average_height=F.avg(Person.height), count=F.count())
print(qs.as_sql())
self.assertEqual(qs.count(), 1)
for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2)
# When filter comes after aggregate
qs = Person.objects_in(self.database).aggregate(average_height=F.avg(Person.height), count=F.count()).filter(Person.first_name=='Warren')
print(qs.as_sql())
self.assertEqual(qs.count(), 1)
for row in qs:
self.assertAlmostEqual(row.average_height, 1.675, places=4)
self.assertEqual(row.count, 2)
def test_aggregate_with_implicit_grouping(self): def test_aggregate_with_implicit_grouping(self):
qs = Person.objects_in(self.database).aggregate('first_name', average_height='avg(height)', count='count()') qs = Person.objects_in(self.database).aggregate('first_name', average_height='avg(height)', count='count()')
print(qs.as_sql()) print(qs.as_sql())
@ -453,6 +469,11 @@ class AggregateTestCase(TestCaseWithData):
order_by('first_name', '-height').limit_by(1, 'first_name') order_by('first_name', '-height').limit_by(1, 'first_name')
self.assertEqual(qs.count(), 94) self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen') self.assertEqual(list(qs)[89].last_name, 'Bowen')
# Test with funcs
qs = Person.objects_in(self.database).aggregate('first_name', 'last_name', 'height', n=F.count()).\
order_by('first_name', '-height').limit_by(1, F.upper(Person.first_name))
self.assertEqual(qs.count(), 94)
self.assertEqual(list(qs)[89].last_name, 'Bowen')
# Test with limit and offset, also mixing LIMIT with LIMIT BY # Test with limit and offset, also mixing LIMIT with LIMIT BY
qs = Person.objects_in(self.database).filter(height__gt=1.67).order_by('height', 'first_name') qs = Person.objects_in(self.database).filter(height__gt=1.67).order_by('height', 'first_name')
limited_qs = qs.limit_by((0, 3), 'height') limited_qs = qs.limit_by((0, 3), 'height')