feat: Support MapField

This commit is contained in:
sw 2022-06-26 13:58:29 +08:00
parent 2a08fdcf94
commit 0ea078dd56
9 changed files with 224 additions and 38 deletions

View File

@ -24,7 +24,7 @@ dependencies = [
"iso8601 >= 0.1.12",
"setuptools"
]
version = "0.2.0"
version = "0.2.1"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,8 +1,7 @@
import inspect
from collections import namedtuple
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value')
DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
def _get_default_arg(args, defaults, arg_index):
@ -26,7 +25,7 @@ def _get_default_arg(args, defaults, arg_index):
return DefaultArgSpec(False, None)
else:
value = defaults[arg_index - args_with_no_defaults]
if (type(value) is str):
if type(value) is str:
value = '"%s"' % value
return DefaultArgSpec(True, value)
@ -59,16 +58,16 @@ def get_method_sig(method):
args.append(arg)
arg_index += 1
if argspec.varargs:
args.append('*' + argspec.varargs)
args.append("*" + argspec.varargs)
if argspec.varkw:
args.append('**' + argspec.varkw)
args.append("**" + argspec.varkw)
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
def docstring(obj):
doc = (obj.__doc__ or '').rstrip()
doc = (obj.__doc__ or "").rstrip()
if doc:
lines = doc.split('\n')
lines = doc.split("\n")
# Find the length of the whitespace prefix common to all non-empty lines
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
# Output the lines without the indentation
@ -78,30 +77,32 @@ def docstring(obj):
def class_doc(cls, list_methods=True):
bases = ', '.join([b.__name__ for b in cls.__bases__])
print('###', cls.__name__)
bases = ", ".join([b.__name__ for b in cls.__bases__])
print("###", cls.__name__)
print()
if bases != 'object':
print('Extends', bases)
if bases != "object":
print("Extends", bases)
print()
docstring(cls)
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
if name == '__init__':
for name, method in inspect.getmembers(
cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)
):
if name == "__init__":
# Initializer
print('####', get_method_sig(method).replace(name, cls.__name__))
elif name[0] == '_':
print("####", get_method_sig(method).replace(name, cls.__name__))
elif name[0] == "_":
# Private method
continue
elif hasattr(method, '__self__') and method.__self__ == cls:
elif hasattr(method, "__self__") and method.__self__ == cls:
# Class method
if not list_methods:
continue
print('#### %s.%s' % (cls.__name__, get_method_sig(method)))
print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
else:
# Regular method
if not list_methods:
continue
print('####', get_method_sig(method))
print("####", get_method_sig(method))
print()
docstring(method)
print()
@ -110,7 +111,7 @@ def class_doc(cls, list_methods=True):
def module_doc(classes, list_methods=True):
mdl = classes[0].__module__
print(mdl)
print('-' * len(mdl))
print("-" * len(mdl))
print()
for cls in classes:
class_doc(cls, list_methods)
@ -120,7 +121,7 @@ def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
if __name__ == '__main__':
if __name__ == "__main__":
from clickhouse_orm import database
from clickhouse_orm import fields
@ -131,13 +132,24 @@ if __name__ == '__main__':
from clickhouse_orm import system_models
from clickhouse_orm.aio import database as aio_database
print('Class Reference')
print('===============')
print("Class Reference")
print("===============")
print()
module_doc([database.Database, database.DatabaseException])
module_doc([aio_database.AioDatabase])
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index])
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
module_doc(
[
models.Model,
models.BufferModel,
models.MergeModel,
models.DistributedModel,
models.Constraint,
models.Index,
]
)
module_doc(
sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False
)
module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False)
module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False)
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])

View File

