- move NO_VALUE to utils

- dynamic generation of func variants (...OrZero, ...OrNull)
This commit is contained in:
Itai Shirav 2020-02-10 10:06:21 +02:00
parent 25c4a6710e
commit 9f36b17fee
5 changed files with 196 additions and 106 deletions

View File

@ -25,7 +25,7 @@ class Event(models.Model):
engine = engines.Memory()
...
```
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` instead. For example:
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` instead. For example:
```python
>>> event = Event()
>>> print(event.to_dict())
@ -63,7 +63,7 @@ db.select('SELECT created, created_date, username, name FROM $db.event', model_c
# created_date and username will contain a default value
db.select('SELECT * FROM $db.event', model_class=Event)
```
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.models.NO_VALUE` since their real values can only be known after insertion to the database.
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database.
## codec

View File

@ -1,7 +1,9 @@
from datetime import date, datetime, tzinfo
import functools
from functools import wraps
from inspect import signature, Parameter
from types import FunctionType
from .utils import is_iterable, comma_join
from .utils import is_iterable, comma_join, NO_VALUE
from .query import Cond
@ -9,7 +11,7 @@ def binary_operator(func):
"""
Decorates a function to mark it as a binary operator.
"""
@functools.wraps(func)
@wraps(func)
def wrapper(*args, **kwargs):
ret = func(*args, **kwargs)
ret.is_binary_operator = True
@ -17,6 +19,29 @@ def binary_operator(func):
return wrapper
def type_conversion(func):
"""
Decorates a function to mark it as a type conversion function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.f_type = 'type_conversion'
return wrapper
def aggregate(func):
"""
Decorates a function to mark it as an aggregate function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.f_type = 'aggregate'
return wrapper
class FunctionOperatorsMixin(object):
"""
A mixin for implementing Python operators using F objects.
@ -104,7 +129,57 @@ class FunctionOperatorsMixin(object):
return F._not(self)
class F(Cond, FunctionOperatorsMixin):
class FMeta(type):
FUNCTION_COMBINATORS = {
'type_conversion': [
{'suffix': 'OrZero'},
{'suffix': 'OrNull'},
],
'aggregate': [
{'suffix': 'OrDefault'},
{'suffix': 'OrNull'},
{'suffix': 'If', 'args': ['cond']},
{'suffix': 'OrDefaultIf', 'args': ['cond']},
{'suffix': 'OrNullIf', 'args': ['cond']},
]
}
def __init__(cls, name, bases, dct):
for name, obj in dct.items():
if hasattr(obj, '__func__'):
f_type = getattr(obj.__func__, 'f_type', '')
for combinator in FMeta.FUNCTION_COMBINATORS.get(f_type, []):
new_name = name + combinator['suffix']
FMeta._add_func(cls, obj.__func__, new_name, combinator.get('args'))
@staticmethod
def _add_func(cls, base_func, new_name, extra_args):
"""
Adds a new func to the cls, based on the signature of the given base_func but with a new name.
"""
# Get the function's signature
sig = signature(base_func)
new_sig = str(sig)[1 : -1] # omit the parentheses
args = comma_join(sig.parameters)
# Add extra args
if extra_args:
if args:
args = comma_join([args] + extra_args)
new_sig = comma_join([new_sig] + extra_args)
else:
args = comma_join(extra_args)
new_sig = comma_join(extra_args)
# Get default values for args
argdefs = tuple(p.default for p in sig.parameters.values() if p.default != Parameter.empty)
# Build the new function
new_code = compile(f'def {new_name}({new_sig}): return F("{new_name}", {args})', __file__, 'exec')
new_func = FunctionType(code=new_code.co_consts[0], globals=globals(), name=new_name, argdefs=argdefs)
# Attach to class
setattr(cls, new_name, new_func)
class F(Cond, FunctionOperatorsMixin, metaclass=FMeta):
"""
Represents a database function call and its arguments.
It doubles as a query condition when the function returns a boolean result.
@ -135,7 +210,7 @@ class F(Cond, FunctionOperatorsMixin):
else:
prefix = self.name
sep = ', '
arg_strs = (F._arg_to_sql(arg) for arg in self.args)
arg_strs = (F._arg_to_sql(arg) for arg in self.args if arg != NO_VALUE)
return prefix + '(' + sep.join(arg_strs) + ')'
@staticmethod
@ -392,168 +467,143 @@ class F(Cond, FunctionOperatorsMixin):
return F('formatDateTime', d, format, timezone)
@staticmethod
def addDays(d, n, timezone=None):
return F('addDays', d, n, timezone) if timezone else F('addDays', d, n)
def addDays(d, n, timezone=NO_VALUE):
return F('addDays', d, n, timezone)
@staticmethod
def addHours(d, n, timezone=None):
return F('addHours', d, n, timezone) if timezone else F('addHours', d, n)
def addHours(d, n, timezone=NO_VALUE):
return F('addHours', d, n, timezone)
@staticmethod
def addMinutes(d, n, timezone=None):
return F('addMinutes', d, n, timezone) if timezone else F('addMinutes', d, n)
def addMinutes(d, n, timezone=NO_VALUE):
return F('addMinutes', d, n, timezone)
@staticmethod
def addMonths(d, n, timezone=None):
return F('addMonths', d, n, timezone) if timezone else F('addMonths', d, n)
def addMonths(d, n, timezone=NO_VALUE):
return F('addMonths', d, n, timezone)
@staticmethod
def addQuarters(d, n, timezone=None):
return F('addQuarters', d, n, timezone) if timezone else F('addQuarters', d, n)
def addQuarters(d, n, timezone=NO_VALUE):
return F('addQuarters', d, n, timezone)
@staticmethod
def addSeconds(d, n, timezone=None):
return F('addSeconds', d, n, timezone) if timezone else F('addSeconds', d, n)
def addSeconds(d, n, timezone=NO_VALUE):
return F('addSeconds', d, n, timezone)
@staticmethod
def addWeeks(d, n, timezone=None):
return F('addWeeks', d, n, timezone) if timezone else F('addWeeks', d, n)
def addWeeks(d, n, timezone=NO_VALUE):
return F('addWeeks', d, n, timezone)
@staticmethod
def addYears(d, n, timezone=None):
return F('addYears', d, n, timezone) if timezone else F('addYears', d, n)
def addYears(d, n, timezone=NO_VALUE):
return F('addYears', d, n, timezone)
@staticmethod
def subtractDays(d, n, timezone=None):
return F('subtractDays', d, n, timezone) if timezone else F('subtractDays', d, n)
def subtractDays(d, n, timezone=NO_VALUE):
return F('subtractDays', d, n, timezone)
@staticmethod
def subtractHours(d, n, timezone=None):
return F('subtractHours', d, n, timezone) if timezone else F('subtractHours', d, n)
def subtractHours(d, n, timezone=NO_VALUE):
return F('subtractHours', d, n, timezone)
@staticmethod
def subtractMinutes(d, n, timezone=None):
return F('subtractMinutes', d, n, timezone) if timezone else F('subtractMinutes', d, n)
def subtractMinutes(d, n, timezone=NO_VALUE):
return F('subtractMinutes', d, n, timezone)
@staticmethod
def subtractMonths(d, n, timezone=None):
return F('subtractMonths', d, n, timezone) if timezone else F('subtractMonths', d, n)
def subtractMonths(d, n, timezone=NO_VALUE):
return F('subtractMonths', d, n, timezone)
@staticmethod
def subtractQuarters(d, n, timezone=None):
return F('subtractQuarters', d, n, timezone) if timezone else F('subtractQuarters', d, n)
def subtractQuarters(d, n, timezone=NO_VALUE):
return F('subtractQuarters', d, n, timezone)
@staticmethod
def subtractSeconds(d, n, timezone=None):
return F('subtractSeconds', d, n, timezone) if timezone else F('subtractSeconds', d, n)
def subtractSeconds(d, n, timezone=NO_VALUE):
return F('subtractSeconds', d, n, timezone)
@staticmethod
def subtractWeeks(d, n, timezone=None):
return F('subtractWeeks', d, n, timezone) if timezone else F('subtractWeeks', d, n)
def subtractWeeks(d, n, timezone=NO_VALUE):
return F('subtractWeeks', d, n, timezone)
@staticmethod
def subtractYears(d, n, timezone=None):
return F('subtractYears', d, n, timezone) if timezone else F('subtractYears', d, n)
def subtractYears(d, n, timezone=NO_VALUE):
return F('subtractYears', d, n, timezone)
# Type conversion functions
@staticmethod
@type_conversion
def toUInt8(x):
return F('toUInt8', x)
@staticmethod
@type_conversion
def toUInt16(x):
return F('toUInt16', x)
@staticmethod
@type_conversion
def toUInt32(x):
return F('toUInt32', x)
@staticmethod
@type_conversion
def toUInt64(x):
return F('toUInt64', x)
@staticmethod
@type_conversion
def toInt8(x):
return F('toInt8', x)
@staticmethod
@type_conversion
def toInt16(x):
return F('toInt16', x)
@staticmethod
@type_conversion
def toInt32(x):
return F('toInt32', x)
@staticmethod
@type_conversion
def toInt64(x):
return F('toInt64', x)
@staticmethod
@type_conversion
def toFloat32(x):
return F('toFloat32', x)
@staticmethod
@type_conversion
def toFloat64(x):
return F('toFloat64', x)
@staticmethod
def toUInt8OrZero(x):
return F('toUInt8OrZero', x)
@staticmethod
def toUInt16OrZero(x):
return F('toUInt16OrZero', x)
@staticmethod
def toUInt32OrZero(x):
return F('toUInt32OrZero', x)
@staticmethod
def toUInt64OrZero(x):
return F('toUInt64OrZero', x)
@staticmethod
def toInt8OrZero(x):
return F('toInt8OrZero', x)
@staticmethod
def toInt16OrZero(x):
return F('toInt16OrZero', x)
@staticmethod
def toInt32OrZero(x):
return F('toInt32OrZero', x)
@staticmethod
def toInt64OrZero(x):
return F('toInt64OrZero', x)
@staticmethod
def toFloat32OrZero(x):
return F('toFloat32OrZero', x)
@staticmethod
def toFloat64OrZero(x):
return F('toFloat64OrZero', x)
@staticmethod
@type_conversion
def toDecimal32(x, scale):
return F('toDecimal32', x, scale)
@staticmethod
@type_conversion
def toDecimal64(x, scale):
return F('toDecimal64', x, scale)
@staticmethod
@type_conversion
def toDecimal128(x, scale):
return F('toDecimal128', x, scale)
@staticmethod
@type_conversion
def toDate(x):
return F('toDate', x)
@staticmethod
@type_conversion
def toDateTime(x):
return F('toDateTime', x)
@ -574,16 +624,9 @@ class F(Cond, FunctionOperatorsMixin):
return F('CAST', x, type)
@staticmethod
def parseDateTimeBestEffort(d, timezone=None):
return F('parseDateTimeBestEffort', d, timezone) if timezone else F('parseDateTimeBestEffort', d)
@staticmethod
def parseDateTimeBestEffortOrNull(d, timezone=None):
return F('parseDateTimeBestEffortOrNull', d, timezone) if timezone else F('parseDateTimeBestEffortOrNull', d)
@staticmethod
def parseDateTimeBestEffortOrZero(d, timezone=None):
return F('parseDateTimeBestEffortOrZero', d, timezone) if timezone else F('parseDateTimeBestEffortOrZero', d)
@type_conversion
def parseDateTimeBestEffort(d, timezone=NO_VALUE):
return F('parseDateTimeBestEffort', d, timezone)
# Functions for working with strings
@ -1314,90 +1357,112 @@ class F(Cond, FunctionOperatorsMixin):
# Aggregate functions
@staticmethod
@aggregate
def any(x):
return F('any', x)
@staticmethod
@aggregate
def anyHeavy(x):
return F('anyHeavy', x)
@staticmethod
@aggregate
def anyLast(x):
return F('anyLast', x)
@staticmethod
@aggregate
def argMax(x, y):
return F('argMax', x, y)
@staticmethod
@aggregate
def argMin(x, y):
return F('argMin', x, y)
@staticmethod
@aggregate
def avg(x):
return F('avg', x)
@staticmethod
@aggregate
def corr(x, y):
return F('corr', x, y)
@staticmethod
@aggregate
def count():
return F('count')
@staticmethod
@aggregate
def covarPop(x, y):
return F('covarPop', x, y)
@staticmethod
@aggregate
def covarSamp(x, y):
return F('covarSamp', x, y)
@staticmethod
@aggregate
def kurtPop(x):
return F('kurtPop', x)
@staticmethod
@aggregate
def kurtSamp(x):
return F('kurtSamp', x)
@staticmethod
@aggregate
def min(x):
return F('min', x)
@staticmethod
@aggregate
def max(x):
return F('max', x)
@staticmethod
@aggregate
def skewPop(x):
return F('skewPop', x)
@staticmethod
@aggregate
def skewSamp(x):
return F('skewSamp', x)
@staticmethod
@aggregate
def sum(x):
return F('sum', x)
@staticmethod
@aggregate
def uniq(*args):
return F('uniq', *args)
@staticmethod
@aggregate
def uniqExact(*args):
return F('uniqExact', *args)
@staticmethod
@aggregate
def uniqHLL12(*args):
return F('uniqHLL12', *args)
@staticmethod
@aggregate
def varPop(x):
return F('varPop', x)
@staticmethod
@aggregate
def varSamp(x):
return F('varSamp', x)

View File

@ -7,7 +7,7 @@ from six import reraise
import pytz
from .fields import Field, StringField
from .utils import parse_tsv
from .utils import parse_tsv, NO_VALUE
from .query import QuerySet
from .funcs import F
from .engines import Merge, Distributed
@ -15,17 +15,6 @@ from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm')
class NoValue:
'''
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
def __repr__(self):
return '<NO_VALUE>'
NO_VALUE = NoValue()
class ModelBase(type):
'''
A metaclass for ORM models. It adds the _fields list to model classes.

View File

@ -112,3 +112,14 @@ def is_iterable(obj):
return True
except TypeError:
return False
class NoValue:
'''
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
def __repr__(self):
return 'NO_VALUE'
NO_VALUE = NoValue()

View File

@ -4,6 +4,7 @@ from .test_querysets import SampleModel
from datetime import date, datetime, tzinfo, timedelta
from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import ServerError
from infi.clickhouse_orm.utils import NO_VALUE
class FuncsTestCase(TestCaseWithData):
@ -21,21 +22,21 @@ class FuncsTestCase(TestCaseWithData):
self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=None):
def _test_func(self, func, expected_value=NO_VALUE):
sql = 'SELECT %s AS value' % func.to_sql()
logger.info(sql)
result = list(self.database.select(sql))
logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None:
if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
def _test_aggr(self, func, expected_value=None):
def _test_aggr(self, func, expected_value=NO_VALUE):
qs = Person.objects_in(self.database).aggregate(value=func)
logger.info(qs.as_sql())
result = list(qs)
logger.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value is not None:
if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
@ -316,7 +317,7 @@ class FuncsTestCase(TestCaseWithData):
try:
self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(':-)'), None)
self._test_func(F.tryBase64Decode(':-)'))
except ServerError as e:
# ClickHouse version that doesn't support these functions
raise unittest.SkipTest(e.message)
@ -548,3 +549,27 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.varPop(Person.height))
self._test_aggr(F.varSamp(Person.height))
def test_aggregate_funcs__or_default(self):
self.database.raw('TRUNCATE TABLE person')
self._test_aggr(F.countOrDefault(), 0)
self._test_aggr(F.maxOrDefault(Person.height), 0)
def test_aggregate_funcs__or_null(self):
self.database.raw('TRUNCATE TABLE person')
self._test_aggr(F.countOrNull(), None)
self._test_aggr(F.maxOrNull(Person.height), None)
def test_aggregate_funcs__if(self):
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H'))
self._test_aggr(F.countIf(Person.last_name > 'H'), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6)
def test_aggregate_funcs__or_default_if(self):
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0)
def test_aggregate_funcs__or_null_if(self):
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None)