mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 19:20:14 +03:00
Enhancement: remove \N warning for utils
Also improves code quality and implementation details.
This commit is contained in:
parent
4a90eede16
commit
bd62a3c1de
|
@ -86,12 +86,12 @@ class MergeTree(Engine):
|
|||
# Let's check version and use new syntax if available
|
||||
if db.server_version >= (1, 1, 54310):
|
||||
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
|
||||
comma_join(self.partition_key, stringify=True),
|
||||
comma_join(self.order_by, stringify=True),
|
||||
comma_join(map(str, self.partition_key)),
|
||||
comma_join(map(str, self.order_by)),
|
||||
)
|
||||
|
||||
if self.primary_key:
|
||||
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
|
||||
partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
|
||||
|
||||
if self.sampling_expr:
|
||||
partition_sql += " SAMPLE BY %s" % self.sampling_expr
|
||||
|
@ -126,7 +126,7 @@ class MergeTree(Engine):
|
|||
params.append(self.date_col)
|
||||
if self.sampling_expr:
|
||||
params.append(self.sampling_expr)
|
||||
params.append("(%s)" % comma_join(self.order_by, stringify=True))
|
||||
params.append("(%s)" % comma_join(map(str(self.order_by))))
|
||||
params.append(str(self.index_granularity))
|
||||
|
||||
return params
|
||||
|
|
|
@ -147,7 +147,7 @@ class StringField(Field):
|
|||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("UTF-8")
|
||||
return value.decode("utf-8")
|
||||
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ class FixedStringField(StringField):
|
|||
|
||||
def validate(self, value):
|
||||
if isinstance(value, str):
|
||||
value = value.encode("UTF-8")
|
||||
value = value.encode("utf-8")
|
||||
if len(value) > self._length:
|
||||
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
|
||||
|
||||
|
@ -475,7 +475,7 @@ class BaseEnumField(Field):
|
|||
except Exception:
|
||||
return self.enum_cls(value)
|
||||
if isinstance(value, bytes):
|
||||
decoded = value.decode("UTF-8")
|
||||
decoded = value.decode("utf-8")
|
||||
try:
|
||||
return self.enum_cls[decoded]
|
||||
except Exception:
|
||||
|
@ -533,7 +533,7 @@ class ArrayField(Field):
|
|||
if isinstance(value, str):
|
||||
value = parse_array(value)
|
||||
elif isinstance(value, bytes):
|
||||
value = parse_array(value.decode("UTF-8"))
|
||||
value = parse_array(value.decode("utf-8"))
|
||||
elif not isinstance(value, (list, tuple)):
|
||||
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
|
||||
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
||||
|
|
|
@ -76,7 +76,7 @@ def parametric(func):
|
|||
def inner(*args, **kwargs):
|
||||
f = func(*args, **kwargs)
|
||||
# Append the parameter to the function name
|
||||
parameters_str = comma_join(parameters, stringify=True)
|
||||
parameters_str = comma_join(map(str, parameters))
|
||||
f.name = "%s(%s)" % (f.name, parameters_str)
|
||||
return f
|
||||
|
||||
|
|
|
@ -3,30 +3,27 @@ import importlib
|
|||
import pkgutil
|
||||
import re
|
||||
from datetime import date, datetime, timedelta, tzinfo
|
||||
|
||||
SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"}
|
||||
|
||||
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
||||
from inspect import isclass
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
|
||||
def escape(value, quote=True):
|
||||
def escape(value: str, quote: bool = True) -> str:
|
||||
"""
|
||||
If the value is a string, escapes any special characters and optionally
|
||||
surrounds it with single quotes. If the value is not a string (e.g. a number),
|
||||
converts it to one.
|
||||
"""
|
||||
value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
|
||||
if quote:
|
||||
value = "'" + value + "'"
|
||||
|
||||
def escape_one(match):
|
||||
return SPECIAL_CHARS[match.group(0)]
|
||||
|
||||
if isinstance(value, str):
|
||||
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
|
||||
if quote:
|
||||
value = "'" + value + "'"
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
def unescape(value):
|
||||
def unescape(value: str) -> Optional[str]:
|
||||
if value == "\\N":
|
||||
return None
|
||||
return codecs.escape_decode(value)[0].decode("utf-8")
|
||||
|
||||
|
||||
|
@ -34,7 +31,7 @@ def string_or_func(obj):
|
|||
return obj.to_sql() if hasattr(obj, "to_sql") else obj
|
||||
|
||||
|
||||
def arg_to_sql(arg):
|
||||
def arg_to_sql(arg: Any) -> str:
|
||||
"""
|
||||
Converts a function argument to SQL string according to its type.
|
||||
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
|
||||
|
@ -69,15 +66,15 @@ def arg_to_sql(arg):
|
|||
return str(arg)
|
||||
|
||||
|
||||
def parse_tsv(line):
|
||||
def parse_tsv(line: Union[bytes, str]) -> List[str]:
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode()
|
||||
if line and line[-1] == "\n":
|
||||
line = line[:-1]
|
||||
return [unescape(value) for value in line.split(str("\t"))]
|
||||
return [unescape(value) for value in line.split("\t")]
|
||||
|
||||
|
||||
def parse_array(array_string):
|
||||
def parse_array(array_string: str) -> List[Any]:
|
||||
"""
|
||||
Parse an array or tuple string as returned by clickhouse. For example:
|
||||
"['hello', 'world']" ==> ["hello", "world"]
|
||||
|
@ -111,7 +108,7 @@ def parse_array(array_string):
|
|||
array_string = array_string[match.end() - 1 :]
|
||||
|
||||
|
||||
def import_submodules(package_name):
|
||||
def import_submodules(package_name: str) -> Dict[str, ModuleType]:
|
||||
"""
|
||||
Import all submodules of a module.
|
||||
"""
|
||||
|
@ -122,17 +119,14 @@ def import_submodules(package_name):
|
|||
}
|
||||
|
||||
|
||||
def comma_join(items, stringify=False):
|
||||
def comma_join(items: Iterable[str]) -> str:
|
||||
"""
|
||||
Joins an iterable of strings with commas.
|
||||
"""
|
||||
if stringify:
|
||||
return ", ".join(str(item) for item in items)
|
||||
else:
|
||||
return ", ".join(items)
|
||||
return ", ".join(items)
|
||||
|
||||
|
||||
def is_iterable(obj):
|
||||
def is_iterable(obj: Any) -> bool:
|
||||
"""
|
||||
Checks if the given object is iterable.
|
||||
"""
|
||||
|
@ -143,9 +137,7 @@ def is_iterable(obj):
|
|||
return False
|
||||
|
||||
|
||||
def get_subclass_names(locals, base_class):
|
||||
from inspect import isclass
|
||||
|
||||
def get_subclass_names(locals: Dict[str, Any], base_class: Type):
|
||||
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
|
||||
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
__import__("pkg_resources").declare_namespace(__name__)
|
|
@ -3,8 +3,8 @@ import logging
|
|||
import unittest
|
||||
|
||||
from clickhouse_orm.database import Database
|
||||
from clickhouse_orm.engines import *
|
||||
from clickhouse_orm.fields import *
|
||||
from clickhouse_orm.engines import MergeTree
|
||||
from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
|
||||
from clickhouse_orm.models import Model
|
||||
|
||||
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||
|
|
29
tests/test_utils.py
Normal file
29
tests/test_utils.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from clickhouse_orm.utils import escape, unescape
|
||||
|
||||
SPECIAL_CHARS = {"\b": "\\x08", "\f": "\\x0c", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\x00", "\\": "\\\\", "'": "\\'"}
|
||||
|
||||
|
||||
def test_unescape():
|
||||
|
||||
for input_, expected in (
|
||||
("π\\t", "π\t"),
|
||||
("\\new", "\new"),
|
||||
("cheeky 🐵", "cheeky 🐵"),
|
||||
("\\N", None),
|
||||
):
|
||||
assert unescape(input_) == expected
|
||||
|
||||
|
||||
def test_escape_special_chars():
|
||||
|
||||
initial = "".join(SPECIAL_CHARS.keys())
|
||||
expected = "".join(SPECIAL_CHARS.values())
|
||||
assert escape(initial, quote=False) == expected
|
||||
assert escape(initial) == "'" + expected + "'"
|
||||
|
||||
|
||||
def test_escape_unescape_parity():
|
||||
|
||||
for initial in ("π\t", "\new", "cheeky 🐵", "back \\ slash", "\\\\n"):
|
||||
assert unescape(escape(initial, quote=False)) == initial
|
Loading…
Reference in New Issue
Block a user