@ -2,13 +2,13 @@ from html.parser import HTMLParser
import sys
HEADER_TAGS = ('h1', 'h2', 'h3')
HEADER_TAGS = ("h1", "h2", "h3")
class HeadersToMarkdownParser(HTMLParser):
inside = None
text = ''
text = ""
def handle_starttag(self, tag, attrs):
if tag.lower() in HEADER_TAGS:
@ -16,11 +16,11 @@ class HeadersToMarkdownParser(HTMLParser):
def handle_endtag(self, tag):
if tag.lower() in HEADER_TAGS:
indent = ' ' * int(self.inside[1])
fragment = self.text.lower().replace(' ', '-').replace('.', '')
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment))
indent = " " * int(self.inside[1])
fragment = self.text.lower().replace(" ", "-").replace(".", "")
print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
self.inside = None
self.text = ''
self.text = ""
def handle_data(self, data):
if self.inside:
@ -28,4 +28,4 @@ class HeadersToMarkdownParser(HTMLParser):
HeadersToMarkdownParser().feed(sys.stdin.read())
print('')
print("")

View File

@ -43,7 +43,7 @@ class Lazy(DatabaseEngine):
for which there is a long time interval between accesses.
"""
def __int__(self, expiration_time_in_seconds: int):
def __init__(self, expiration_time_in_seconds: int):
self.expiration_time_in_seconds = expiration_time_in_seconds
def create_database_sql(self) -> str:
@ -59,7 +59,7 @@ class MySQL(DatabaseEngine):
such as SHOW TABLES or SHOW CREATE TABLE.
"""
def __int__(self, host: str, port: int, database: str, user: str, password: str):
def __init__(self, host: str, port: int, database: str, user: str, password: str):
"""
- `host`: MySQL server address.
- `port`: MySQL server port.
@ -90,7 +90,7 @@ class PostgreSQL(DatabaseEngine):
but can be updated with DETACH and ATTACH queries.
"""
def __int__(
def __init__(
self,
host: str,
port: int,

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals, annotations
import json
import re
from enum import Enum
from uuid import UUID
@ -14,7 +15,7 @@ import iso8601
import pytz
from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names, parse_map
from .funcs import F, FunctionOperatorsMixin
if TYPE_CHECKING:
@ -870,5 +871,75 @@ class LowCardinalityField(Field):
return sql
MapKey = Union[
StringField,
BaseIntField,
LowCardinalityField,
UUIDField,
DateField,
DateTimeField,
BaseEnumField,
]
class MapField(Field):
"""MapField is only experimental and is expected to be available in the next few releases"""
class_default = None # incorrect value
def __init__(
self,
key: MapKey,
value: Field,
default: dict = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None,
):
assert isinstance(key, Field)
assert isinstance(value, Field)
self.key = key
self.value = value
if not default:
default = {}
super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use) -> dict:
if isinstance(value, bytes):
value = value.decode('utf-8')
if isinstance(value, str):
value = parse_map(value)
if not isinstance(value, dict):
raise ValueError("MapField expects dict, not %s" % type(value))
value = {
self.key.to_python(k, timezone_in_use): self.value.to_python(v, timezone_in_use)
for k, v in value.items()
}
return value
def validate(self, value):
for k, v in value.items():
self.key.validate(k)
self.value.validate(v)
def to_db_string(self, value, quote=True) -> str:
value2 = {}
int_key = isinstance(self.key, (BaseIntField, BaseFloatField))
int_value = isinstance(self.value, (BaseIntField, BaseFloatField))
for k, v in value.items():
if not int_key:
k = self.key.to_db_string(k, quote=False)
if not int_value:
v = self.value.to_db_string(v, quote=False)
value2[k] = v
return json.dumps(value2).replace("\"", "'")
def get_sql(self, with_default_expression=True, db=None) -> str:
sql = "Map(%s, %s)" % (self.key.get_sql(False), self.value.get_sql(False))
if with_default_expression and self.codec and db and db.has_codec_support:
sql += " CODEC(%s)" % self.codec
return sql
# Expose only relevant classes in import *
__all__ = get_subclass_names(locals(), Field)

View File

@ -186,6 +186,12 @@ class FunctionOperatorsMixin:
def isNotIn(self, others):
return F._notIn(self, others) # pylint: disable=W0212
def isGlobalIn(self, others):
return F._gin(self, others) # pylint: disable=W0212
def isNotGlobalIn(self, others):
return F._notGIn(self, others) # pylint: disable=W0212
class FMeta(type):
@ -404,6 +410,20 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): # pylint: disable=R0904
b = tuple(b)
return F("NOT IN", a, b)
@staticmethod
@binary_operator
def _gin(a, b):
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
b = tuple(b)
return F("GLOBAL IN", a, b)
@staticmethod
@binary_operator
def _notGIn(a, b):
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
b = tuple(b)
return F("NOT GLOBAL IN", a, b)
# Functions for working with dates and times
@staticmethod

