Functions WIP

This commit is contained in:
Itai Shirav 2018-10-16 16:42:30 +03:00
parent 9df82a44ec
commit 602d0671f1
6 changed files with 523 additions and 33 deletions

View File

@ -8,15 +8,18 @@ from calendar import timegm
from decimal import Decimal, localcontext
from .utils import escape, parse_array, comma_join
from .query import F
class Field(object):
'''
Abstract base class for all field types.
'''
creation_counter = 0
class_default = 0
db_type = None
name = None # this is set by the parent model
parent = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered
class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None):
assert (None, None) in {(default, alias), (alias, materialized), (default, materialized)}, \
@ -96,6 +99,26 @@ class Field(object):
inner_field = getattr(inner_field, 'inner_field', None)
return False
# Support comparison operators (for use in querysets)
def __lt__(self, other):
return F.less(self, other)
def __le__(self, other):
return F.lessOrEquals(self, other)
def __eq__(self, other):
return F.equals(self, other)
def __ne__(self, other):
return F.notEquals(self, other)
def __gt__(self, other):
return F.greater(self, other)
def __ge__(self, other):
return F.greaterOrEquals(self, other)
class StringField(Field):

View File

@ -43,7 +43,14 @@ class ModelBase(type):
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults
)
return super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
# Let each field know its parent and its own name
for n, f in fields:
setattr(f, 'parent', model)
setattr(f, 'name', n)
return model
@classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'):

View File

