Enhancement: remove \N warning for utils

Also improves code quality and implementation details.
This commit is contained in:
olliemath 2021-07-28 21:39:49 +01:00
parent 4a90eede16
commit bd62a3c1de
7 changed files with 60 additions and 40 deletions

View File

@ -86,12 +86,12 @@ class MergeTree(Engine):
# Let's check version and use new syntax if available # Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310): if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % ( partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
comma_join(self.partition_key, stringify=True), comma_join(map(str, self.partition_key)),
comma_join(self.order_by, stringify=True), comma_join(map(str, self.order_by)),
) )
if self.primary_key: if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
if self.sampling_expr: if self.sampling_expr:
partition_sql += " SAMPLE BY %s" % self.sampling_expr partition_sql += " SAMPLE BY %s" % self.sampling_expr
@ -126,7 +126,7 @@ class MergeTree(Engine):
params.append(self.date_col) params.append(self.date_col)
if self.sampling_expr: if self.sampling_expr:
params.append(self.sampling_expr) params.append(self.sampling_expr)
params.append("(%s)" % comma_join(self.order_by, stringify=True)) params.append("(%s)" % comma_join(map(str(self.order_by))))
params.append(str(self.index_granularity)) params.append(str(self.index_granularity))
return params return params

View File

@ -147,7 +147,7 @@ class StringField(Field):
if isinstance(value, str): if isinstance(value, str):
return value return value
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode("UTF-8") return value.decode("utf-8")
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
@ -163,7 +163,7 @@ class FixedStringField(StringField):
def validate(self, value): def validate(self, value):
if isinstance(value, str): if isinstance(value, str):
value = value.encode("UTF-8") value = value.encode("utf-8")
if len(value) > self._length: if len(value) > self._length:
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length)) raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
@ -475,7 +475,7 @@ class BaseEnumField(Field):
except Exception: except Exception:
return self.enum_cls(value) return self.enum_cls(value)
if isinstance(value, bytes): if isinstance(value, bytes):
decoded = value.decode("UTF-8") decoded = value.decode("utf-8")
try: try:
return self.enum_cls[decoded] return self.enum_cls[decoded]
except Exception: except Exception:
@ -533,7 +533,7 @@ class ArrayField(Field):
if isinstance(value, str): if isinstance(value, str):
value = parse_array(value) value = parse_array(value)
elif isinstance(value, bytes): elif isinstance(value, bytes):
value = parse_array(value.decode("UTF-8")) value = parse_array(value.decode("utf-8"))
elif not isinstance(value, (list, tuple)): elif not isinstance(value, (list, tuple)):
raise ValueError("ArrayField expects list or tuple, not %s" % type(value)) raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value] return [self.inner_field.to_python(v, timezone_in_use) for v in value]

View File

@ -76,7 +76,7 @@ def parametric(func):
def inner(*args, **kwargs): def inner(*args, **kwargs):
f = func(*args, **kwargs) f = func(*args, **kwargs)
# Append the parameter to the function name # Append the parameter to the function name
parameters_str = comma_join(parameters, stringify=True) parameters_str = comma_join(map(str, parameters))
f.name = "%s(%s)" % (f.name, parameters_str) f.name = "%s(%s)" % (f.name, parameters_str)
return f return f

View File

@ -3,30 +3,27 @@ import importlib
import pkgutil import pkgutil
import re import re
from datetime import date, datetime, timedelta, tzinfo from datetime import date, datetime, timedelta, tzinfo
from inspect import isclass
SPECIAL_CHARS = {"\b": "\\b", "\f": "\\f", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\0", "\\": "\\\\", "'": "\\'"} from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Type, Union
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True): def escape(value: str, quote: bool = True) -> str:
""" """
If the value is a string, escapes any special characters and optionally If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number), surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one. converts it to one.
""" """
value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
if quote:
value = "'" + value + "'"
def escape_one(match): return value
return SPECIAL_CHARS[match.group(0)]
if isinstance(value, str):
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
if quote:
value = "'" + value + "'"
return str(value)
def unescape(value): def unescape(value: str) -> Optional[str]:
if value == "\\N":
return None
return codecs.escape_decode(value)[0].decode("utf-8") return codecs.escape_decode(value)[0].decode("utf-8")
@ -34,7 +31,7 @@ def string_or_func(obj):
return obj.to_sql() if hasattr(obj, "to_sql") else obj return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg): def arg_to_sql(arg: Any) -> str:
""" """
Converts a function argument to SQL string according to its type. Converts a function argument to SQL string according to its type.
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans, Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
@ -69,15 +66,15 @@ def arg_to_sql(arg):
return str(arg) return str(arg)
def parse_tsv(line): def parse_tsv(line: Union[bytes, str]) -> List[str]:
if isinstance(line, bytes): if isinstance(line, bytes):
line = line.decode() line = line.decode()
if line and line[-1] == "\n": if line and line[-1] == "\n":
line = line[:-1] line = line[:-1]
return [unescape(value) for value in line.split(str("\t"))] return [unescape(value) for value in line.split("\t")]
def parse_array(array_string): def parse_array(array_string: str) -> List[Any]:
""" """
Parse an array or tuple string as returned by clickhouse. For example: Parse an array or tuple string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"] "['hello', 'world']" ==> ["hello", "world"]
@ -111,7 +108,7 @@ def parse_array(array_string):
array_string = array_string[match.end() - 1 :] array_string = array_string[match.end() - 1 :]
def import_submodules(package_name): def import_submodules(package_name: str) -> Dict[str, ModuleType]:
""" """
Import all submodules of a module. Import all submodules of a module.
""" """
@ -122,17 +119,14 @@ def import_submodules(package_name):
} }
def comma_join(items, stringify=False): def comma_join(items: Iterable[str]) -> str:
""" """
Joins an iterable of strings with commas. Joins an iterable of strings with commas.
""" """
if stringify: return ", ".join(items)
return ", ".join(str(item) for item in items)
else:
return ", ".join(items)
def is_iterable(obj): def is_iterable(obj: Any) -> bool:
""" """
Checks if the given object is iterable. Checks if the given object is iterable.
""" """
@ -143,9 +137,7 @@ def is_iterable(obj):
return False return False
def get_subclass_names(locals, base_class): def get_subclass_names(locals: Dict[str, Any], base_class: Type):
from inspect import isclass
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -3,8 +3,8 @@ import logging
import unittest import unittest
from clickhouse_orm.database import Database from clickhouse_orm.database import Database
from clickhouse_orm.engines import * from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import * from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
from clickhouse_orm.models import Model from clickhouse_orm.models import Model
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)

29
tests/test_utils.py Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
from clickhouse_orm.utils import escape, unescape
SPECIAL_CHARS = {"\b": "\\x08", "\f": "\\x0c", "\r": "\\r", "\n": "\\n", "\t": "\\t", "\0": "\\x00", "\\": "\\\\", "'": "\\'"}
def test_unescape():
for input_, expected in (
("π\\t", "π\t"),
("\\new", "\new"),
("cheeky 🐵", "cheeky 🐵"),
("\\N", None),
):
assert unescape(input_) == expected
def test_escape_special_chars():
initial = "".join(SPECIAL_CHARS.keys())
expected = "".join(SPECIAL_CHARS.values())
assert escape(initial, quote=False) == expected
assert escape(initial) == "'" + expected + "'"
def test_escape_unescape_parity():
for initial in ("π\t", "\new", "cheeky 🐵", "back \\ slash", "\\\\n"):
assert unescape(escape(initial, quote=False)) == initial