View File

@ -253,6 +253,12 @@ class ModelBase(type):
else:
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
return orm_fields.TupleField(name_fields=name_fields)
# Map
if db_type.startswith("Map"):
types = [s.strip() for s in db_type[4:-1].split(",")]
key_filed = cls.create_ad_hoc_field(types[0])
value_filed = cls.create_ad_hoc_field(types[1])
return orm_fields.MapField(key_filed, value_filed)
# FixedString
if db_type.startswith("FixedString"):
length = int(db_type[12:-1])

View File

@ -1,4 +1,5 @@
import codecs
import json
import re
from datetime import date, datetime, tzinfo, timedelta
@ -16,6 +17,7 @@ SPECIAL_CHARS = {
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
MAP_REGEX = re.compile(r"(.+?)=(.+?),?")
def escape(value, quote=True):
@ -120,6 +122,22 @@ def parse_array(array_string):
array_string = array_string[match.end() - 1 :]
def parse_map(map_string: str) -> dict:
"""
Parse an map string as returned by clickhouse. For example:
"{key1=1, key2=2, key3=3}" ==> {"key1": 1, "key2": 2, "key3: 3}
"""
if any([map_string[0] != "{", map_string[-1] != "}"]):
raise ValueError('Invalid map string: "%s"' % map_string)
ret = {}
try:
ret = json.loads(map_string.replace("'", "\""))
except json.JSONDecodeError:
for (key, value) in MAP_REGEX.findall(map_string[1:-1]):
ret[key] = value
return ret
def import_submodules(package_name):
"""
Import all submodules of a module.

59
tests/test_map_fields.py Normal file
View File

@ -0,0 +1,59 @@
import unittest
from datetime import date
from clickhouse_orm.database import Database
from clickhouse_orm.models import Model
from clickhouse_orm.fields import MapField, DateField, StringField, Int32Field, Float64Field
from clickhouse_orm.engines import MergeTree
class TupleFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(ModelWithTuple)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithTuple(
date_field='2022-06-26',
map1={"k1": "v1", "k2": "v2"},
map2={"v1": 1, "v2": 2},
map3={"f1": 1.1, "f2": 2.0},
map4={"2022-06-25": "ok", "2022-06-26": "today"}
)
self.database.insert([instance])
query = 'SELECT * from $db.modelwithtuple ORDER BY date_field'
for model_cls in (ModelWithTuple, None):
results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1)
self.assertIn("k1", results[0].map1)
self.assertEqual(results[0].map1["k2"], "v2")
self.assertEqual(results[0].map2["v1"], 1)
self.assertEqual(results[0].map3["f2"], 2.0)
self.assertEqual(results[0].map4[date(2022, 6, 26)], "today")
def test_conversion(self):
instance = ModelWithTuple(
map2="{'1': '2'}"
)
self.assertEqual(instance.map2['1'], 2)
def test_assignment_error(self):
instance = ModelWithTuple()
for value in (7, 'x', [date.today()], ['aaa'], [None]):
with self.assertRaises(ValueError):
instance.map1 = value
class ModelWithTuple(Model):
date_field = DateField()
map1 = MapField(StringField(), StringField())
map2 = MapField(StringField(), Int32Field())
map3 = MapField(StringField(), Float64Field())
map4 = MapField(DateField(), StringField())
engine = MergeTree('date_field', ('date_field',))