@ -4,9 +4,9 @@ import six
import pytz
from copy import copy
from math import ceil
from .engines import CollapsingMergeTree
from .utils import comma_join
from datetime import date, datetime
from .utils import comma_join, is_iterable
# TODO
@ -25,6 +25,11 @@ class Operator(object):
"""
raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True):
if isinstance(value, F):
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
class SimpleOperator(Operator):
"""
@ -37,7 +42,7 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc))
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value])
@ -59,7 +64,7 @@ class InOperator(Operator):
elif isinstance(value, six.string_types):
pass
else:
value = comma_join([field.to_db_string(field.to_python(v, pytz.utc)) for v in value])
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value)
@ -75,7 +80,7 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc), quote=False)
value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
pattern = self._pattern.format(value)
if self._case_sensitive:
@ -91,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc))
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value)
@ -120,10 +125,8 @@ class BetweenOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value0 = field.to_db_string(
field.to_python(value[0], pytz.utc)) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = field.to_db_string(
field.to_python(value[1], pytz.utc)) if value[1] is not None or len(str(value[1])) > 0 else None
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
if value0 and not value1:
@ -156,11 +159,19 @@ register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator())
class FOV(object):
class Cond(object):
"""
An object for storing Field + Operator + Value.
An abstract object for storing a single query condition Field + Operator + Value.
"""
def to_sql(self, model_cls):
raise NotImplementedError
class FieldCond(Cond):
"""
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value):
self._field_name = field_name
self._operator = _operators.get(operator)
@ -174,13 +185,300 @@ class FOV(object):
return self._operator.to_sql(model_cls, self._field_name, self._value)
class F(Cond):
"""
Represents a database function call and its arguments.
It doubles as a query condition when the function returns a boolean result.
"""
def __init__(self, name, *args):
self.name = name
self.args = args
def to_sql(self, *args):
args_sql = comma_join(self.arg_to_sql(arg) for arg in self.args)
return self.name + '(' + args_sql + ')'
def arg_to_sql(self, arg):
from .fields import Field, StringField, DateTimeField, DateField
if isinstance(arg, F):
return arg.to_sql()
if isinstance(arg, Field):
return "`%s`" % arg.name
if isinstance(arg, six.string_types):
return StringField().to_db_string(arg)
if isinstance(arg, datetime):
return DateTimeField().to_db_string(arg)
if isinstance(arg, date):
return DateField().to_db_string(arg)
if isinstance(arg, bool):
return six.text_type(int(arg))
if arg is None:
return 'NULL'
if is_iterable(arg):
return '[' + comma_join(self.arg_to_sql(x) for x in arg) + ']'
return six.text_type(arg)
# Support comparison operators with F objects
def __lt__(self, other):
return F.less(self, other)
def __le__(self, other):
return F.lessOrEquals(self, other)
def __eq__(self, other):
return F.equals(self, other)
def __ne__(self, other):
return F.notEquals(self, other)
def __gt__(self, other):
return F.greater(self, other)
def __ge__(self, other):
return F.greaterOrEquals(self, other)
# Support arithmetic operations on F objects
def __add__(self, other):
return F.plus(self, other)
def __radd__(self, other):
return F.plus(other, self)
def __sub__(self, other):
return F.minus(self, other)
def __rsub__(self, other):
return F.minus(other, self)
def __mul__(self, other):
return F.multiply(self, other)
def __rmul__(self, other):
return F.multiply(other, self)
def __div__(self, other):
return F.divide(self, other)
def __rdiv__(self, other):
return F.divide(other, self)
def __mod__(self, other):
return F.modulo(self, other)
def __rmod__(self, other):
return F.modulo(other, self)
def __neg__(self):
return F.negate(self)
def __pos__(self):
return self
# Arithmetic functions
@staticmethod
def plus(a, b):
return F('plus', a, b)
@staticmethod
def minus(a, b):
return F('minus', a, b)
@staticmethod
def multiply(a, b):
return F('multiply', a, b)
@staticmethod
def divide(a, b):
return F('divide', a, b)
@staticmethod
def intDiv(a, b):
return F('intDiv', a, b)
@staticmethod
def intDivOrZero(a, b):
return F('intDivOrZero', a, b)
@staticmethod
def modulo(a, b):
return F('modulo', a, b)
@staticmethod
def negate(a):
return F('negate', a)
@staticmethod
def abs(a):
return F('abs', a)
@staticmethod
def gcd(a, b):
return F('gcd',a, b)
@staticmethod
def lcm(a, b):
return F('lcm', a, b)
# Comparison functions
@staticmethod
def equals(a, b):
return F('equals', a, b)
@staticmethod
def notEquals(a, b):
return F('notEquals', a, b)
@staticmethod
def less(a, b):
return F('less', a, b)
@staticmethod
def greater(a, b):
return F('greater', a, b)
@staticmethod
def lessOrEquals(a, b):
return F('lessOrEquals', a, b)
@staticmethod
def greaterOrEquals(a, b):
return F('greaterOrEquals', a, b)
# Functions for working with dates and times
@staticmethod
def toYear(d):
return F('toYear', d)
@staticmethod
def toMonth(d):
return F('toMonth', d)
@staticmethod
def toDayOfMonth(d):
return F('toDayOfMonth', d)
@staticmethod
def toDayOfWeek(d):
return F('toDayOfWeek', d)
@staticmethod
def toHour(d):
return F('toHour', d)
@staticmethod
def toMinute(d):
return F('toMinute', d)
@staticmethod
def toSecond(d):
return F('toSecond', d)
@staticmethod
def toMonday(d):
return F('toMonday', d)
@staticmethod
def toStartOfMonth(d):
return F('toStartOfMonth', d)
@staticmethod
def toStartOfQuarter(d):
return F('toStartOfQuarter', d)
@staticmethod
def toStartOfYear(d):
return F('toStartOfYear', d)
@staticmethod
def toStartOfMinute(d):
return F('toStartOfMinute', d)
@staticmethod
def toStartOfFiveMinute(d):
return F('toStartOfFiveMinute', d)
@staticmethod
def toStartOfFifteenMinutes(d):
return F('toStartOfFifteenMinutes', d)
@staticmethod
def toStartOfHour(d):
return F('toStartOfHour', d)
@staticmethod
def toStartOfDay(d):
return F('toStartOfDay', d)
@staticmethod
def toTime(d):
return F('toTime', d)
@staticmethod
def toRelativeYearNum(d, timezone=''):
return F('toRelativeYearNum', d, timezone)
@staticmethod
def toRelativeMonthNum(d, timezone=''):
return F('toRelativeMonthNum', d, timezone)
@staticmethod
def toRelativeWeekNum(d, timezone=''):
return F('toRelativeWeekNum', d, timezone)
@staticmethod
def toRelativeDayNum(d, timezone=''):
return F('toRelativeDayNum', d, timezone)
@staticmethod
def toRelativeHourNum(d, timezone=''):
return F('toRelativeHourNum', d, timezone)
@staticmethod
def toRelativeMinuteNum(d, timezone=''):
return F('toRelativeMinuteNum', d, timezone)
@staticmethod
def toRelativeSecondNum(d, timezone=''):
return F('toRelativeSecondNum', d, timezone)
@staticmethod
def now():
return F('now')
@staticmethod
def today():
return F('today')
@staticmethod
def yesterday(d):
return F('yesterday')
@staticmethod
def timeSlot(d):
return F('timeSlot', d)
@staticmethod
def timeSlots(start_time, duration):
return F('timeSlots', start_time, duration)
@staticmethod
def formatDateTime(d, format, timezone=''):
return F('formatDateTime', d, format, timezone)
class Q(object):
AND_MODE = 'AND'
OR_MODE = 'OR'
AND_MODE = ' AND '
OR_MODE = ' OR '
def __init__(self, **filter_fields):
self._fovs = [self._build_fov(k, v) for k, v in six.iteritems(filter_fields)]
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in six.iteritems(filter_fields)]
self._l_child = None
self._r_child = None
self._negate = False
@ -194,16 +492,16 @@ class Q(object):
q._mode = mode # AND/OR
return q
def _build_fov(self, key, value):
def _build_cond(self, key, value):
if '__' in key:
field_name, operator = key.rsplit('__', 1)
else:
field_name, operator = key, 'eq'
return FOV(field_name, operator, value)
return FieldCond(field_name, operator, value)
def to_sql(self, model_cls):
if self._fovs:
sql = ' {} '.format(self._mode).join(fov.to_sql(model_cls) for fov in self._fovs)
if self._conds:
sql = self._mode.join(cond.to_sql(model_cls) for cond in self._conds)
else:
if self._l_child and self._r_child:
sql = '({} {} {})'.format(
@ -353,10 +651,16 @@ class QuerySet(object):
Add q object to query if it specified.
"""
qs = copy(self)
if q:
qs._q = list(self._q) + list(q)
else:
qs._q = list(self._q) + [Q(**filter_fields)]
qs._q = list(self._q)
for arg in q:
if isinstance(arg, Q):
qs._q.append(arg)
elif isinstance(arg, F):
qs._q.append(Q(arg))
else:
raise TypeError('Invalid argument "%r" to queryset filter' % arg)
if filter_fields:
qs._q += [Q(**filter_fields)]
return qs
def exclude(self, **filter_fields):
@ -519,3 +823,5 @@ class AggregateQuerySet(QuerySet):
sql = u'SELECT count() FROM (%s)' % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0

