- move NO_VALUE to utils

- dynamic generation of func variants (...OrZero, ...OrNull)
This commit is contained in:
Itai Shirav 2020-02-10 10:06:21 +02:00
parent 25c4a6710e
commit 9f36b17fee
5 changed files with 196 additions and 106 deletions

View File

@ -25,7 +25,7 @@ class Event(models.Model):
engine = engines.Memory() engine = engines.Memory()
... ...
``` ```
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` instead. For example: When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` instead. For example:
```python ```python
>>> event = Event() >>> event = Event()
>>> print(event.to_dict()) >>> print(event.to_dict())
@ -63,7 +63,7 @@ db.select('SELECT created, created_date, username, name FROM $db.event', model_c
# created_date and username will contain a default value # created_date and username will contain a default value
db.select('SELECT * FROM $db.event', model_class=Event) db.select('SELECT * FROM $db.event', model_class=Event)
``` ```
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` since their real values can only be known after insertion to the database. When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database.
## codec ## codec

View File

@ -1,7 +1,9 @@
from datetime import date, datetime, tzinfo from datetime import date, datetime, tzinfo
import functools from functools import wraps
from inspect import signature, Parameter
from types import FunctionType
from .utils import is_iterable, comma_join from .utils import is_iterable, comma_join, NO_VALUE
from .query import Cond from .query import Cond
@ -9,7 +11,7 @@ def binary_operator(func):
""" """
Decorates a function to mark it as a binary operator. Decorates a function to mark it as a binary operator.
""" """
@functools.wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
ret.is_binary_operator = True ret.is_binary_operator = True
@ -17,6 +19,29 @@ def binary_operator(func):
return wrapper return wrapper
def type_conversion(func):
"""
Decorates a function to mark it as a type conversion function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.f_type = 'type_conversion'
return wrapper
def aggregate(func):
"""
Decorates a function to mark it as an aggregate function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.f_type = 'aggregate'
return wrapper
class FunctionOperatorsMixin(object): class FunctionOperatorsMixin(object):
""" """
A mixin for implementing Python operators using F objects. A mixin for implementing Python operators using F objects.
@ -104,7 +129,57 @@ class FunctionOperatorsMixin(object):
return F._not(self) return F._not(self)
class F(Cond, FunctionOperatorsMixin): class FMeta(type):
FUNCTION_COMBINATORS = {
'type_conversion': [
{'suffix': 'OrZero'},
{'suffix': 'OrNull'},
],
'aggregate': [
{'suffix': 'OrDefault'},
{'suffix': 'OrNull'},
{'suffix': 'If', 'args': ['cond']},
{'suffix': 'OrDefaultIf', 'args': ['cond']},
{'suffix': 'OrNullIf', 'args': ['cond']},
]
}
def __init__(cls, name, bases, dct):
for name, obj in dct.items():
if hasattr(obj, '__func__'):
f_type = getattr(obj.__func__, 'f_type', '')
for combinator in FMeta.FUNCTION_COMBINATORS.get(f_type, []):
new_name = name + combinator['suffix']
FMeta._add_func(cls, obj.__func__, new_name, combinator.get('args'))
@staticmethod
def _add_func(cls, base_func, new_name, extra_args):
"""
Adds a new func to the cls, based on the signature of the given base_func but with a new name.
"""
# Get the function's signature
sig = signature(base_func)
new_sig = str(sig)[1 : -1] # omit the parentheses
args = comma_join(sig.parameters)
# Add extra args
if extra_args:
if args:
args = comma_join([args] + extra_args)
new_sig = comma_join([new_sig] + extra_args)
else:
args = comma_join(extra_args)
new_sig = comma_join(extra_args)
# Get default values for args
argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty)
# Build the new function
new_code = compile(f'def {new_name}({new_sig}): return F("{new_name}", {args})', __file__, 'exec')
new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs)
# Attach to class
setattr(cls, new_name, new_func)
class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
""" """
Represents a database function call and its arguments. Represents a database function call and its arguments.
It doubles as a query condition when the function returns a boolean result. It doubles as a query condition when the function returns a boolean result.
@ -135,7 +210,7 @@ class F(Cond, FunctionOperatorsMixin):
else: else:
prefix = self.name prefix = self.name
sep = ', ' sep = ', '
arg_strs = (F._arg_to_sql(arg) for arg in self.args) arg_strs = (F._arg_to_sql(arg) for arg in self.args if arg != NO_VALUE)
return prefix + '(' + sep.join(arg_strs) + ')' return prefix + '(' + sep.join(arg_strs) + ')'
@staticmethod @staticmethod
@ -392,168 +467,143 @@ class F(Cond, FunctionOperatorsMixin):
return F('formatDateTime', d, format, timezone) return F('formatDateTime', d, format, timezone)
@staticmethod @staticmethod
def addDays(d, n, timezone=None): def addDays(d, n, timezone=NO_VALUE):
return F('addDays', d, n, timezone) if timezone else F('addDays', d, n) return F('addDays', d, n, timezone)
@staticmethod @staticmethod
def addHours(d, n, timezone=None): def addHours(d, n, timezone=NO_VALUE):
return F('addHours', d, n, timezone) if timezone else F('addHours', d, n) return F('addHours', d, n, timezone)
@staticmethod @staticmethod
def addMinutes(d, n, timezone=None): def addMinutes(d, n, timezone=NO_VALUE):
return F('addMinutes', d, n, timezone) if timezone else F('addMinutes', d, n) return F('addMinutes', d, n, timezone)
@staticmethod @staticmethod
def addMonths(d, n, timezone=None): def addMonths(d, n, timezone=NO_VALUE):
return F('addMonths', d, n, timezone) if timezone else F('addMonths', d, n) return F('addMonths', d, n, timezone)
@staticmethod @staticmethod
def addQuarters(d, n, timezone=None): def addQuarters(d, n, timezone=NO_VALUE):
return F('addQuarters', d, n, timezone) if timezone else F('addQuarters', d, n) return F('addQuarters', d, n, timezone)
@staticmethod @staticmethod
def addSeconds(d, n, timezone=None): def addSeconds(d, n, timezone=NO_VALUE):
return F('addSeconds', d, n, timezone) if timezone else F('addSeconds', d, n) return F('addSeconds', d, n, timezone)
@staticmethod @staticmethod
def addWeeks(d, n, timezone=None): def addWeeks(d, n, timezone=NO_VALUE):
return F('addWeeks', d, n, timezone) if timezone else F('addWeeks', d, n) return F('addWeeks', d, n, timezone)
@staticmethod @staticmethod
def addYears(d, n, timezone=None): def addYears(d, n, timezone=NO_VALUE):
return F('addYears', d, n, timezone) if timezone else F('addYears', d, n) return F('addYears', d, n, timezone)
@staticmethod @staticmethod
def subtractDays(d, n, timezone=None): def subtractDays(d, n, timezone=NO_VALUE):
return F('subtractDays', d, n, timezone) if timezone else F('subtractDays', d, n) return F('subtractDays', d, n, timezone)
@staticmethod @staticmethod
def subtractHours(d, n, timezone=None): def subtractHours(d, n, timezone=NO_VALUE):
return F('subtractHours', d, n, timezone) if timezone else F('subtractHours', d, n) return F('subtractHours', d, n, timezone)
@staticmethod @staticmethod
def subtractMinutes(d, n, timezone=None): def subtractMinutes(d, n, timezone=NO_VALUE):
return F('subtractMinutes', d, n, timezone) if timezone else F('subtractMinutes', d, n) return F('subtractMinutes', d, n, timezone)
@staticmethod @staticmethod
def subtractMonths(d, n, timezone=None): def subtractMonths(d, n, timezone=NO_VALUE):
return F('subtractMonths', d, n, timezone) if timezone else F('subtractMonths', d, n) return F('subtractMonths', d, n, timezone)
@staticmethod @staticmethod
def subtractQuarters(d, n, timezone=None): def subtractQuarters(d, n, timezone=NO_VALUE):
return F('subtractQuarters', d, n, timezone) if timezone else F('subtractQuarters', d, n) return F('subtractQuarters', d, n, timezone)
@staticmethod @staticmethod
def subtractSeconds(d, n, timezone=None): def subtractSeconds(d, n, timezone=NO_VALUE):
return F('subtractSeconds', d, n, timezone) if timezone else F('subtractSeconds', d, n) return F('subtractSeconds', d, n, timezone)
@staticmethod @staticmethod
def subtractWeeks(d, n, timezone=None): def subtractWeeks(d, n, timezone=NO_VALUE):
return F('subtractWeeks', d, n, timezone) if timezone else F('subtractWeeks', d, n) return F('subtractWeeks', d, n, timezone)
@staticmethod @staticmethod
def subtractYears(d, n, timezone=None): def subtractYears(d, n, timezone=NO_VALUE):
return F('subtractYears', d, n, timezone) if timezone else F('subtractYears', d, n) return F('subtractYears', d, n, timezone)
# Type conversion functions # Type conversion functions
@staticmethod @staticmethod
@type_conversion
def toUInt8(x): def toUInt8(x):
return F('toUInt8', x) return F('toUInt8', x)
@staticmethod @staticmethod
@type_conversion
def toUInt16(x): def toUInt16(x):
return F('toUInt16', x) return F('toUInt16', x)
@staticmethod @staticmethod
@type_conversion
def toUInt32(x): def toUInt32(x):
return F('toUInt32', x) return F('toUInt32', x)
@staticmethod @staticmethod
@type_conversion
def toUInt64(x): def toUInt64(x):
return F('toUInt64', x) return F('toUInt64', x)
@staticmethod @staticmethod
@type_conversion
def toInt8(x): def toInt8(x):
return F('toInt8', x) return F('toInt8', x)
@staticmethod @staticmethod
@type_conversion
def toInt16(x): def toInt16(x):
return F('toInt16', x) return F('toInt16', x)
@staticmethod @staticmethod
@type_conversion
def toInt32(x): def toInt32(x):
return F('toInt32', x) return F('toInt32', x)
@staticmethod @staticmethod
@type_conversion
def toInt64(x): def toInt64(x):
return F('toInt64', x) return F('toInt64', x)
@staticmethod @staticmethod
@type_conversion
def toFloat32(x): def toFloat32(x):
return F('toFloat32', x) return F('toFloat32', x)
@staticmethod @staticmethod
@type_conversion
def toFloat64(x): def toFloat64(x):
return F('toFloat64', x) return F('toFloat64', x)
@staticmethod @staticmethod
def toUInt8OrZero(x): @type_conversion
return F('toUInt8OrZero', x)
@staticmethod
def toUInt16OrZero(x):
return F('toUInt16OrZero', x)
@staticmethod
def toUInt32OrZero(x):
return F('toUInt32OrZero', x)
@staticmethod
def toUInt64OrZero(x):
return F('toUInt64OrZero', x)
@staticmethod
def toInt8OrZero(x):
return F('toInt8OrZero', x)
@staticmethod
def toInt16OrZero(x):
return F('toInt16OrZero', x)
@staticmethod
def toInt32OrZero(x):
return F('toInt32OrZero', x)
@staticmethod
def toInt64OrZero(x):
return F('toInt64OrZero', x)
@staticmethod
def toFloat32OrZero(x):
return F('toFloat32OrZero', x)
@staticmethod
def toFloat64OrZero(x):
return F('toFloat64OrZero', x)
@staticmethod
def toDecimal32(x, scale): def toDecimal32(x, scale):
return F('toDecimal32', x, scale) return F('toDecimal32', x, scale)
@staticmethod @staticmethod
@type_conversion
def toDecimal64(x, scale): def toDecimal64(x, scale):
return F('toDecimal64', x, scale) return F('toDecimal64', x, scale)
@staticmethod @staticmethod
@type_conversion
def toDecimal128(x, scale): def toDecimal128(x, scale):
return F('toDecimal128', x, scale) return F('toDecimal128', x, scale)
@staticmethod @staticmethod
@type_conversion
def toDate(x): def toDate(x):
return F('toDate', x) return F('toDate', x)
@staticmethod @staticmethod
@type_conversion
def toDateTime(x): def toDateTime(x):
return F('toDateTime', x) return F('toDateTime', x)
@ -574,16 +624,9 @@ class F(Cond, FunctionOperatorsMixin):
return F('CAST', x, type) return F('CAST', x, type)
@staticmethod @staticmethod
def parseDateTimeBestEffort(d, timezone=None): @type_conversion
return F('parseDateTimeBestEffort', d, timezone) if timezone else F('parseDateTimeBestEffort', d) def parseDateTimeBestEffort(d, timezone=NO_VALUE):
return F('parseDateTimeBestEffort', d, timezone)
@staticmethod
def parseDateTimeBestEffortOrNull(d, timezone=None):
return F('parseDateTimeBestEffortOrNull', d, timezone) if timezone else F('parseDateTimeBestEffortOrNull', d)
@staticmethod
def parseDateTimeBestEffortOrZero(d, timezone=None):
return F('parseDateTimeBestEffortOrZero', d, timezone) if timezone else F('parseDateTimeBestEffortOrZero', d)
# Functions for working with strings # Functions for working with strings
@ -1314,90 +1357,112 @@ class F(Cond, FunctionOperatorsMixin):
# Aggregate functions # Aggregate functions
@staticmethod @staticmethod
@aggregate
def any(x): def any(x):
return F('any', x) return F('any', x)
@staticmethod @staticmethod
@aggregate
def anyHeavy(x): def anyHeavy(x):
return F('anyHeavy', x) return F('anyHeavy', x)
@staticmethod @staticmethod
@aggregate
def anyLast(x): def anyLast(x):
return F('anyLast', x) return F('anyLast', x)
@staticmethod @staticmethod
@aggregate
def argMax(x, y): def argMax(x, y):
return F('argMax', x, y) return F('argMax', x, y)
@staticmethod @staticmethod
@aggregate
def argMin(x, y): def argMin(x, y):
return F('argMin', x, y) return F('argMin', x, y)
@staticmethod @staticmethod
@aggregate
def avg(x): def avg(x):
return F('avg', x) return F('avg', x)
@staticmethod @staticmethod
@aggregate
def corr(x, y): def corr(x, y):
return F('corr', x, y) return F('corr', x, y)
@staticmethod @staticmethod
@aggregate
def count(): def count():
return F('count') return F('count')
@staticmethod @staticmethod
@aggregate
def covarPop(x, y): def covarPop(x, y):
return F('covarPop', x, y) return F('covarPop', x, y)
@staticmethod @staticmethod
@aggregate
def covarSamp(x, y): def covarSamp(x, y):
return F('covarSamp', x, y) return F('covarSamp', x, y)
@staticmethod @staticmethod
@aggregate
def kurtPop(x): def kurtPop(x):
return F('kurtPop', x) return F('kurtPop', x)
@staticmethod @staticmethod
@aggregate
def kurtSamp(x): def kurtSamp(x):
return F('kurtSamp', x) return F('kurtSamp', x)
@staticmethod @staticmethod
@aggregate
def min(x): def min(x):
return F('min', x) return F('min', x)
@staticmethod @staticmethod
@aggregate
def max(x): def max(x):
return F('max', x) return F('max', x)
@staticmethod @staticmethod
@aggregate
def skewPop(x): def skewPop(x):
return F('skewPop', x) return F('skewPop', x)
@staticmethod @staticmethod
@aggregate
def skewSamp(x): def skewSamp(x):
return F('skewSamp', x) return F('skewSamp', x)
@staticmethod @staticmethod
@aggregate
def sum(x): def sum(x):
return F('sum', x) return F('sum', x)
@staticmethod @staticmethod
@aggregate
def uniq(*args): def uniq(*args):
return F('uniq', *args) return F('uniq', *args)
@staticmethod @staticmethod
@aggregate
def uniqExact(*args): def uniqExact(*args):
return F('uniqExact', *args) return F('uniqExact', *args)
@staticmethod @staticmethod
@aggregate
def uniqHLL12(*args): def uniqHLL12(*args):
return F('uniqHLL12', *args) return F('uniqHLL12', *args)
@staticmethod @staticmethod
@aggregate
def varPop(x): def varPop(x):
return F('varPop', x) return F('varPop', x)
@staticmethod @staticmethod
@aggregate
def varSamp(x): def varSamp(x):
return F('varSamp', x) return F('varSamp', x)

View File

@ -7,7 +7,7 @@ from six import reraise
import pytz import pytz
from .fields import Field, StringField from .fields import Field, StringField
from .utils import parse_tsv from .utils import parse_tsv, NO_VALUE
from .query import QuerySet from .query import QuerySet
from .funcs import F from .funcs import F
from .engines import Merge, Distributed from .engines import Merge, Distributed
@ -15,17 +15,6 @@ from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm') logger = getLogger('clickhouse_orm')
class NoValue:
'''
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
def __repr__(self):
return '<NO_VALUE>'
NO_VALUE = NoValue()
class ModelBase(type): class ModelBase(type):
''' '''
A metaclass for ORM models. It adds the _fields list to model classes. A metaclass for ORM models. It adds the _fields list to model classes.

View File

@ -112,3 +112,14 @@ def is_iterable(obj):
return True return True
except TypeError: except TypeError:
return False return False
class NoValue:
'''
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
def __repr__(self):
return 'NO_VALUE'
NO_VALUE = NoValue()

View File

@ -4,6 +4,7 @@ from .test_querysets import SampleModel
from datetime import date, datetime, tzinfo, timedelta from datetime import date, datetime, tzinfo, timedelta
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import ServerError from infi.clickhouse_orm.database import ServerError
from infi.clickhouse_orm.utils import NO_VALUE
class FuncsTestCase(TestCaseWithData): class FuncsTestCase(TestCaseWithData):
@ -21,21 +22,21 @@ class FuncsTestCase(TestCaseWithData):
self.assertEqual(count, expected_count) self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count) self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=None): def _test_func(self, func, expected_value=NO_VALUE):
sql = 'SELECT %s AS value' % func.to_sql() sql = 'SELECT %s AS value' % func.to_sql()
logger.info(sql) logger.info(sql)
result = list(self.database.select(sql)) result = list(self.database.select(sql))
logger.info('\t==> %s', result[0].value if result else '<empty>') logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None: if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
def _test_aggr(self, func, expected_value=None): def _test_aggr(self, func, expected_value=NO_VALUE):
qs = Person.objects_in(self.database).aggregate(value=func) qs = Person.objects_in(self.database).aggregate(value=func)
logger.info(qs.as_sql()) logger.info(qs.as_sql())
result = list(qs) result = list(qs)
logger.info('\t==> %s', result[0].value if result else '<empty>') logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None: if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
@ -316,7 +317,7 @@ class FuncsTestCase(TestCaseWithData):
try: try:
self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(':-)'), None) self._test_func(F.tryBase64Decode(':-)'))
except ServerError as e: except ServerError as e:
# ClickHouse version that doesn't support these functions # ClickHouse version that doesn't support these functions
raise unittest.SkipTest(e.message) raise unittest.SkipTest(e.message)
@ -548,3 +549,27 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.varPop(Person.height)) self._test_aggr(F.varPop(Person.height))
self._test_aggr(F.varSamp(Person.height)) self._test_aggr(F.varSamp(Person.height))
def test_aggregate_funcs__or_default(self):
self.database.raw('TRUNCATE TABLE person')
self._test_aggr(F.countOrDefault(), 0)
self._test_aggr(F.maxOrDefault(Person.height), 0)
def test_aggregate_funcs__or_null(self):
self.database.raw('TRUNCATE TABLE person')
self._test_aggr(F.countOrNull(), None)
self._test_aggr(F.maxOrNull(Person.height), None)
def test_aggregate_funcs__if(self):
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H'))
self._test_aggr(F.countIf(Person.last_name > 'H'), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6)
def test_aggregate_funcs__or_default_if(self):
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0)
def test_aggregate_funcs__or_null_if(self):
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None)