diff --git a/src/infi/clickhouse_orm/fields.py b/src/infi/clickhouse_orm/fields.py index 8c11805..83b738e 100644 --- a/src/infi/clickhouse_orm/fields.py +++ b/src/infi/clickhouse_orm/fields.py @@ -8,15 +8,18 @@ from calendar import timegm from decimal import Decimal, localcontext from .utils import escape, parse_array, comma_join +from .query import F class Field(object): ''' Abstract base class for all field types. ''' - creation_counter = 0 - class_default = 0 - db_type = None + name = None # this is set by the parent model + parent = None # this is set by the parent model + creation_counter = 0 # used for keeping the model fields ordered + class_default = 0 # should be overridden by concrete subclasses + db_type = None # should be overridden by concrete subclasses def __init__(self, default=None, alias=None, materialized=None, readonly=None): assert (None, None) in {(default, alias), (alias, materialized), (default, materialized)}, \ @@ -96,6 +99,26 @@ class Field(object): inner_field = getattr(inner_field, 'inner_field', None) return False + # Support comparison operators (for use in querysets) + + def __lt__(self, other): + return F.less(self, other) + + def __le__(self, other): + return F.lessOrEquals(self, other) + + def __eq__(self, other): + return F.equals(self, other) + + def __ne__(self, other): + return F.notEquals(self, other) + + def __gt__(self, other): + return F.greater(self, other) + + def __ge__(self, other): + return F.greaterOrEquals(self, other) + class StringField(Field): diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index d008513..8a69949 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -43,7 +43,14 @@ class ModelBase(type): _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _defaults=defaults ) - return super(ModelBase, cls).__new__(cls, str(name), bases, attrs) + model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) + + # Let each field know its parent and its own name + for n, f in fields: + setattr(f, 'parent', model) + setattr(f, 'name', n) + + return model @classmethod def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): diff --git a/src/infi/clickhouse_orm/query.py b/src/infi/clickhouse_orm/query.py index 47bb3bf..73a45d1 100644 --- a/src/infi/clickhouse_orm/query.py +++ b/src/infi/clickhouse_orm/query.py @@ -4,9 +4,9 @@ import six import pytz from copy import copy from math import ceil - from .engines import CollapsingMergeTree -from .utils import comma_join +from datetime import date, datetime +from .utils import comma_join, is_iterable # TODO @@ -25,6 +25,11 @@ class Operator(object): """ raise NotImplementedError # pragma: no cover + def _value_to_sql(self, field, value, quote=True): + if isinstance(value, F): + return value.to_sql() + return field.to_db_string(field.to_python(value, pytz.utc), quote) + class SimpleOperator(Operator): """ @@ -37,7 +42,7 @@ class SimpleOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) - value = field.to_db_string(field.to_python(value, pytz.utc)) + value = self._value_to_sql(field, value) if value == '\\N' and self._sql_for_null is not None: return ' '.join([field_name, self._sql_for_null]) return ' '.join([field_name, self._sql_operator, value]) @@ -59,7 +64,7 @@ class InOperator(Operator): elif isinstance(value, six.string_types): pass else: - value = comma_join([field.to_db_string(field.to_python(v, pytz.utc)) for v in value]) + value = comma_join([self._value_to_sql(field, v) for v in value]) return '%s IN (%s)' % (field_name, value) @@ -75,7 +80,7 @@ class LikeOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) - value = field.to_db_string(field.to_python(value, pytz.utc), quote=False) + value = self._value_to_sql(field, value, quote=False) value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') pattern = self._pattern.format(value) if self._case_sensitive: @@ -91,7 +96,7 @@ class IExactOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) - value = field.to_db_string(field.to_python(value, pytz.utc)) + value = self._value_to_sql(field, value) return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) @@ -120,10 +125,8 @@ class BetweenOperator(Operator): def to_sql(self, model_cls, field_name, value): field = getattr(model_cls, field_name) - value0 = field.to_db_string( - field.to_python(value[0], pytz.utc)) if value[0] is not None or len(str(value[0])) > 0 else None - value1 = field.to_db_string( - field.to_python(value[1], pytz.utc)) if value[1] is not None or len(str(value[1])) > 0 else None + value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None + value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None if value0 and value1: return '%s BETWEEN %s AND %s' % (field_name, value0, value1) if value0 and not value1: @@ -156,11 +159,19 @@ register_operator('iendswith', LikeOperator('%{}', False)) register_operator('iexact', IExactOperator()) -class FOV(object): +class Cond(object): """ - An object for storing Field + Operator + Value. + An abstract object for storing a single query condition Field + Operator + Value. """ + def to_sql(self, model_cls): + raise NotImplementedError + + +class FieldCond(Cond): + """ + A single query condition made up of Field + Operator + Value. + """ def __init__(self, field_name, operator, value): self._field_name = field_name self._operator = _operators.get(operator) @@ -174,13 +185,300 @@ class FOV(object): return self._operator.to_sql(model_cls, self._field_name, self._value) +class F(Cond): + """ + Represents a database function call and its arguments. + It doubles as a query condition when the function returns a boolean result. + """ + + def __init__(self, name, *args): + self.name = name + self.args = args + + def to_sql(self, *args): + args_sql = comma_join(self.arg_to_sql(arg) for arg in self.args) + return self.name + '(' + args_sql + ')' + + def arg_to_sql(self, arg): + from .fields import Field, StringField, DateTimeField, DateField + if isinstance(arg, F): + return arg.to_sql() + if isinstance(arg, Field): + return "`%s`" % arg.name + if isinstance(arg, six.string_types): + return StringField().to_db_string(arg) + if isinstance(arg, datetime): + return DateTimeField().to_db_string(arg) + if isinstance(arg, date): + return DateField().to_db_string(arg) + if isinstance(arg, bool): + return six.text_type(int(arg)) + if arg is None: + return 'NULL' + if is_iterable(arg): + return '[' + comma_join(self.arg_to_sql(x) for x in arg) + ']' + return six.text_type(arg) + + # Support comparison operators with F objects + + def __lt__(self, other): + return F.less(self, other) + + def __le__(self, other): + return F.lessOrEquals(self, other) + + def __eq__(self, other): + return F.equals(self, other) + + def __ne__(self, other): + return F.notEquals(self, other) + + def __gt__(self, other): + return F.greater(self, other) + + def __ge__(self, other): + return F.greaterOrEquals(self, other) + + # Support arithmetic operations on F objects + + def __add__(self, other): + return F.plus(self, other) + + def __radd__(self, other): + return F.plus(other, self) + + def __sub__(self, other): + return F.minus(self, other) + + def __rsub__(self, other): + return F.minus(other, self) + + def __mul__(self, other): + return F.multiply(self, other) + + def __rmul__(self, other): + return F.multiply(other, self) + + def __div__(self, other): + return F.divide(self, other) + + def __rdiv__(self, other): + return F.divide(other, self) + + def __mod__(self, other): + return F.modulo(self, other) + + def __rmod__(self, other): + return F.modulo(other, self) + + def __neg__(self): + return F.negate(self) + + def __pos__(self): + return self + + # Arithmetic functions + + @staticmethod + def plus(a, b): + return F('plus', a, b) + + @staticmethod + def minus(a, b): + return F('minus', a, b) + + @staticmethod + def multiply(a, b): + return F('multiply', a, b) + + @staticmethod + def divide(a, b): + return F('divide', a, b) + + @staticmethod + def intDiv(a, b): + return F('intDiv', a, b) + + @staticmethod + def intDivOrZero(a, b): + return F('intDivOrZero', a, b) + + @staticmethod + def modulo(a, b): + return F('modulo', a, b) + + @staticmethod + def negate(a): + return F('negate', a) + + @staticmethod + def abs(a): + return F('abs', a) + + @staticmethod + def gcd(a, b): + return F('gcd',a, b) + + @staticmethod + def lcm(a, b): + return F('lcm', a, b) + + # Comparison functions + + @staticmethod + def equals(a, b): + return F('equals', a, b) + + @staticmethod + def notEquals(a, b): + return F('notEquals', a, b) + + @staticmethod + def less(a, b): + return F('less', a, b) + + @staticmethod + def greater(a, b): + return F('greater', a, b) + + @staticmethod + def lessOrEquals(a, b): + return F('lessOrEquals', a, b) + + @staticmethod + def greaterOrEquals(a, b): + return F('greaterOrEquals', a, b) + + # Functions for working with dates and times + + @staticmethod + def toYear(d): + return F('toYear', d) + + @staticmethod + def toMonth(d): + return F('toMonth', d) + + @staticmethod + def toDayOfMonth(d): + return F('toDayOfMonth', d) + + @staticmethod + def toDayOfWeek(d): + return F('toDayOfWeek', d) + + @staticmethod + def toHour(d): + return F('toHour', d) + + @staticmethod + def toMinute(d): + return F('toMinute', d) + + @staticmethod + def toSecond(d): + return F('toSecond', d) + + @staticmethod + def toMonday(d): + return F('toMonday', d) + + @staticmethod + def toStartOfMonth(d): + return F('toStartOfMonth', d) + + @staticmethod + def toStartOfQuarter(d): + return F('toStartOfQuarter', d) + + @staticmethod + def toStartOfYear(d): + return F('toStartOfYear', d) + + @staticmethod + def toStartOfMinute(d): + return F('toStartOfMinute', d) + + @staticmethod + def toStartOfFiveMinute(d): + return F('toStartOfFiveMinute', d) + + @staticmethod + def toStartOfFifteenMinutes(d): + return F('toStartOfFifteenMinutes', d) + + @staticmethod + def toStartOfHour(d): + return F('toStartOfHour', d) + + @staticmethod + def toStartOfDay(d): + return F('toStartOfDay', d) + + @staticmethod + def toTime(d): + return F('toTime', d) + + @staticmethod + def toRelativeYearNum(d, timezone=''): + return F('toRelativeYearNum', d, timezone) + + @staticmethod + def toRelativeMonthNum(d, timezone=''): + return F('toRelativeMonthNum', d, timezone) + + @staticmethod + def toRelativeWeekNum(d, timezone=''): + return F('toRelativeWeekNum', d, timezone) + + @staticmethod + def toRelativeDayNum(d, timezone=''): + return F('toRelativeDayNum', d, timezone) + + @staticmethod + def toRelativeHourNum(d, timezone=''): + return F('toRelativeHourNum', d, timezone) + + @staticmethod + def toRelativeMinuteNum(d, timezone=''): + return F('toRelativeMinuteNum', d, timezone) + + @staticmethod + def toRelativeSecondNum(d, timezone=''): + return F('toRelativeSecondNum', d, timezone) + + @staticmethod + def now(): + return F('now') + + @staticmethod + def today(): + return F('today') + + @staticmethod + def yesterday(d): + return F('yesterday') + + @staticmethod + def timeSlot(d): + return F('timeSlot', d) + + @staticmethod + def timeSlots(start_time, duration): + return F('timeSlots', start_time, duration) + + @staticmethod + def formatDateTime(d, format, timezone=''): + return F('formatDateTime', d, format, timezone) + + class Q(object): - AND_MODE = 'AND' - OR_MODE = 'OR' + AND_MODE = ' AND ' + OR_MODE = ' OR ' - def __init__(self, **filter_fields): - self._fovs = [self._build_fov(k, v) for k, v in six.iteritems(filter_fields)] + def __init__(self, *filter_funcs, **filter_fields): + self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in six.iteritems(filter_fields)] self._l_child = None self._r_child = None self._negate = False @@ -194,16 +492,16 @@ class Q(object): q._mode = mode # AND/OR return q - def _build_fov(self, key, value): + def _build_cond(self, key, value): if '__' in key: field_name, operator = key.rsplit('__', 1) else: field_name, operator = key, 'eq' - return FOV(field_name, operator, value) + return FieldCond(field_name, operator, value) def to_sql(self, model_cls): - if self._fovs: - sql = ' {} '.format(self._mode).join(fov.to_sql(model_cls) for fov in self._fovs) + if self._conds: + sql = self._mode.join(cond.to_sql(model_cls) for cond in self._conds) else: if self._l_child and self._r_child: sql = '({} {} {})'.format( @@ -353,10 +651,16 @@ class QuerySet(object): Add q object to query if it specified. """ qs = copy(self) - if q: - qs._q = list(self._q) + list(q) - else: - qs._q = list(self._q) + [Q(**filter_fields)] + qs._q = list(self._q) + for arg in q: + if isinstance(arg, Q): + qs._q.append(arg) + elif isinstance(arg, F): + qs._q.append(Q(arg)) + else: + raise TypeError('Invalid argument "%r" to queryset filter' % arg) + if filter_fields: + qs._q += [Q(**filter_fields)] return qs def exclude(self, **filter_fields): @@ -519,3 +823,5 @@ class AggregateQuerySet(QuerySet): sql = u'SELECT count() FROM (%s)' % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 + + diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index e3eb4bb..8d4f7ee 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -98,3 +98,14 @@ def comma_join(items): Joins an iterable of strings with commas. """ return ', '.join(items) + + +def is_iterable(obj): + """ + Checks if the given object is iterable. + """ + try: + iter(obj) + return True + except TypeError: + return False diff --git a/tests/base_test_with_data.py b/tests/base_test_with_data.py index c3fc376..eaacd7d 100644 --- a/tests/base_test_with_data.py +++ b/tests/base_test_with_data.py @@ -46,7 +46,7 @@ class Person(Model): data = [ {"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", "passport": 35052255}, - + {"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", "passport": 36052255}, diff --git a/tests/test_querysets.py b/tests/test_querysets.py index 4c76b68..4858e03 100644 --- a/tests/test_querysets.py +++ b/tests/test_querysets.py @@ -3,11 +3,13 @@ from __future__ import unicode_literals, print_function import unittest from infi.clickhouse_orm.database import Database -from infi.clickhouse_orm.query import Q +from infi.clickhouse_orm.query import Q, F from .base_test_with_data import * -import logging from datetime import date, datetime +from logging import getLogger +logger = getLogger('tests') + try: Enum # exists in Python 3.4+ except NameError: @@ -21,11 +23,11 @@ class QuerySetTestCase(TestCaseWithData): self.database.insert(self._sample_data()) def _test_qs(self, qs, expected_count): - logging.info(qs.as_sql()) + logger.info(qs.as_sql()) count = 0 for instance in qs: count += 1 - logging.info('\t[%d]\t%s' % (count, instance.to_dict())) + logger.info('\t[%d]\t%s' % (count, instance.to_dict())) self.assertEqual(count, expected_count) self.assertEqual(qs.count(), expected_count) @@ -290,6 +292,17 @@ class QuerySetTestCase(TestCaseWithData): for item, exp_color in zip(res, (Color.red, Color.green, Color.white, Color.blue)): self.assertEqual(exp_color, item.color) + def test_mixed_filter(self): + qs = Person.objects_in(self.database) + qs = qs.filter(Q(first_name='a'), F('greater', Person.height, 1.7), last_name='b') + self.assertEqual(qs.conditions_as_sql(), + "first_name = 'a' AND greater(`height`, 1.7) AND last_name = 'b'") + + def test_invalid_filter(self): + qs = Person.objects_in(self.database) + with self.assertRaises(TypeError): + qs.filter('foo') + class AggregateTestCase(TestCaseWithData): @@ -419,6 +432,136 @@ class AggregateTestCase(TestCaseWithData): self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1') +class FuncsTestCase(TestCaseWithData): + + def setUp(self): + super(FuncsTestCase, self).setUp() + self.database.insert(self._sample_data()) + + def _test_qs(self, qs, expected_count): + logger.info(qs.as_sql()) + count = 0 + for instance in qs: + count += 1 + logger.info('\t[%d]\t%s' % (count, instance.to_dict())) + self.assertEqual(count, expected_count) + self.assertEqual(qs.count(), expected_count) + + def _test_func(self, func, expected_value=None): + sql = 'SELECT %s AS value' % func.to_sql() + logger.info(sql) + result = list(self.database.select(sql)) + logger.info('\t==> %s', result[0].value) + if expected_value is not None: + self.assertEqual(result[0].value, expected_value) + + def test_func_to_sql(self): + # No args + self.assertEqual(F('func').to_sql(), 'func()') + # String args + self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')") + # Numeric args + self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)") + # Date args + self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func('2018-12-31')") + # Datetime args + self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func('1546214400')") + # Boolean args + self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)") + # Null args + self.assertEqual(F('func', None).to_sql(), "func(NULL)") + # Fields as args + self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)") + # Funcs as args + self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))') + # Iterables as args + x = [1, 'z', F('foo', 17)] + for y in [x, tuple(x), iter(x)]: + self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)") + self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([[1, 2], [3, 4]])") + + def test_filter_float_field(self): + qs = Person.objects_in(self.database) + # Height > 2 + self._test_qs(qs.filter(F.greater(Person.height, 2)), 0) + self._test_qs(qs.filter(Person.height > 2), 0) + # Height > 1.61 + self._test_qs(qs.filter(F.greater(Person.height, 1.61)), 96) + self._test_qs(qs.filter(Person.height > 1.61), 96) + # Height < 1.61 + self._test_qs(qs.filter(F.less(Person.height, 1.61)), 4) + self._test_qs(qs.filter(Person.height < 1.61), 4) + + def test_filter_date_field(self): + qs = Person.objects_in(self.database) + # People born on the 30th + self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3) + self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3) + self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3) + # People born on Sunday + self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18) + self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18) + self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18) + # People born on 1976-10-01 + self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1) + self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 01))), 1) + self._test_qs(qs.filter(Person.birthday == date(1976, 10, 01)), 1) + + def test_func_as_field_value(self): + qs = Person.objects_in(self.database) + self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96) + self._test_qs(qs.exclude(birthday=F.today()), 100) + self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100) + + def test_comparison_operators(self): + one = F.plus(1, 0) + two = F.plus(1, 1) + self._test_func(one > one, 0) + self._test_func(two > one, 1) + self._test_func(one >= two, 0) + self._test_func(one >= one, 1) + self._test_func(one < one, 0) + self._test_func(one < two, 1) + self._test_func(two <= one, 0) + self._test_func(one <= one, 1) + self._test_func(one == two, 0) + self._test_func(one == one, 1) + self._test_func(one != one, 0) + self._test_func(one != two, 1) + + def test_arithmetic_operators(self): + one = F.plus(1, 0) + two = F.plus(1, 1) + # + + self._test_func(one + two, 3) + self._test_func(one + 2, 3) + self._test_func(2 + one, 3) + # - + self._test_func(one - two, -1) + self._test_func(one - 2, -1) + self._test_func(1 - two, -1) + # * + self._test_func(one * two, 2) + self._test_func(one * 2, 2) + self._test_func(1 * two, 2) + # / + self._test_func(one / two, 0.5) + self._test_func(one / 2, 0.5) + self._test_func(1 / two, 0.5) + # % + self._test_func(one % two, 1) + self._test_func(one % 2, 1) + self._test_func(1 % two, 1) + # sign + self._test_func(-one, -1) + self._test_func(--one, 1) + self._test_func(+one, 1) + + + + + + Color = Enum('Color', u'red blue green yellow brown white black')