diff --git a/CHANGELOG.md b/CHANGELOG.md index 90f49fe..1cfd096 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Unreleased - Support for model constraints - Support for data skipping indexes - Support for mutations: `QuerySet.update` and `QuerySet.delete` +- Added functions for working with external dictionaries - Support FINAL for `ReplacingMergeTree` (chripede) - Added `DateTime64Field` (NiyazNz) - Make `DateTimeField` and `DateTime64Field` timezone-aware (NiyazNz) diff --git a/docs/class_reference.md b/docs/class_reference.md index cb718c5..08716a5 100644 --- a/docs/class_reference.md +++ b/docs/class_reference.md @@ -1913,6 +1913,21 @@ Initializer. #### covarSampOrNullIf(y, cond) +#### dictGet(attr_name, id_expr) + + +#### dictGetHierarchy(id_expr) + + +#### dictGetOrDefault(attr_name, id_expr, default) + + +#### dictHas(id_expr) + + +#### dictIsIn(child_id_expr, ancestor_id_expr) + + #### divide(**kwargs) diff --git a/src/infi/clickhouse_orm/funcs.py b/src/infi/clickhouse_orm/funcs.py index 2763fa7..d84c761 100644 --- a/src/infi/clickhouse_orm/funcs.py +++ b/src/infi/clickhouse_orm/funcs.py @@ -1789,6 +1789,28 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): def greatest(x, y): return F('greatest', x, y) + # Dictionary functions + + @staticmethod + def dictGet(dict_name, attr_name, id_expr): + return F('dictGet', dict_name, attr_name, id_expr) + + @staticmethod + def dictGetOrDefault(dict_name, attr_name, id_expr, default): + return F('dictGetOrDefault', dict_name, attr_name, id_expr, default) + + @staticmethod + def dictHas(dict_name, id_expr): + return F('dictHas', dict_name, id_expr) + + @staticmethod + def dictGetHierarchy(dict_name, id_expr): + return F('dictGetHierarchy', dict_name, id_expr) + + @staticmethod + def dictIsIn(dict_name, child_id_expr, ancestor_id_expr): + return F('dictIsIn', dict_name, child_id_expr, ancestor_id_expr) + # Expose only relevant classes in import * __all__ = ['F'] diff --git a/tests/test_dictionaries.py b/tests/test_dictionaries.py new file mode 100644 index 0000000..7da4160 --- /dev/null +++ b/tests/test_dictionaries.py @@ -0,0 +1,131 @@ +import unittest +import logging + +from infi.clickhouse_orm import * + + +class DictionaryTestMixin: + + def setUp(self): + self.database = Database('test-db', log_statements=True) + if self.database.server_version < (20, 1, 11, 73): + raise unittest.SkipTest('ClickHouse version too old') + self._create_dictionary() + + def tearDown(self): + self.database.drop_database() + + def _test_func(self, func, expected_value): + sql = 'SELECT %s AS value' % func.to_sql() + logging.info(sql) + result = list(self.database.select(sql)) + logging.info('\t==> %s', result[0].value if result else '') + print('Comparing %s to %s' % (result[0].value, expected_value)) + self.assertEqual(result[0].value, expected_value) + return result[0].value if result else None + + +class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase): + + def _create_dictionary(self): + # Create a table to be used as source for the dictionary + self.database.create_table(NumberName) + self.database.insert( + NumberName(number=i, name=name) + for i, name in enumerate('Zero One Two Three Four Five Six Seven Eight Nine Ten'.split()) + ) + # Create the dictionary + self.database.raw(""" + CREATE DICTIONARY numbers_dict( + number UInt64, + name String DEFAULT '?' + ) + PRIMARY KEY number + SOURCE(CLICKHOUSE( + HOST 'localhost' PORT 9000 USER 'default' PASSWORD '' DB 'test-db' TABLE 'numbername' + )) + LIFETIME(100) + LAYOUT(HASHED()); + """) + self.dict_name = 'test-db.numbers_dict' + + def test_dictget(self): + self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(3)), 'Three') + self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(99)), '?') + + def test_dictgetordefault(self): + self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(3), 'n/a'), 'Three') + self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(99), 'n/a'), 'n/a') + + def test_dicthas(self): + self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) + self._test_func(F.dictHas(self.dict_name, F.toUInt64(99)), 0) + + +class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase): + + def _create_dictionary(self): + # Create a table to be used as source for the dictionary + self.database.create_table(Region) + self.database.insert([ + Region(region_id=1, parent_region=0, region_name='Russia'), + Region(region_id=2, parent_region=1, region_name='Moscow'), + Region(region_id=3, parent_region=2, region_name='Center'), + Region(region_id=4, parent_region=0, region_name='Great Britain'), + Region(region_id=5, parent_region=4, region_name='London'), + ]) + # Create the dictionary + self.database.raw(""" + CREATE DICTIONARY regions_dict( + region_id UInt64, + parent_region UInt64 HIERARCHICAL, + region_name String DEFAULT '?' + ) + PRIMARY KEY region_id + SOURCE(CLICKHOUSE( + HOST 'localhost' PORT 9000 USER 'default' PASSWORD '' DB 'test-db' TABLE 'region' + )) + LIFETIME(100) + LAYOUT(HASHED()); + """) + self.dict_name = 'test-db.regions_dict' + + def test_dictget(self): + self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(3)), 'Center') + self._test_func(F.dictGet(self.dict_name, 'parent_region', F.toUInt64(3)), 2) + self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(99)), '?') + + def test_dictgetordefault(self): + self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(3), 'n/a'), 'Center') + self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(99), 'n/a'), 'n/a') + + def test_dicthas(self): + self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) + self._test_func(F.dictHas(self.dict_name, F.toUInt64(99)), 0) + + def test_dictgethierarchy(self): + self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1]) + self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [99]) + + def test_dictisin(self): + self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1) + self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(4)), 0) + self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(99), F.toUInt64(4)), 0) + + +class NumberName(Model): + ''' A table to act as a source for the dictionary ''' + + number = UInt64Field() + name = StringField() + + engine = Memory() + + +class Region(Model): + + region_id = UInt64Field() + parent_region = UInt64Field() + region_name = StringField() + + engine = Memory()