View File

@ -98,3 +98,14 @@ def comma_join(items):
Joins an iterable of strings with commas.
"""
return ', '.join(items)
def is_iterable(obj):
"""
Checks if the given object is iterable.
"""
try:
iter(obj)
return True
except TypeError:
return False

View File

@ -3,11 +3,13 @@ from __future__ import unicode_literals, print_function
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.query import Q
from infi.clickhouse_orm.query import Q, F
from .base_test_with_data import *
import logging
from datetime import date, datetime
from logging import getLogger
logger = getLogger('tests')
try:
Enum # exists in Python 3.4+
except NameError:
@ -21,11 +23,11 @@ class QuerySetTestCase(TestCaseWithData):
self.database.insert(self._sample_data())
def _test_qs(self, qs, expected_count):
logging.info(qs.as_sql())
logger.info(qs.as_sql())
count = 0
for instance in qs:
count += 1
logging.info('\t[%d]\t%s' % (count, instance.to_dict()))
logger.info('\t[%d]\t%s' % (count, instance.to_dict()))
self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count)
@ -290,6 +292,17 @@ class QuerySetTestCase(TestCaseWithData):
for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)):
self.assertEqual(exp_color, item.color)
def test_mixed_filter(self):
qs = Person.objects_in(self.database)
qs = qs.filter(Q(first_name='a'), F('greater', Person.height, 1.7), last_name='b')
self.assertEqual(qs.conditions_as_sql(),
"first_name = 'a' AND greater(`height`, 1.7) AND last_name = 'b'")
def test_invalid_filter(self):
qs = Person.objects_in(self.database)
with self.assertRaises(TypeError):
qs.filter('foo')
class AggregateTestCase(TestCaseWithData):
@ -419,6 +432,136 @@ class AggregateTestCase(TestCaseWithData):
self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1')
class FuncsTestCase(TestCaseWithData):
def setUp(self):
super(FuncsTestCase, self).setUp()
self.database.insert(self._sample_data())
def _test_qs(self, qs, expected_count):
logger.info(qs.as_sql())
count = 0
for instance in qs:
count += 1
logger.info('\t[%d]\t%s' % (count, instance.to_dict()))
self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=None):
sql = 'SELECT %s AS value' % func.to_sql()
logger.info(sql)
result = list(self.database.select(sql))
logger.info('\t==> %s', result[0].value)
if expected_value is not None:
self.assertEqual(result[0].value, expected_value)
def test_func_to_sql(self):
# No args
self.assertEqual(F('func').to_sql(), 'func()')
# String args
self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
# Numeric args
self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)")
# Date args
self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func('2018-12-31')")
# Datetime args
self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func('1546214400')")
# Boolean args
self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)")
# Null args
self.assertEqual(F('func', None).to_sql(), "func(NULL)")
# Fields as args
self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)")
# Funcs as args
self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))')
# Iterables as args
x = [1, 'z', F('foo', 17)]
for y in [x, tuple(x), iter(x)]:
self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([[1, 2], [3, 4]])")
def test_filter_float_field(self):
qs = Person.objects_in(self.database)
# Height > 2
self._test_qs(qs.filter(F.greater(Person.height, 2)), 0)
self._test_qs(qs.filter(Person.height > 2), 0)
# Height > 1.61
self._test_qs(qs.filter(F.greater(Person.height, 1.61)), 96)
self._test_qs(qs.filter(Person.height > 1.61), 96)
# Height < 1.61
self._test_qs(qs.filter(F.less(Person.height, 1.61)), 4)
self._test_qs(qs.filter(Person.height < 1.61), 4)
def test_filter_date_field(self):
qs = Person.objects_in(self.database)
# People born on the 30th
self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3)
self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3)
self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3)
# People born on Sunday
self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18)
self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18)
self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18)
# People born on 1976-10-01
self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1)
self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 01))), 1)
self._test_qs(qs.filter(Person.birthday == date(1976, 10, 01)), 1)
def test_func_as_field_value(self):
qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96)
self._test_qs(qs.exclude(birthday=F.today()), 100)
self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100)
def test_comparison_operators(self):
one = F.plus(1, 0)
two = F.plus(1, 1)
self._test_func(one > one, 0)
self._test_func(two > one, 1)
self._test_func(one >= two, 0)
self._test_func(one >= one, 1)
self._test_func(one < one, 0)
self._test_func(one < two, 1)
self._test_func(two <= one, 0)
self._test_func(one <= one, 1)
self._test_func(one == two, 0)
self._test_func(one == one, 1)
self._test_func(one != one, 0)
self._test_func(one != two, 1)
def test_arithmetic_operators(self):
one = F.plus(1, 0)
two = F.plus(1, 1)
# +
self._test_func(one + two, 3)
self._test_func(one + 2, 3)
self._test_func(2 + one, 3)
# -
self._test_func(one - two, -1)
self._test_func(one - 2, -1)
self._test_func(1 - two, -1)
# *
self._test_func(one * two, 2)
self._test_func(one * 2, 2)
self._test_func(1 * two, 2)
# /
self._test_func(one / two, 0.5)
self._test_func(one / 2, 0.5)
self._test_func(1 / two, 0.5)
# %
self._test_func(one % two, 1)
self._test_func(one % 2, 1)
self._test_func(1 % two, 1)
# sign
self._test_func(-one, -1)
self._test_func(--one, 1)
self._test_func(+one, 1)
Color = Enum('Color', u'red blue green yellow brown white black')