Enhancement: remove \N warning for utils

Also improves code quality and implementation details.
This commit is contained in:
olliemath 2021-07-28 21:39:49 +01:00
parent 4a90eede16
commit bd62a3c1de
7 changed files with 60 additions and 40 deletions

View File

@ -86,12 +86,12 @@ class MergeTree(Engine):
# Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True),
comma_join(map(str, self.partition_key)),
comma_join(map(str, self.order_by)),
)
if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
if self.sampling_expr:
partition_sql += " SAMPLE BY %s" % self.sampling_expr
@ -126,7 +126,7 @@ class MergeTree(Engine):
params.append(self.date_col)
if self.sampling_expr:
params.append(self.sampling_expr)
params.append("(%s)" % comma_join(self.order_by, stringify=True))
params.append("(%s)" % comma_join(map(str(self.order_by))))
params.append(str(self.index_granularity))
return params

View File

@ -147,7 +147,7 @@ class StringField(Field):
if isinstance(value, str):
return value
if isinstance(value, bytes):
return value.decode("UTF-8")
return value.decode("utf-8")
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
@ -163,7 +163,7 @@ class FixedStringField(StringField):
def validate(self, value):
if isinstance(value, str):
value = value.encode("UTF-8")
value = value.encode("utf-8")
if len(value) > self._length:
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
@ -475,7 +475,7 @@ class BaseEnumField(Field):
except Exception:
return self.enum_cls(value)
if isinstance(value, bytes):
decoded = value.decode("UTF-8")
decoded = value.decode("utf-8")
try:
return self.enum_cls[decoded]
except Exception:
@ -533,7 +533,7 @@ class ArrayField(Field):
if isinstance(value, str):
value = parse_array(value)
elif isinstance(value, bytes):
value = parse_array(value.decode("UTF-8"))
value = parse_array(value.decode("utf-8"))
elif not isinstance(value, (list, tuple)):
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value]

View File

@ -76,7 +76,7 @@ def parametric(func):
def inner(*args, **kwargs):
f = func(*args, **kwargs)
# Append the parameter to the function name
parameters_str = comma_join(parameters, stringify=True)
parameters_str = comma_join(map(str, parameters))
f.name = "%s(%s)" % (f.name, parameters_str)
return f

View File

@ -3,30 +3,27 @@ import importlib
import pkgutil
import re
from datetime import date, datetime, timedelta, tzinfo
SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"}
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
from inspect import isclass
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Type, Union
def escape(value, quote=True):
def escape(value: str, quote: bool = True) -> str:
"""
If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one.
"""
value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
if quote:
value = "'" + value + "'"
def escape_one(match):
return SPECIAL_CHARS[match.group(0)]
if isinstance(value, str):
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
if quote:
value = "'" + value + "'"
return str(value)
return value
def unescape(value):
def unescape(value: str) -> Optional[str]:
if value == "\\N":
return None
return codecs.escape_decode(value)[0].decode("utf-8")
@ -34,7 +31,7 @@ def string_or_func(obj):
return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg):
def arg_to_sql(arg: Any) -> str:
"""
Converts a function argument to SQL string according to its type.
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
@ -69,15 +66,15 @@ def arg_to_sql(arg):
return str(arg)
def parse_tsv(line):
def parse_tsv(line: Union[bytes, str]) -> List[str]:
if isinstance(line, bytes):
line = line.decode()
if line and line[-1] == "\n":
line = line[:-1]
return [unescape(value) for value in line.split(str("\t"))]
return [unescape(value) for value in line.split("\t")]
def parse_array(array_string):
def parse_array(array_string: str) -> List[Any]:
"""
Parse an array or tuple string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"]
@ -111,7 +108,7 @@ def parse_array(array_string):
array_string = array_string[match.end() - 1 :]
def import_submodules(package_name):
def import_submodules(package_name: str) -> Dict[str, ModuleType]:
"""
Import all submodules of a module.
"""
@ -122,17 +119,14 @@ def import_submodules(package_name):
}
def comma_join(items, stringify=False):
def comma_join(items: Iterable[str]) -> str:
"""
Joins an iterable of strings with commas.
"""
if stringify:
return ", ".join(str(item) for item in items)
else:
return ", ".join(items)
return ", ".join(items)
def is_iterable(obj):
def is_iterable(obj: Any) -> bool:
"""
Checks if the given object is iterable.
"""
@ -143,9 +137,7 @@ def is_iterable(obj):
return False
def get_subclass_names(locals, base_class):
from inspect import isclass
def get_subclass_names(locals: Dict[str, Any], base_class: Type):
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -3,8 +3,8 @@ import logging
import unittest
from clickhouse_orm.database import Database
from clickhouse_orm.engines import *
from clickhouse_orm.fields import *
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
from clickhouse_orm.models import Model
logging.getLogger("requests").setLevel(logging.WARNING)

29
tests/test_utils.py Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
from clickhouse_orm.utils import escape, unescape
SPECIAL_CHARS = {"\b": "\\x08", "\f": "\\x0c", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\x00", "\\": "\\\\", "'": "\\'"}
def test_unescape():
for input_, expected in (
("π\\t", "π\t"),
("\\new", "\new"),
("cheeky 🐵", "cheeky 🐵"),
("\\N", None),
):
assert unescape(input_) == expected
def test_escape_special_chars():
initial = "".join(SPECIAL_CHARS.keys())
expected = "".join(SPECIAL_CHARS.values())
assert escape(initial, quote=False) == expected
assert escape(initial) == "'" + expected + "'"
def test_escape_unescape_parity():
for initial in ("π\t", "\new", "cheeky 🐵", "back \\ slash", "\\\\n"):
assert unescape(escape(initial, quote=False)) == initial