Functions WIP

This commit is contained in:
Itai Shirav 2018-10-16 16:42:30 +03:00
parent 9df82a44ec
commit 602d0671f1
6 changed files with 523 additions and 33 deletions

View File

@ -8,15 +8,18 @@ from calendar import timegm
from decimal import Decimal, localcontext from decimal import Decimal, localcontext
from .utils import escape, parse_array, comma_join from .utils import escape, parse_array, comma_join
from .query import F
class Field(object): class Field(object):
''' '''
Abstract base class for all field types. Abstract base class for all field types.
''' '''
creation_counter = 0 name = None # this is set by the parent model
class_default = 0 parent = None # this is set by the parent model
db_type = None creation_counter = 0 # used for keeping the model fields ordered
class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None): def __init__(self, default=None, alias=None, materialized=None, readonly=None):
assert (None, None) in {(default, alias), (alias, materialized), (default, materialized)}, \ assert (None, None) in {(default, alias), (alias, materialized), (default, materialized)}, \
@ -96,6 +99,26 @@ class Field(object):
inner_field = getattr(inner_field, 'inner_field', None) inner_field = getattr(inner_field, 'inner_field', None)
return False return False
# Support comparison operators (for use in querysets)
def __lt__(self, other):
return F.less(self, other)
def __le__(self, other):
return F.lessOrEquals(self, other)
def __eq__(self, other):
return F.equals(self, other)
def __ne__(self, other):
return F.notEquals(self, other)
def __gt__(self, other):
return F.greater(self, other)
def __ge__(self, other):
return F.greaterOrEquals(self, other)
class StringField(Field): class StringField(Field):

View File

@ -43,7 +43,14 @@ class ModelBase(type):
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults _defaults=defaults
) )
return super(ModelBase, cls).__new__(cls, str(name), bases, attrs) model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
# Let each field know its parent and its own name
for n, f in fields:
setattr(f, 'parent', model)
setattr(f, 'name', n)
return model
@classmethod @classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): def create_ad_hoc_model(cls, fields, model_name='AdHocModel'):

View File

