diff --git a/pyproject.toml b/pyproject.toml index e6f5743..262f1d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "iso8601 >= 0.1.12", "setuptools" ] -version = "0.2.0" +version = "0.2.1" [tool.setuptools.packages.find] where = ["src"] diff --git a/scripts/generate_ref.py b/scripts/generate_ref.py index eaad4b0..acccc9c 100644 --- a/scripts/generate_ref.py +++ b/scripts/generate_ref.py @@ -1,12 +1,11 @@ - import inspect from collections import namedtuple -DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value') +DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value") def _get_default_arg(args, defaults, arg_index): - """ Method that determines if an argument has default value or not, + """Method that determines if an argument has default value or not, and if yes what is the default value for the argument :param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg'] @@ -26,13 +25,13 @@ def _get_default_arg(args, defaults, arg_index): return DefaultArgSpec(False, None) else: value = defaults[arg_index - args_with_no_defaults] - if (type(value) is str): + if type(value) is str: value = '"%s"' % value return DefaultArgSpec(True, value) def get_method_sig(method): - """ Given a function, it returns a string that pretty much looks how the + """Given a function, it returns a string that pretty much looks how the function signature would be written in python. :param method: a python method @@ -59,16 +58,16 @@ def get_method_sig(method): args.append(arg) arg_index += 1 if argspec.varargs: - args.append('*' + argspec.varargs) + args.append("*" + argspec.varargs) if argspec.varkw: - args.append('**' + argspec.varkw) + args.append("**" + argspec.varkw) return "%s(%s)" % (method.__name__, ", ".join(args[1:])) def docstring(obj): - doc = (obj.__doc__ or '').rstrip() + doc = (obj.__doc__ or "").rstrip() if doc: - lines = doc.split('\n') + lines = doc.split("\n") # Find the length of the whitespace prefix common to all non-empty lines indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip()) # Output the lines without the indentation @@ -78,30 +77,32 @@ def docstring(obj): def class_doc(cls, list_methods=True): - bases = ', '.join([b.__name__ for b in cls.__bases__]) - print('###', cls.__name__) + bases = ", ".join([b.__name__ for b in cls.__bases__]) + print("###", cls.__name__) print() - if bases != 'object': - print('Extends', bases) + if bases != "object": + print("Extends", bases) print() docstring(cls) - for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)): - if name == '__init__': + for name, method in inspect.getmembers( + cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m) + ): + if name == "__init__": # Initializer - print('####', get_method_sig(method).replace(name, cls.__name__)) - elif name[0] == '_': + print("####", get_method_sig(method).replace(name, cls.__name__)) + elif name[0] == "_": # Private method continue - elif hasattr(method, '__self__') and method.__self__ == cls: + elif hasattr(method, "__self__") and method.__self__ == cls: # Class method if not list_methods: continue - print('#### %s.%s' % (cls.__name__, get_method_sig(method))) + print("#### %s.%s" % (cls.__name__, get_method_sig(method))) else: # Regular method if not list_methods: continue - print('####', get_method_sig(method)) + print("####", get_method_sig(method)) print() docstring(method) print() @@ -110,7 +111,7 @@ def class_doc(cls, list_methods=True): def module_doc(classes, list_methods=True): mdl = classes[0].__module__ print(mdl) - print('-' * len(mdl)) + print("-" * len(mdl)) print() for cls in classes: class_doc(cls, list_methods) @@ -120,7 +121,7 @@ def all_subclasses(cls): return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)] -if __name__ == '__main__': +if __name__ == "__main__": from clickhouse_orm import database from clickhouse_orm import fields @@ -131,13 +132,24 @@ if __name__ == '__main__': from clickhouse_orm import system_models from clickhouse_orm.aio import database as aio_database - print('Class Reference') - print('===============') + print("Class Reference") + print("===============") print() module_doc([database.Database, database.DatabaseException]) module_doc([aio_database.AioDatabase]) - module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index]) - module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False) + module_doc( + [ + models.Model, + models.BufferModel, + models.MergeModel, + models.DistributedModel, + models.Constraint, + models.Index, + ] + ) + module_doc( + sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False + ) module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False) module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False) module_doc([query.QuerySet, query.AggregateQuerySet, query.Q]) diff --git a/scripts/html_to_markdown_toc.py b/scripts/html_to_markdown_toc.py index 552137f..70e0c5d 100644 --- a/scripts/html_to_markdown_toc.py +++ b/scripts/html_to_markdown_toc.py @@ -2,13 +2,13 @@ from html.parser import HTMLParser import sys -HEADER_TAGS = ('h1', 'h2', 'h3') +HEADER_TAGS = ("h1", "h2", "h3") class HeadersToMarkdownParser(HTMLParser): inside = None - text = '' + text = "" def handle_starttag(self, tag, attrs): if tag.lower() in HEADER_TAGS: @@ -16,11 +16,11 @@ class HeadersToMarkdownParser(HTMLParser): def handle_endtag(self, tag): if tag.lower() in HEADER_TAGS: - indent = ' ' * int(self.inside[1]) - fragment = self.text.lower().replace(' ', '-').replace('.', '') - print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment)) + indent = " " * int(self.inside[1]) + fragment = self.text.lower().replace(" ", "-").replace(".", "") + print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment)) self.inside = None - self.text = '' + self.text = "" def handle_data(self, data): if self.inside: @@ -28,4 +28,4 @@ class HeadersToMarkdownParser(HTMLParser): HeadersToMarkdownParser().feed(sys.stdin.read()) -print('') +print("") diff --git a/src/clickhouse_orm/engines.py b/src/clickhouse_orm/engines.py index 27513f0..4c83db6 100644 --- a/src/clickhouse_orm/engines.py +++ b/src/clickhouse_orm/engines.py @@ -43,7 +43,7 @@ class Lazy(DatabaseEngine): for which there is a long time interval between accesses. """ - def __int__(self, expiration_time_in_seconds: int): + def __init__(self, expiration_time_in_seconds: int): self.expiration_time_in_seconds = expiration_time_in_seconds def create_database_sql(self) -> str: @@ -59,7 +59,7 @@ class MySQL(DatabaseEngine): such as SHOW TABLES or SHOW CREATE TABLE. """ - def __int__(self, host: str, port: int, database: str, user: str, password: str): + def __init__(self, host: str, port: int, database: str, user: str, password: str): """ - `host`: MySQL server address. - `port`: MySQL server port. @@ -90,7 +90,7 @@ class PostgreSQL(DatabaseEngine): but can be updated with DETACH and ATTACH queries. """ - def __int__( + def __init__( self, host: str, port: int, diff --git a/src/clickhouse_orm/fields.py b/src/clickhouse_orm/fields.py index b0025e6..ea3fec3 100644 --- a/src/clickhouse_orm/fields.py +++ b/src/clickhouse_orm/fields.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals, annotations +import json import re from enum import Enum from uuid import UUID @@ -14,7 +15,7 @@ import iso8601 import pytz from pytz import BaseTzInfo -from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names +from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names, parse_map from .funcs import F, FunctionOperatorsMixin if TYPE_CHECKING: @@ -870,5 +871,75 @@ class LowCardinalityField(Field): return sql +MapKey = Union[ + StringField, + BaseIntField, + LowCardinalityField, + UUIDField, + DateField, + DateTimeField, + BaseEnumField, +] + + +class MapField(Field): + """MapField is only experimental and is expected to be available in the next few releases""" + class_default = None # incorrect value + + def __init__( + self, + key: MapKey, + value: Field, + default: dict = None, + alias: Optional[Union[F, str]] = None, + materialized: Optional[Union[F, str]] = None, + readonly: bool = None, + codec: Optional[str] = None, + db_column: Optional[str] = None, + ): + assert isinstance(key, Field) + assert isinstance(value, Field) + self.key = key + self.value = value + if not default: + default = {} + super().__init__(default, alias, materialized, readonly, codec, db_column) + + def to_python(self, value, timezone_in_use) -> dict: + if isinstance(value, bytes): + value = value.decode('utf-8') + if isinstance(value, str): + value = parse_map(value) + if not isinstance(value, dict): + raise ValueError("MapField expects dict, not %s" % type(value)) + value = { + self.key.to_python(k, timezone_in_use): self.value.to_python(v, timezone_in_use) + for k, v in value.items() + } + return value + + def validate(self, value): + for k, v in value.items(): + self.key.validate(k) + self.value.validate(v) + + def to_db_string(self, value, quote=True) -> str: + value2 = {} + int_key = isinstance(self.key, (BaseIntField, BaseFloatField)) + int_value = isinstance(self.value, (BaseIntField, BaseFloatField)) + for k, v in value.items(): + if not int_key: + k = self.key.to_db_string(k, quote=False) + if not int_value: + v = self.value.to_db_string(v, quote=False) + value2[k] = v + return json.dumps(value2).replace("\"", "'") + + def get_sql(self, with_default_expression=True, db=None) -> str: + sql = "Map(%s, %s)" % (self.key.get_sql(False), self.value.get_sql(False)) + if with_default_expression and self.codec and db and db.has_codec_support: + sql += " CODEC(%s)" % self.codec + return sql + # Expose only relevant classes in import * __all__ = get_subclass_names(locals(), Field) diff --git a/src/clickhouse_orm/funcs.py b/src/clickhouse_orm/funcs.py index db7f7c7..772fc56 100644 --- a/src/clickhouse_orm/funcs.py +++ b/src/clickhouse_orm/funcs.py @@ -186,6 +186,12 @@ class FunctionOperatorsMixin: def isNotIn(self, others): return F._notIn(self, others) # pylint: disable=W0212 + def isGlobalIn(self, others): + return F._gin(self, others) # pylint: disable=W0212 + + def isNotGlobalIn(self, others): + return F._notGIn(self, others) # pylint: disable=W0212 + class FMeta(type): @@ -404,6 +410,20 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): # pylint: disable=R0904 b = tuple(b) return F("NOT IN", a, b) + @staticmethod + @binary_operator + def _gin(a, b): + if is_iterable(b) and not isinstance(b, (tuple, QuerySet)): + b = tuple(b) + return F("GLOBAL IN", a, b) + + @staticmethod + @binary_operator + def _notGIn(a, b): + if is_iterable(b) and not isinstance(b, (tuple, QuerySet)): + b = tuple(b) + return F("NOT GLOBAL IN", a, b) + # Functions for working with dates and times @staticmethod diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index 27b1a93..a53aaf9 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -253,6 +253,12 @@ class ModelBase(type): else: name_fields.append((str(i), cls.create_ad_hoc_field(tp[0]))) return orm_fields.TupleField(name_fields=name_fields) + # Map + if db_type.startswith("Map"): + types = [s.strip() for s in db_type[4:-1].split(",")] + key_filed = cls.create_ad_hoc_field(types[0]) + value_filed = cls.create_ad_hoc_field(types[1]) + return orm_fields.MapField(key_filed, value_filed) # FixedString if db_type.startswith("FixedString"): length = int(db_type[12:-1]) diff --git a/src/clickhouse_orm/utils.py b/src/clickhouse_orm/utils.py index 441ee6b..401647f 100644 --- a/src/clickhouse_orm/utils.py +++ b/src/clickhouse_orm/utils.py @@ -1,4 +1,5 @@ import codecs +import json import re from datetime import date, datetime, tzinfo, timedelta @@ -16,6 +17,7 @@ SPECIAL_CHARS = { SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]") POINT_REGEX = re.compile(r"\((?P\d+(\.\d+)?),(?P\d+(\.\d+)?)\)") RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]") +MAP_REGEX = re.compile(r"(.+?)=(.+?),?") def escape(value, quote=True): @@ -120,6 +122,22 @@ def parse_array(array_string): array_string = array_string[match.end() - 1 :] +def parse_map(map_string: str) -> dict: + """ + Parse an map string as returned by clickhouse. For example: + "{key1=1, key2=2, key3=3}" ==> {"key1": 1, "key2": 2, "key3: 3} + """ + if any([map_string[0] != "{", map_string[-1] != "}"]): + raise ValueError('Invalid map string: "%s"' % map_string) + ret = {} + try: + ret = json.loads(map_string.replace("'", "\"")) + except json.JSONDecodeError: + for (key, value) in MAP_REGEX.findall(map_string[1:-1]): + ret[key] = value + return ret + + def import_submodules(package_name): """ Import all submodules of a module. diff --git a/tests/test_map_fields.py b/tests/test_map_fields.py new file mode 100644 index 0000000..0619254 --- /dev/null +++ b/tests/test_map_fields.py @@ -0,0 +1,59 @@ +import unittest +from datetime import date + +from clickhouse_orm.database import Database +from clickhouse_orm.models import Model +from clickhouse_orm.fields import MapField, DateField, StringField, Int32Field, Float64Field +from clickhouse_orm.engines import MergeTree + + +class TupleFieldsTest(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db', log_statements=True) + self.database.create_table(ModelWithTuple) + + def tearDown(self): + self.database.drop_database() + + def test_insert_and_select(self): + instance = ModelWithTuple( + date_field='2022-06-26', + map1={"k1": "v1", "k2": "v2"}, + map2={"v1": 1, "v2": 2}, + map3={"f1": 1.1, "f2": 2.0}, + map4={"2022-06-25": "ok", "2022-06-26": "today"} + ) + self.database.insert([instance]) + query = 'SELECT * from $db.modelwithtuple ORDER BY date_field' + for model_cls in (ModelWithTuple, None): + results = list(self.database.select(query, model_cls)) + self.assertEqual(len(results), 1) + self.assertIn("k1", results[0].map1) + self.assertEqual(results[0].map1["k2"], "v2") + self.assertEqual(results[0].map2["v1"], 1) + self.assertEqual(results[0].map3["f2"], 2.0) + self.assertEqual(results[0].map4[date(2022, 6, 26)], "today") + + def test_conversion(self): + instance = ModelWithTuple( + map2="{'1': '2'}" + ) + self.assertEqual(instance.map2['1'], 2) + + def test_assignment_error(self): + instance = ModelWithTuple() + for value in (7, 'x', [date.today()], ['aaa'], [None]): + with self.assertRaises(ValueError): + instance.map1 = value + + +class ModelWithTuple(Model): + + date_field = DateField() + map1 = MapField(StringField(), StringField()) + map2 = MapField(StringField(), Int32Field()) + map3 = MapField(StringField(), Float64Field()) + map4 = MapField(DateField(), StringField()) + + engine = MergeTree('date_field', ('date_field',))