diff --git a/clickhouse_orm/engines.py b/clickhouse_orm/engines.py index 8dfda54..af63af7 100644 --- a/clickhouse_orm/engines.py +++ b/clickhouse_orm/engines.py @@ -86,12 +86,12 @@ class MergeTree(Engine): # Let's check version and use new syntax if available if db.server_version >= (1, 1, 54310): partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % ( - comma_join(self.partition_key, stringify=True), - comma_join(self.order_by, stringify=True), + comma_join(map(str, self.partition_key)), + comma_join(map(str, self.order_by)), ) if self.primary_key: - partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) + partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key)) if self.sampling_expr: partition_sql += " SAMPLE BY %s" % self.sampling_expr @@ -126,7 +126,7 @@ class MergeTree(Engine): params.append(self.date_col) if self.sampling_expr: params.append(self.sampling_expr) - params.append("(%s)" % comma_join(self.order_by, stringify=True)) + params.append("(%s)" % comma_join(map(str(self.order_by)))) params.append(str(self.index_granularity)) return params diff --git a/clickhouse_orm/fields.py b/clickhouse_orm/fields.py index 31c36eb..8ead5a2 100644 --- a/clickhouse_orm/fields.py +++ b/clickhouse_orm/fields.py @@ -147,7 +147,7 @@ class StringField(Field): if isinstance(value, str): return value if isinstance(value, bytes): - return value.decode("UTF-8") + return value.decode("utf-8") raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value)) @@ -163,7 +163,7 @@ class FixedStringField(StringField): def validate(self, value): if isinstance(value, str): - value = value.encode("UTF-8") + value = value.encode("utf-8") if len(value) > self._length: raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length)) @@ -475,7 +475,7 @@ class BaseEnumField(Field): except Exception: return self.enum_cls(value) if isinstance(value, bytes): - decoded = value.decode("UTF-8") + decoded = value.decode("utf-8") try: return self.enum_cls[decoded] except Exception: @@ -533,7 +533,7 @@ class ArrayField(Field): if isinstance(value, str): value = parse_array(value) elif isinstance(value, bytes): - value = parse_array(value.decode("UTF-8")) + value = parse_array(value.decode("utf-8")) elif not isinstance(value, (list, tuple)): raise ValueError("ArrayField expects list or tuple, not %s" % type(value)) return [self.inner_field.to_python(v, timezone_in_use) for v in value] diff --git a/clickhouse_orm/funcs.py b/clickhouse_orm/funcs.py index 52be002..4293c9b 100644 --- a/clickhouse_orm/funcs.py +++ b/clickhouse_orm/funcs.py @@ -76,7 +76,7 @@ def parametric(func): def inner(*args, **kwargs): f = func(*args, **kwargs) # Append the parameter to the function name - parameters_str = comma_join(parameters, stringify=True) + parameters_str = comma_join(map(str, parameters)) f.name = "%s(%s)" % (f.name, parameters_str) return f diff --git a/clickhouse_orm/utils.py b/clickhouse_orm/utils.py index 8be51f3..c4526f4 100644 --- a/clickhouse_orm/utils.py +++ b/clickhouse_orm/utils.py @@ -3,30 +3,27 @@ import importlib import pkgutil import re from datetime import date, datetime, timedelta, tzinfo - -SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"} - -SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]") +from inspect import isclass +from types import ModuleType +from typing import Any, Dict, Iterable, List, Optional, Type, Union -def escape(value, quote=True): +def escape(value: str, quote: bool = True) -> str: """ If the value is a string, escapes any special characters and optionally surrounds it with single quotes. If the value is not a string (e.g. a number), converts it to one. """ + value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8") + if quote: + value = "'" + value + "'" - def escape_one(match): - return SPECIAL_CHARS[match.group(0)] - - if isinstance(value, str): - value = SPECIAL_CHARS_REGEX.sub(escape_one, value) - if quote: - value = "'" + value + "'" - return str(value) + return value -def unescape(value): +def unescape(value: str) -> Optional[str]: + if value == "\\N": + return None return codecs.escape_decode(value)[0].decode("utf-8") @@ -34,7 +31,7 @@ def string_or_func(obj): return obj.to_sql() if hasattr(obj, "to_sql") else obj -def arg_to_sql(arg): +def arg_to_sql(arg: Any) -> str: """ Converts a function argument to SQL string according to its type. Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans, @@ -69,15 +66,15 @@ def arg_to_sql(arg): return str(arg) -def parse_tsv(line): +def parse_tsv(line: Union[bytes, str]) -> List[str]: if isinstance(line, bytes): line = line.decode() if line and line[-1] == "\n": line = line[:-1] - return [unescape(value) for value in line.split(str("\t"))] + return [unescape(value) for value in line.split("\t")] -def parse_array(array_string): +def parse_array(array_string: str) -> List[Any]: """ Parse an array or tuple string as returned by clickhouse. For example: "['hello', 'world']" ==> ["hello", "world"] @@ -111,7 +108,7 @@ def parse_array(array_string): array_string = array_string[match.end() - 1 :] -def import_submodules(package_name): +def import_submodules(package_name: str) -> Dict[str, ModuleType]: """ Import all submodules of a module. """ @@ -122,17 +119,14 @@ def import_submodules(package_name): } -def comma_join(items, stringify=False): +def comma_join(items: Iterable[str]) -> str: """ Joins an iterable of strings with commas. """ - if stringify: - return ", ".join(str(item) for item in items) - else: - return ", ".join(items) + return ", ".join(items) -def is_iterable(obj): +def is_iterable(obj: Any) -> bool: """ Checks if the given object is iterable. """ @@ -143,9 +137,7 @@ def is_iterable(obj): return False -def get_subclass_names(locals, base_class): - from inspect import isclass - +def get_subclass_names(locals: Dict[str, Any], base_class: Type): return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] diff --git a/tests/__init__.py b/tests/__init__.py index 5284146..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -__import__("pkg_resources").declare_namespace(__name__) diff --git a/tests/base_test_with_data.py b/tests/base_test_with_data.py index 17da512..a026f5a 100644 --- a/tests/base_test_with_data.py +++ b/tests/base_test_with_data.py @@ -3,8 +3,8 @@ import logging import unittest from clickhouse_orm.database import Database -from clickhouse_orm.engines import * -from clickhouse_orm.fields import * +from clickhouse_orm.engines import MergeTree +from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field from clickhouse_orm.models import Model logging.getLogger("requests").setLevel(logging.WARNING) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..74722f8 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +from clickhouse_orm.utils import escape, unescape + +SPECIAL_CHARS = {"\b": "\\x08", "\f": "\\x0c", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\x00", "\\": "\\\\", "'": "\\'"} + + +def test_unescape(): + + for input_, expected in ( + ("π\\t", "π\t"), + ("\\new", "\new"), + ("cheeky 🐵", "cheeky 🐵"), + ("\\N", None), + ): + assert unescape(input_) == expected + + +def test_escape_special_chars(): + + initial = "".join(SPECIAL_CHARS.keys()) + expected = "".join(SPECIAL_CHARS.values()) + assert escape(initial, quote=False) == expected + assert escape(initial) == "'" + expected + "'" + + +def test_escape_unescape_parity(): + + for initial in ("π\t", "\new", "cheeky 🐵", "back \\ slash", "\\\\n"): + assert unescape(escape(initial, quote=False)) == initial