From 9f36b17fee11f222e619278d9e9ccae06d5ee8b3 Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Mon, 10 Feb 2020 10:06:21 +0200 Subject: [PATCH] - move NO_VALUE to utils - dynamic generation of func variants (...OrZero, ...OrNull) --- docs/field_options.md | 4 +- src/infi/clickhouse_orm/funcs.py | 239 +++++++++++++++++++----------- src/infi/clickhouse_orm/models.py | 13 +- src/infi/clickhouse_orm/utils.py | 11 ++ tests/test_funcs.py | 35 ++++- 5 files changed, 196 insertions(+), 106 deletions(-) diff --git a/docs/field_options.md b/docs/field_options.md index db3e58f..3905afd 100644 --- a/docs/field_options.md +++ b/docs/field_options.md @@ -25,7 +25,7 @@ class Event(models.Model): engine = engines.Memory() ... ``` -When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` instead. For example: +When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` instead. For example: ```python >>> event = Event() >>> print(event.to_dict()) @@ -63,7 +63,7 @@ db.select('SELECT created, created_date, username, name FROM $db.event', model_c # created_date and username will contain a default value db.select('SELECT * FROM $db.event', model_class=Event) ``` -When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` since their real values can only be known after insertion to the database. +When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database. ## codec diff --git a/src/infi/clickhouse_orm/funcs.py b/src/infi/clickhouse_orm/funcs.py index 065e46e..9e02b0c 100644 --- a/src/infi/clickhouse_orm/funcs.py +++ b/src/infi/clickhouse_orm/funcs.py @@ -1,7 +1,9 @@ from datetime import date, datetime, tzinfo -import functools +from functools import wraps +from inspect import signature, Parameter +from types import FunctionType -from .utils import is_iterable, comma_join +from .utils import is_iterable, comma_join, NO_VALUE from .query import Cond @@ -9,7 +11,7 @@ def binary_operator(func): """ Decorates a function to mark it as a binary operator. """ - @functools.wraps(func) + @wraps(func) def wrapper(*args, **kwargs): ret = func(*args, **kwargs) ret.is_binary_operator = True @@ -17,6 +19,29 @@ def binary_operator(func): return wrapper +def type_conversion(func): + """ + Decorates a function to mark it as a type conversion function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + wrapper.f_type = 'type_conversion' + return wrapper + + +def aggregate(func): + """ + Decorates a function to mark it as an aggregate function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + wrapper.f_type = 'aggregate' + return wrapper + + + class FunctionOperatorsMixin(object): """ A mixin for implementing Python operators using F objects. @@ -104,7 +129,57 @@ class FunctionOperatorsMixin(object): return F._not(self) -class F(Cond, FunctionOperatorsMixin): +class FMeta(type): + + FUNCTION_COMBINATORS = { + 'type_conversion': [ + {'suffix': 'OrZero'}, + {'suffix': 'OrNull'}, + ], + 'aggregate': [ + {'suffix': 'OrDefault'}, + {'suffix': 'OrNull'}, + {'suffix': 'If', 'args': ['cond']}, + {'suffix': 'OrDefaultIf', 'args': ['cond']}, + {'suffix': 'OrNullIf', 'args': ['cond']}, + ] + } + + def __init__(cls, name, bases, dct): + for name, obj in dct.items(): + if hasattr(obj, '__func__'): + f_type = getattr(obj.__func__, 'f_type', '') + for combinator in FMeta.FUNCTION_COMBINATORS.get(f_type, []): + new_name = name + combinator['suffix'] + FMeta._add_func(cls, obj.__func__, new_name, combinator.get('args')) + + @staticmethod + def _add_func(cls, base_func, new_name, extra_args): + """ + Adds a new func to the cls, based on the signature of the given base_func but with a new name. + """ + # Get the function's signature + sig = signature(base_func) + new_sig = str(sig)[1 : -1] # omit the parentheses + args = comma_join(sig.parameters) + # Add extra args + if extra_args: + if args: + args = comma_join([args] + extra_args) + new_sig = comma_join([new_sig] + extra_args) + else: + args = comma_join(extra_args) + new_sig = comma_join(extra_args) + # Get default values for args + argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty) + # Build the new function + new_code = compile(f'def {new_name}({new_sig}): return F("{new_name}", {args})', __file__, 'exec') + new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs) + # Attach to class + setattr(cls, new_name, new_func) + + +class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): """ Represents a database function call and its arguments. It doubles as a query condition when the function returns a boolean result. @@ -135,7 +210,7 @@ class F(Cond, FunctionOperatorsMixin): else: prefix = self.name sep = ', ' - arg_strs = (F._arg_to_sql(arg) for arg in self.args) + arg_strs = (F._arg_to_sql(arg) for arg in self.args if arg != NO_VALUE) return prefix + '(' + sep.join(arg_strs) + ')' @staticmethod @@ -392,168 +467,143 @@ class F(Cond, FunctionOperatorsMixin): return F('formatDateTime', d, format, timezone) @staticmethod - def addDays(d, n, timezone=None): - return F('addDays', d, n, timezone) if timezone else F('addDays', d, n) + def addDays(d, n, timezone=NO_VALUE): + return F('addDays', d, n, timezone) @staticmethod - def addHours(d, n, timezone=None): - return F('addHours', d, n, timezone) if timezone else F('addHours', d, n) + def addHours(d, n, timezone=NO_VALUE): + return F('addHours', d, n, timezone) @staticmethod - def addMinutes(d, n, timezone=None): - return F('addMinutes', d, n, timezone) if timezone else F('addMinutes', d, n) + def addMinutes(d, n, timezone=NO_VALUE): + return F('addMinutes', d, n, timezone) @staticmethod - def addMonths(d, n, timezone=None): - return F('addMonths', d, n, timezone) if timezone else F('addMonths', d, n) + def addMonths(d, n, timezone=NO_VALUE): + return F('addMonths', d, n, timezone) @staticmethod - def addQuarters(d, n, timezone=None): - return F('addQuarters', d, n, timezone) if timezone else F('addQuarters', d, n) + def addQuarters(d, n, timezone=NO_VALUE): + return F('addQuarters', d, n, timezone) @staticmethod - def addSeconds(d, n, timezone=None): - return F('addSeconds', d, n, timezone) if timezone else F('addSeconds', d, n) + def addSeconds(d, n, timezone=NO_VALUE): + return F('addSeconds', d, n, timezone) @staticmethod - def addWeeks(d, n, timezone=None): - return F('addWeeks', d, n, timezone) if timezone else F('addWeeks', d, n) + def addWeeks(d, n, timezone=NO_VALUE): + return F('addWeeks', d, n, timezone) @staticmethod - def addYears(d, n, timezone=None): - return F('addYears', d, n, timezone) if timezone else F('addYears', d, n) + def addYears(d, n, timezone=NO_VALUE): + return F('addYears', d, n, timezone) @staticmethod - def subtractDays(d, n, timezone=None): - return F('subtractDays', d, n, timezone) if timezone else F('subtractDays', d, n) + def subtractDays(d, n, timezone=NO_VALUE): + return F('subtractDays', d, n, timezone) @staticmethod - def subtractHours(d, n, timezone=None): - return F('subtractHours', d, n, timezone) if timezone else F('subtractHours', d, n) + def subtractHours(d, n, timezone=NO_VALUE): + return F('subtractHours', d, n, timezone) @staticmethod - def subtractMinutes(d, n, timezone=None): - return F('subtractMinutes', d, n, timezone) if timezone else F('subtractMinutes', d, n) + def subtractMinutes(d, n, timezone=NO_VALUE): + return F('subtractMinutes', d, n, timezone) @staticmethod - def subtractMonths(d, n, timezone=None): - return F('subtractMonths', d, n, timezone) if timezone else F('subtractMonths', d, n) + def subtractMonths(d, n, timezone=NO_VALUE): + return F('subtractMonths', d, n, timezone) @staticmethod - def subtractQuarters(d, n, timezone=None): - return F('subtractQuarters', d, n, timezone) if timezone else F('subtractQuarters', d, n) + def subtractQuarters(d, n, timezone=NO_VALUE): + return F('subtractQuarters', d, n, timezone) @staticmethod - def subtractSeconds(d, n, timezone=None): - return F('subtractSeconds', d, n, timezone) if timezone else F('subtractSeconds', d, n) + def subtractSeconds(d, n, timezone=NO_VALUE): + return F('subtractSeconds', d, n, timezone) @staticmethod - def subtractWeeks(d, n, timezone=None): - return F('subtractWeeks', d, n, timezone) if timezone else F('subtractWeeks', d, n) + def subtractWeeks(d, n, timezone=NO_VALUE): + return F('subtractWeeks', d, n, timezone) @staticmethod - def subtractYears(d, n, timezone=None): - return F('subtractYears', d, n, timezone) if timezone else F('subtractYears', d, n) + def subtractYears(d, n, timezone=NO_VALUE): + return F('subtractYears', d, n, timezone) # Type conversion functions @staticmethod + @type_conversion def toUInt8(x): return F('toUInt8', x) @staticmethod + @type_conversion def toUInt16(x): return F('toUInt16', x) @staticmethod + @type_conversion def toUInt32(x): return F('toUInt32', x) @staticmethod + @type_conversion def toUInt64(x): return F('toUInt64', x) @staticmethod + @type_conversion def toInt8(x): return F('toInt8', x) @staticmethod + @type_conversion def toInt16(x): return F('toInt16', x) @staticmethod + @type_conversion def toInt32(x): return F('toInt32', x) @staticmethod + @type_conversion def toInt64(x): return F('toInt64', x) @staticmethod + @type_conversion def toFloat32(x): return F('toFloat32', x) @staticmethod + @type_conversion def toFloat64(x): return F('toFloat64', x) @staticmethod - def toUInt8OrZero(x): - return F('toUInt8OrZero', x) - - @staticmethod - def toUInt16OrZero(x): - return F('toUInt16OrZero', x) - - @staticmethod - def toUInt32OrZero(x): - return F('toUInt32OrZero', x) - - @staticmethod - def toUInt64OrZero(x): - return F('toUInt64OrZero', x) - - @staticmethod - def toInt8OrZero(x): - return F('toInt8OrZero', x) - - @staticmethod - def toInt16OrZero(x): - return F('toInt16OrZero', x) - - @staticmethod - def toInt32OrZero(x): - return F('toInt32OrZero', x) - - @staticmethod - def toInt64OrZero(x): - return F('toInt64OrZero', x) - - @staticmethod - def toFloat32OrZero(x): - return F('toFloat32OrZero', x) - - @staticmethod - def toFloat64OrZero(x): - return F('toFloat64OrZero', x) - - @staticmethod + @type_conversion def toDecimal32(x, scale): return F('toDecimal32', x, scale) @staticmethod + @type_conversion def toDecimal64(x, scale): return F('toDecimal64', x, scale) @staticmethod + @type_conversion def toDecimal128(x, scale): return F('toDecimal128', x, scale) @staticmethod + @type_conversion def toDate(x): return F('toDate', x) @staticmethod + @type_conversion def toDateTime(x): return F('toDateTime', x) @@ -574,16 +624,9 @@ class F(Cond, FunctionOperatorsMixin): return F('CAST', x, type) @staticmethod - def parseDateTimeBestEffort(d, timezone=None): - return F('parseDateTimeBestEffort', d, timezone) if timezone else F('parseDateTimeBestEffort', d) - - @staticmethod - def parseDateTimeBestEffortOrNull(d, timezone=None): - return F('parseDateTimeBestEffortOrNull', d, timezone) if timezone else F('parseDateTimeBestEffortOrNull', d) - - @staticmethod - def parseDateTimeBestEffortOrZero(d, timezone=None): - return F('parseDateTimeBestEffortOrZero', d, timezone) if timezone else F('parseDateTimeBestEffortOrZero', d) + @type_conversion + def parseDateTimeBestEffort(d, timezone=NO_VALUE): + return F('parseDateTimeBestEffort', d, timezone) # Functions for working with strings @@ -1314,90 +1357,112 @@ class F(Cond, FunctionOperatorsMixin): # Aggregate functions @staticmethod + @aggregate def any(x): return F('any', x) @staticmethod + @aggregate def anyHeavy(x): return F('anyHeavy', x) @staticmethod + @aggregate def anyLast(x): return F('anyLast', x) @staticmethod + @aggregate def argMax(x, y): return F('argMax', x, y) @staticmethod + @aggregate def argMin(x, y): return F('argMin', x, y) @staticmethod + @aggregate def avg(x): return F('avg', x) @staticmethod + @aggregate def corr(x, y): return F('corr', x, y) @staticmethod + @aggregate def count(): return F('count') @staticmethod + @aggregate def covarPop(x, y): return F('covarPop', x, y) @staticmethod + @aggregate def covarSamp(x, y): return F('covarSamp', x, y) @staticmethod + @aggregate def kurtPop(x): return F('kurtPop', x) @staticmethod + @aggregate def kurtSamp(x): return F('kurtSamp', x) @staticmethod + @aggregate def min(x): return F('min', x) @staticmethod + @aggregate def max(x): return F('max', x) @staticmethod + @aggregate def skewPop(x): return F('skewPop', x) @staticmethod + @aggregate def skewSamp(x): return F('skewSamp', x) @staticmethod + @aggregate def sum(x): return F('sum', x) @staticmethod + @aggregate def uniq(*args): return F('uniq', *args) @staticmethod + @aggregate def uniqExact(*args): return F('uniqExact', *args) @staticmethod + @aggregate def uniqHLL12(*args): return F('uniqHLL12', *args) @staticmethod + @aggregate def varPop(x): return F('varPop', x) @staticmethod + @aggregate def varSamp(x): return F('varSamp', x) diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index 939d240..aca3e95 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -7,7 +7,7 @@ from six import reraise import pytz from .fields import Field, StringField -from .utils import parse_tsv +from .utils import parse_tsv, NO_VALUE from .query import QuerySet from .funcs import F from .engines import Merge, Distributed @@ -15,17 +15,6 @@ from .engines import Merge, Distributed logger = getLogger('clickhouse_orm') -class NoValue: - ''' - A sentinel for fields with an expression for a default value, - that were not assigned a value yet. - ''' - def __repr__(self): - return '' - -NO_VALUE = NoValue() - - class ModelBase(type): ''' A metaclass for ORM models. It adds the _fields list to model classes. diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index eb895a4..9e678fb 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -112,3 +112,14 @@ def is_iterable(obj): return True except TypeError: return False + + +class NoValue: + ''' + A sentinel for fields with an expression for a default value, + that were not assigned a value yet. + ''' + def __repr__(self): + return 'NO_VALUE' + +NO_VALUE = NoValue() diff --git a/tests/test_funcs.py b/tests/test_funcs.py index d969836..77481f5 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -4,6 +4,7 @@ from .test_querysets import SampleModel from datetime import date, datetime, tzinfo, timedelta from ipaddress import IPv4Address, IPv6Address from infi.clickhouse_orm.database import ServerError +from infi.clickhouse_orm.utils import NO_VALUE class FuncsTestCase(TestCaseWithData): @@ -21,21 +22,21 @@ class FuncsTestCase(TestCaseWithData): self.assertEqual(count, expected_count) self.assertEqual(qs.count(), expected_count) - def _test_func(self, func, expected_value=None): + def _test_func(self, func, expected_value=NO_VALUE): sql = 'SELECT %s AS value' % func.to_sql() logger.info(sql) result = list(self.database.select(sql)) logger.info('\t==> %s', result[0].value if result else '') - if expected_value is not None: + if expected_value != NO_VALUE: self.assertEqual(result[0].value, expected_value) return result[0].value if result else None - def _test_aggr(self, func, expected_value=None): + def _test_aggr(self, func, expected_value=NO_VALUE): qs = Person.objects_in(self.database).aggregate(value=func) logger.info(qs.as_sql()) result = list(qs) logger.info('\t==> %s', result[0].value if result else '') - if expected_value is not None: + if expected_value != NO_VALUE: self.assertEqual(result[0].value, expected_value) return result[0].value if result else None @@ -316,7 +317,7 @@ class FuncsTestCase(TestCaseWithData): try: self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello') - self._test_func(F.tryBase64Decode(':-)'), None) + self._test_func(F.tryBase64Decode(':-)')) except ServerError as e: # ClickHouse version that doesn't support these functions raise unittest.SkipTest(e.message) @@ -548,3 +549,27 @@ class FuncsTestCase(TestCaseWithData): self._test_aggr(F.varPop(Person.height)) self._test_aggr(F.varSamp(Person.height)) + def test_aggregate_funcs__or_default(self): + self.database.raw('TRUNCATE TABLE person') + self._test_aggr(F.countOrDefault(), 0) + self._test_aggr(F.maxOrDefault(Person.height), 0) + + def test_aggregate_funcs__or_null(self): + self.database.raw('TRUNCATE TABLE person') + self._test_aggr(F.countOrNull(), None) + self._test_aggr(F.maxOrNull(Person.height), None) + + def test_aggregate_funcs__if(self): + self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H')) + self._test_aggr(F.countIf(Person.last_name > 'H'), 57) + self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6) + + def test_aggregate_funcs__or_default_if(self): + self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z')) + self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0) + self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0) + + def test_aggregate_funcs__or_null_if(self): + self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z')) + self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None) + self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None)