@ -4,9 +4,9 @@ import six
import pytz import pytz
from copy import copy from copy import copy
from math import ceil from math import ceil
from .engines import CollapsingMergeTree from .engines import CollapsingMergeTree
from .utils import comma_join from datetime import date, datetime
from .utils import comma_join, is_iterable
# TODO # TODO
@ -25,6 +25,11 @@ class Operator(object):
""" """
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True):
if isinstance(value, F):
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
class SimpleOperator(Operator): class SimpleOperator(Operator):
""" """
@ -37,7 +42,7 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc)) value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None: if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null]) return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value]) return ' '.join([field_name, self._sql_operator, value])
@ -59,7 +64,7 @@ class InOperator(Operator):
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
pass pass
else: else:
value = comma_join([field.to_db_string(field.to_python(v, pytz.utc)) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value) return '%s IN (%s)' % (field_name, value)
@ -75,7 +80,7 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc), quote=False) value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
pattern = self._pattern.format(value) pattern = self._pattern.format(value)
if self._case_sensitive: if self._case_sensitive:
@ -91,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = field.to_db_string(field.to_python(value, pytz.utc)) value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value)
@ -120,10 +125,8 @@ class BetweenOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value0 = field.to_db_string( value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
field.to_python(value[0], pytz.utc)) if value[0] is not None or len(str(value[0])) > 0 else None value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
value1 = field.to_db_string(
field.to_python(value[1], pytz.utc)) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1: if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1) return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
if value0 and not value1: if value0 and not value1:
@ -156,11 +159,19 @@ register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator()) register_operator('iexact', IExactOperator())
class FOV(object): class Cond(object):
""" """
An object for storing Field + Operator + Value. An abstract object for storing a single query condition Field + Operator + Value.
""" """
def to_sql(self, model_cls):
raise NotImplementedError
class FieldCond(Cond):
"""
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value): def __init__(self, field_name, operator, value):
self._field_name = field_name self._field_name = field_name
self._operator = _operators.get(operator) self._operator = _operators.get(operator)
@ -174,13 +185,300 @@ class FOV(object):
return self._operator.to_sql(model_cls, self._field_name, self._value) return self._operator.to_sql(model_cls, self._field_name, self._value)
class F(Cond):
"""
Represents a database function call and its arguments.
It doubles as a query condition when the function returns a boolean result.
"""
def __init__(self, name, *args):
self.name = name
self.args = args
def to_sql(self, *args):
args_sql = comma_join(self.arg_to_sql(arg) for arg in self.args)
return self.name + '(' + args_sql + ')'
def arg_to_sql(self, arg):
from .fields import Field, StringField, DateTimeField, DateField
if isinstance(arg, F):
return arg.to_sql()
if isinstance(arg, Field):
return "`%s`" % arg.name
if isinstance(arg, six.string_types):
return StringField().to_db_string(arg)
if isinstance(arg, datetime):
return DateTimeField().to_db_string(arg)
if isinstance(arg, date):
return DateField().to_db_string(arg)
if isinstance(arg, bool):
return six.text_type(int(arg))
if arg is None:
return 'NULL'
if is_iterable(arg):
return '[' + comma_join(self.arg_to_sql(x) for x in arg) + ']'
return six.text_type(arg)
# Support comparison operators with F objects
def __lt__(self, other):
return F.less(self, other)
def __le__(self, other):
return F.lessOrEquals(self, other)
def __eq__(self, other):
return F.equals(self, other)
def __ne__(self, other):
return F.notEquals(self, other)
def __gt__(self, other):
return F.greater(self, other)
def __ge__(self, other):
return F.greaterOrEquals(self, other)
# Support arithmetic operations on F objects
def __add__(self, other):
return F.plus(self, other)
def __radd__(self, other):
return F.plus(other, self)
def __sub__(self, other):
return F.minus(self, other)
def __rsub__(self, other):
return F.minus(other, self)
def __mul__(self, other):
return F.multiply(self, other)
def __rmul__(self, other):
return F.multiply(other, self)
def __div__(self, other):
return F.divide(self, other)
def __rdiv__(self, other):
return F.divide(other, self)
def __mod__(self, other):
return F.modulo(self, other)
def __rmod__(self, other):
return F.modulo(other, self)
def __neg__(self):
return F.negate(self)
def __pos__(self):
return self
# Arithmetic functions
@staticmethod
def plus(a, b):
return F('plus', a, b)
@staticmethod
def minus(a, b):
return F('minus', a, b)
@staticmethod
def multiply(a, b):
return F('multiply', a, b)
@staticmethod
def divide(a, b):
return F('divide', a, b)
@staticmethod
def intDiv(a, b):
return F('intDiv', a, b)
@staticmethod
def intDivOrZero(a, b):
return F('intDivOrZero', a, b)
@staticmethod
def modulo(a, b):
return F('modulo', a, b)
@staticmethod
def negate(a):
return F('negate', a)
@staticmethod
def abs(a):
return F('abs', a)
@staticmethod
def gcd(a, b):
return F('gcd',a, b)
@staticmethod
def lcm(a, b):
return F('lcm', a, b)
# Comparison functions
@staticmethod
def equals(a, b):
return F('equals', a, b)
@staticmethod
def notEquals(a, b):
return F('notEquals', a, b)
@staticmethod
def less(a, b):
return F('less', a, b)
@staticmethod
def greater(a, b):
return F('greater', a, b)
@staticmethod
def lessOrEquals(a, b):
return F('lessOrEquals', a, b)
@staticmethod
def greaterOrEquals(a, b):
return F('greaterOrEquals', a, b)
# Functions for working with dates and times
@staticmethod
def toYear(d):
return F('toYear', d)
@staticmethod
def toMonth(d):
return F('toMonth', d)
@staticmethod
def toDayOfMonth(d):
return F('toDayOfMonth', d)
@staticmethod
def toDayOfWeek(d):
return F('toDayOfWeek', d)
@staticmethod
def toHour(d):
return F('toHour', d)
@staticmethod
def toMinute(d):
return F('toMinute', d)
@staticmethod
def toSecond(d):
return F('toSecond', d)
@staticmethod
def toMonday(d):
return F('toMonday', d)
@staticmethod
def toStartOfMonth(d):
return F('toStartOfMonth', d)
@staticmethod
def toStartOfQuarter(d):
return F('toStartOfQuarter', d)
@staticmethod
def toStartOfYear(d):
return F('toStartOfYear', d)
@staticmethod
def toStartOfMinute(d):
return F('toStartOfMinute', d)
@staticmethod
def toStartOfFiveMinute(d):
return F('toStartOfFiveMinute', d)
@staticmethod
def toStartOfFifteenMinutes(d):
return F('toStartOfFifteenMinutes', d)
@staticmethod
def toStartOfHour(d):
return F('toStartOfHour', d)
@staticmethod
def toStartOfDay(d):
return F('toStartOfDay', d)
@staticmethod
def toTime(d):
return F('toTime', d)
@staticmethod
def toRelativeYearNum(d, timezone=''):
return F('toRelativeYearNum', d, timezone)
@staticmethod
def toRelativeMonthNum(d, timezone=''):
return F('toRelativeMonthNum', d, timezone)
@staticmethod
def toRelativeWeekNum(d, timezone=''):
return F('toRelativeWeekNum', d, timezone)
@staticmethod
def toRelativeDayNum(d, timezone=''):
return F('toRelativeDayNum', d, timezone)
@staticmethod
def toRelativeHourNum(d, timezone=''):
return F('toRelativeHourNum', d, timezone)
@staticmethod
def toRelativeMinuteNum(d, timezone=''):
return F('toRelativeMinuteNum', d, timezone)
@staticmethod
def toRelativeSecondNum(d, timezone=''):
return F('toRelativeSecondNum', d, timezone)
@staticmethod
def now():
return F('now')
@staticmethod
def today():
return F('today')
@staticmethod
def yesterday(d):
return F('yesterday')
@staticmethod
def timeSlot(d):
return F('timeSlot', d)
@staticmethod
def timeSlots(start_time, duration):
return F('timeSlots', start_time, duration)
@staticmethod
def formatDateTime(d, format, timezone=''):
return F('formatDateTime', d, format, timezone)
class Q(object): class Q(object):
AND_MODE = 'AND' AND_MODE = ' AND '
OR_MODE = 'OR' OR_MODE = ' OR '
def __init__(self, **filter_fields): def __init__(self, *filter_funcs, **filter_fields):
self._fovs = [self._build_fov(k, v) for k, v in six.iteritems(filter_fields)] self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in six.iteritems(filter_fields)]
self._l_child = None self._l_child = None
self._r_child = None self._r_child = None
self._negate = False self._negate = False
@ -194,16 +492,16 @@ class Q(object):
q._mode = mode # AND/OR q._mode = mode # AND/OR
return q return q
def _build_fov(self, key, value): def _build_cond(self, key, value):
if '__' in key: if '__' in key:
field_name, operator = key.rsplit('__', 1) field_name, operator = key.rsplit('__', 1)
else: else:
field_name, operator = key, 'eq' field_name, operator = key, 'eq'
return FOV(field_name, operator, value) return FieldCond(field_name, operator, value)
def to_sql(self, model_cls): def to_sql(self, model_cls):
if self._fovs: if self._conds:
sql = ' {} '.format(self._mode).join(fov.to_sql(model_cls) for fov in self._fovs) sql = self._mode.join(cond.to_sql(model_cls) for cond in self._conds)
else: else:
if self._l_child and self._r_child: if self._l_child and self._r_child:
sql = '({} {} {})'.format( sql = '({} {} {})'.format(
@ -353,10 +651,16 @@ class QuerySet(object):
Add q object to query if it specified. Add q object to query if it specified.
""" """
qs = copy(self) qs = copy(self)
if q: qs._q = list(self._q)
qs._q = list(self._q) + list(q) for arg in q:
else: if isinstance(arg, Q):
qs._q = list(self._q) + [Q(**filter_fields)] qs._q.append(arg)
elif isinstance(arg, F):
qs._q.append(Q(arg))
else:
raise TypeError('Invalid argument "%r" to queryset filter' % arg)
if filter_fields:
qs._q += [Q(**filter_fields)]
return qs return qs
def exclude(self, **filter_fields): def exclude(self, **filter_fields):
@ -519,3 +823,5 @@ class AggregateQuerySet(QuerySet):
sql = u'SELECT count() FROM (%s)' % self.as_sql() sql = u'SELECT count() FROM (%s)' % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0

View File

@ -98,3 +98,14 @@ def comma_join(items):
Joins an iterable of strings with commas. Joins an iterable of strings with commas.
""" """
return ', '.join(items) return ', '.join(items)
def is_iterable(obj):
"""
Checks if the given object is iterable.
"""
try:
iter(obj)
return True
except TypeError:
return False

View File

@ -46,7 +46,7 @@ class Person(Model):
data = [ data = [
{"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", {"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63",
"passport": 35052255}, "passport": 35052255},
{"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", {"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74",
"passport": 36052255}, "passport": 36052255},

View File

@ -3,11 +3,13 @@ from __future__ import unicode_literals, print_function
import unittest import unittest
from infi.clickhouse_orm.database import Database from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.query import Q from infi.clickhouse_orm.query import Q, F
from .base_test_with_data import * from .base_test_with_data import *
import logging
from datetime import date, datetime from datetime import date, datetime
from logging import getLogger
logger = getLogger('tests')
try: try:
Enum # exists in Python 3.4+ Enum # exists in Python 3.4+
except NameError: except NameError:
@ -21,11 +23,11 @@ class QuerySetTestCase(TestCaseWithData):
self.database.insert(self._sample_data()) self.database.insert(self._sample_data())
def _test_qs(self, qs, expected_count): def _test_qs(self, qs, expected_count):
logging.info(qs.as_sql()) logger.info(qs.as_sql())
count = 0 count = 0
for instance in qs: for instance in qs:
count += 1 count += 1
logging.info('\t[%d]\t%s' % (count, instance.to_dict())) logger.info('\t[%d]\t%s' % (count, instance.to_dict()))
self.assertEqual(count, expected_count) self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count) self.assertEqual(qs.count(), expected_count)
@ -290,6 +292,17 @@ class QuerySetTestCase(TestCaseWithData):
for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)): for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)):
self.assertEqual(exp_color, item.color) self.assertEqual(exp_color, item.color)
def test_mixed_filter(self):
qs = Person.objects_in(self.database)
qs = qs.filter(Q(first_name='a'), F('greater', Person.height, 1.7), last_name='b')
self.assertEqual(qs.conditions_as_sql(),
"first_name = 'a' AND greater(`height`, 1.7) AND last_name = 'b'")
def test_invalid_filter(self):
qs = Person.objects_in(self.database)
with self.assertRaises(TypeError):
qs.filter('foo')
class AggregateTestCase(TestCaseWithData): class AggregateTestCase(TestCaseWithData):
@ -419,6 +432,136 @@ class AggregateTestCase(TestCaseWithData):
self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1') self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1')
class FuncsTestCase(TestCaseWithData):
def setUp(self):
super(FuncsTestCase, self).setUp()
self.database.insert(self._sample_data())
def _test_qs(self, qs, expected_count):
logger.info(qs.as_sql())
count = 0
for instance in qs:
count += 1
logger.info('\t[%d]\t%s' % (count, instance.to_dict()))
self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=None):
sql = 'SELECT %s AS value' % func.to_sql()
logger.info(sql)
result = list(self.database.select(sql))
logger.info('\t==> %s', result[0].value)
if expected_value is not None:
self.assertEqual(result[0].value, expected_value)
def test_func_to_sql(self):
# No args
self.assertEqual(F('func').to_sql(), 'func()')
# String args
self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
# Numeric args
self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)")
# Date args
self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func('2018-12-31')")
# Datetime args
self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func('1546214400')")
# Boolean args
self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)")
# Null args
self.assertEqual(F('func', None).to_sql(), "func(NULL)")
# Fields as args
self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)")
# Funcs as args
self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))')
# Iterables as args
x = [1, 'z', F('foo', 17)]
for y in [x, tuple(x), iter(x)]:
self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([[1, 2], [3, 4]])")
def test_filter_float_field(self):
qs = Person.objects_in(self.database)
# Height > 2
self._test_qs(qs.filter(F.greater(Person.height, 2)), 0)
self._test_qs(qs.filter(Person.height > 2), 0)
# Height > 1.61
self._test_qs(qs.filter(F.greater(Person.height, 1.61)), 96)
self._test_qs(qs.filter(Person.height > 1.61), 96)
# Height < 1.61
self._test_qs(qs.filter(F.less(Person.height, 1.61)), 4)
self._test_qs(qs.filter(Person.height < 1.61), 4)
def test_filter_date_field(self):
qs = Person.objects_in(self.database)
# People born on the 30th
self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3)
self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3)
self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3)
# People born on Sunday
self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18)
self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18)
self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18)
# People born on 1976-10-01
self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1)
self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 01))), 1)
self._test_qs(qs.filter(Person.birthday == date(1976, 10, 01)), 1)
def test_func_as_field_value(self):
qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96)
self._test_qs(qs.exclude(birthday=F.today()), 100)
self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100)
def test_comparison_operators(self):
one = F.plus(1, 0)
two = F.plus(1, 1)
self._test_func(one > one, 0)
self._test_func(two > one, 1)
self._test_func(one >= two, 0)
self._test_func(one >= one, 1)
self._test_func(one < one, 0)
self._test_func(one < two, 1)
self._test_func(two <= one, 0)
self._test_func(one <= one, 1)
self._test_func(one == two, 0)
self._test_func(one == one, 1)
self._test_func(one != one, 0)
self._test_func(one != two, 1)
def test_arithmetic_operators(self):
one = F.plus(1, 0)
two = F.plus(1, 1)
# +
self._test_func(one + two, 3)
self._test_func(one + 2, 3)
self._test_func(2 + one, 3)
# -
self._test_func(one - two, -1)
self._test_func(one - 2, -1)
self._test_func(1 - two, -1)
# *
self._test_func(one * two, 2)
self._test_func(one * 2, 2)
self._test_func(1 * two, 2)
# /
self._test_func(one / two, 0.5)
self._test_func(one / 2, 0.5)
self._test_func(1 / two, 0.5)
# %
self._test_func(one % two, 1)
self._test_func(one % 2, 1)
self._test_func(1 % two, 1)
# sign
self._test_func(-one, -1)
self._test_func(--one, 1)
self._test_func(+one, 1)
Color = Enum('Color', u'red blue green yellow brown white black') Color = Enum('Color', u'red blue green yellow brown white black')