mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-03 03:30:17 +03:00
Enhancement: remove \N warning for utils
Also improves code quality and implementation details.
This commit is contained in:
parent
4a90eede16
commit
bd62a3c1de
|
@ -86,12 +86,12 @@ class MergeTree(Engine):
|
||||||
# Let's check version and use new syntax if available
|
# Let's check version and use new syntax if available
|
||||||
if db.server_version >= (1, 1, 54310):
|
if db.server_version >= (1, 1, 54310):
|
||||||
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
|
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
|
||||||
comma_join(self.partition_key, stringify=True),
|
comma_join(map(str, self.partition_key)),
|
||||||
comma_join(self.order_by, stringify=True),
|
comma_join(map(str, self.order_by)),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.primary_key:
|
if self.primary_key:
|
||||||
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
|
partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
|
||||||
|
|
||||||
if self.sampling_expr:
|
if self.sampling_expr:
|
||||||
partition_sql += " SAMPLE BY %s" % self.sampling_expr
|
partition_sql += " SAMPLE BY %s" % self.sampling_expr
|
||||||
|
@ -126,7 +126,7 @@ class MergeTree(Engine):
|
||||||
params.append(self.date_col)
|
params.append(self.date_col)
|
||||||
if self.sampling_expr:
|
if self.sampling_expr:
|
||||||
params.append(self.sampling_expr)
|
params.append(self.sampling_expr)
|
||||||
params.append("(%s)" % comma_join(self.order_by, stringify=True))
|
params.append("(%s)" % comma_join(map(str(self.order_by))))
|
||||||
params.append(str(self.index_granularity))
|
params.append(str(self.index_granularity))
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
|
@ -147,7 +147,7 @@ class StringField(Field):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
return value.decode("UTF-8")
|
return value.decode("utf-8")
|
||||||
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
|
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ class FixedStringField(StringField):
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = value.encode("UTF-8")
|
value = value.encode("utf-8")
|
||||||
if len(value) > self._length:
|
if len(value) > self._length:
|
||||||
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
||||||
|
|
||||||
|
@ -475,7 +475,7 @@ class BaseEnumField(Field):
|
||||||
except Exception:
|
except Exception:
|
||||||
return self.enum_cls(value)
|
return self.enum_cls(value)
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
decoded = value.decode("UTF-8")
|
decoded = value.decode("utf-8")
|
||||||
try:
|
try:
|
||||||
return self.enum_cls[decoded]
|
return self.enum_cls[decoded]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -533,7 +533,7 @@ class ArrayField(Field):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = parse_array(value)
|
value = parse_array(value)
|
||||||
elif isinstance(value, bytes):
|
elif isinstance(value, bytes):
|
||||||
value = parse_array(value.decode("UTF-8"))
|
value = parse_array(value.decode("utf-8"))
|
||||||
elif not isinstance(value, (list, tuple)):
|
elif not isinstance(value, (list, tuple)):
|
||||||
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
|
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
|
||||||
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
||||||
|
|
|
@ -76,7 +76,7 @@ def parametric(func):
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
f = func(*args, **kwargs)
|
f = func(*args, **kwargs)
|
||||||
# Append the parameter to the function name
|
# Append the parameter to the function name
|
||||||
parameters_str = comma_join(parameters, stringify=True)
|
parameters_str = comma_join(map(str, parameters))
|
||||||
f.name = "%s(%s)" % (f.name, parameters_str)
|
f.name = "%s(%s)" % (f.name, parameters_str)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
|
@ -3,30 +3,27 @@ import importlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import re
|
import re
|
||||||
from datetime import date, datetime, timedelta, tzinfo
|
from datetime import date, datetime, timedelta, tzinfo
|
||||||
|
from inspect import isclass
|
||||||
SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"}
|
from types import ModuleType
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||||
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
|
||||||
|
|
||||||
|
|
||||||
def escape(value, quote=True):
|
def escape(value: str, quote: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
If the value is a string, escapes any special characters and optionally
|
If the value is a string, escapes any special characters and optionally
|
||||||
surrounds it with single quotes. If the value is not a string (e.g. a number),
|
surrounds it with single quotes. If the value is not a string (e.g. a number),
|
||||||
converts it to one.
|
converts it to one.
|
||||||
"""
|
"""
|
||||||
|
value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
|
||||||
|
if quote:
|
||||||
|
value = "'" + value + "'"
|
||||||
|
|
||||||
def escape_one(match):
|
return value
|
||||||
return SPECIAL_CHARS[match.group(0)]
|
|
||||||
|
|
||||||
if isinstance(value, str):
|
|
||||||
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
|
|
||||||
if quote:
|
|
||||||
value = "'" + value + "'"
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
|
|
||||||
def unescape(value):
|
def unescape(value: str) -> Optional[str]:
|
||||||
|
if value == "\\N":
|
||||||
|
return None
|
||||||
return codecs.escape_decode(value)[0].decode("utf-8")
|
return codecs.escape_decode(value)[0].decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +31,7 @@ def string_or_func(obj):
|
||||||
return obj.to_sql() if hasattr(obj, "to_sql") else obj
|
return obj.to_sql() if hasattr(obj, "to_sql") else obj
|
||||||
|
|
||||||
|
|
||||||
def arg_to_sql(arg):
|
def arg_to_sql(arg: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a function argument to SQL string according to its type.
|
Converts a function argument to SQL string according to its type.
|
||||||
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
|
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
|
||||||
|
@ -69,15 +66,15 @@ def arg_to_sql(arg):
|
||||||
return str(arg)
|
return str(arg)
|
||||||
|
|
||||||
|
|
||||||
def parse_tsv(line):
|
def parse_tsv(line: Union[bytes, str]) -> List[str]:
|
||||||
if isinstance(line, bytes):
|
if isinstance(line, bytes):
|
||||||
line = line.decode()
|
line = line.decode()
|
||||||
if line and line[-1] == "\n":
|
if line and line[-1] == "\n":
|
||||||
line = line[:-1]
|
line = line[:-1]
|
||||||
return [unescape(value) for value in line.split(str("\t"))]
|
return [unescape(value) for value in line.split("\t")]
|
||||||
|
|
||||||
|
|
||||||
def parse_array(array_string):
|
def parse_array(array_string: str) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Parse an array or tuple string as returned by clickhouse. For example:
|
Parse an array or tuple string as returned by clickhouse. For example:
|
||||||
"['hello', 'world']" ==> ["hello", "world"]
|
"['hello', 'world']" ==> ["hello", "world"]
|
||||||
|
@ -111,7 +108,7 @@ def parse_array(array_string):
|
||||||
array_string = array_string[match.end() - 1 :]
|
array_string = array_string[match.end() - 1 :]
|
||||||
|
|
||||||
|
|
||||||
def import_submodules(package_name):
|
def import_submodules(package_name: str) -> Dict[str, ModuleType]:
|
||||||
"""
|
"""
|
||||||
Import all submodules of a module.
|
Import all submodules of a module.
|
||||||
"""
|
"""
|
||||||
|
@ -122,17 +119,14 @@ def import_submodules(package_name):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def comma_join(items, stringify=False):
|
def comma_join(items: Iterable[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Joins an iterable of strings with commas.
|
Joins an iterable of strings with commas.
|
||||||
"""
|
"""
|
||||||
if stringify:
|
return ", ".join(items)
|
||||||
return ", ".join(str(item) for item in items)
|
|
||||||
else:
|
|
||||||
return ", ".join(items)
|
|
||||||
|
|
||||||
|
|
||||||
def is_iterable(obj):
|
def is_iterable(obj: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if the given object is iterable.
|
Checks if the given object is iterable.
|
||||||
"""
|
"""
|
||||||
|
@ -143,9 +137,7 @@ def is_iterable(obj):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_subclass_names(locals, base_class):
|
def get_subclass_names(locals: Dict[str, Any], base_class: Type):
|
||||||
from inspect import isclass
|
|
||||||
|
|
||||||
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
|
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
|
|
@ -3,8 +3,8 @@ import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from clickhouse_orm.database import Database
|
from clickhouse_orm.database import Database
|
||||||
from clickhouse_orm.engines import *
|
from clickhouse_orm.engines import MergeTree
|
||||||
from clickhouse_orm.fields import *
|
from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
|
||||||
from clickhouse_orm.models import Model
|
from clickhouse_orm.models import Model
|
||||||
|
|
||||||
logging.getLogger("requests").setLevel(logging.WARNING)
|
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||||
|
|
29
tests/test_utils.py
Normal file
29
tests/test_utils.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from clickhouse_orm.utils import escape, unescape
|
||||||
|
|
||||||
|
SPECIAL_CHARS = {"\b": "\\x08", "\f": "\\x0c", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\x00", "\\": "\\\\", "'": "\\'"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_unescape():
|
||||||
|
|
||||||
|
for input_, expected in (
|
||||||
|
("π\\t", "π\t"),
|
||||||
|
("\\new", "\new"),
|
||||||
|
("cheeky 🐵", "cheeky 🐵"),
|
||||||
|
("\\N", None),
|
||||||
|
):
|
||||||
|
assert unescape(input_) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_escape_special_chars():
|
||||||
|
|
||||||
|
initial = "".join(SPECIAL_CHARS.keys())
|
||||||
|
expected = "".join(SPECIAL_CHARS.values())
|
||||||
|
assert escape(initial, quote=False) == expected
|
||||||
|
assert escape(initial) == "'" + expected + "'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_escape_unescape_parity():
|
||||||
|
|
||||||
|
for initial in ("π\t", "\new", "cheeky 🐵", "back \\ slash", "\\\\n"):
|
||||||
|
assert unescape(escape(initial, quote=False)) == initial
|
Loading…
Reference in New Issue
Block a user