feat: Support MapField

This commit is contained in:
sw 2022-06-26 13:58:29 +08:00
parent 2a08fdcf94
commit 0ea078dd56
9 changed files with 224 additions and 38 deletions

View File

@ -24,7 +24,7 @@ dependencies = [
"iso8601 >= 0.1.12",
"setuptools"
]
version = "0.2.0"
version = "0.2.1"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,12 +1,11 @@
import inspect
from collections import namedtuple
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value')
DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
def _get_default_arg(args, defaults, arg_index):
""" Method that determines if an argument has default value or not,
"""Method that determines if an argument has default value or not,
and if yes what is the default value for the argument
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
@ -26,13 +25,13 @@ def _get_default_arg(args, defaults, arg_index):
return DefaultArgSpec(False, None)
else:
value = defaults[arg_index - args_with_no_defaults]
if (type(value) is str):
if type(value) is str:
value = '"%s"' % value
return DefaultArgSpec(True, value)
def get_method_sig(method):
""" Given a function, it returns a string that pretty much looks how the
"""Given a function, it returns a string that pretty much looks how the
function signature would be written in python.
:param method: a python method
@ -59,16 +58,16 @@ def get_method_sig(method):
args.append(arg)
arg_index += 1
if argspec.varargs:
args.append('*' + argspec.varargs)
args.append("*" + argspec.varargs)
if argspec.varkw:
args.append('**' + argspec.varkw)
args.append("**" + argspec.varkw)
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
def docstring(obj):
doc = (obj.__doc__ or '').rstrip()
doc = (obj.__doc__ or "").rstrip()
if doc:
lines = doc.split('\n')
lines = doc.split("\n")
# Find the length of the whitespace prefix common to all non-empty lines
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
# Output the lines without the indentation
@ -78,30 +77,32 @@ def docstring(obj):
def class_doc(cls, list_methods=True):
bases = ', '.join([b.__name__ for b in cls.__bases__])
print('###', cls.__name__)
bases = ", ".join([b.__name__ for b in cls.__bases__])
print("###", cls.__name__)
print()
if bases != 'object':
print('Extends', bases)
if bases != "object":
print("Extends", bases)
print()
docstring(cls)
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
if name == '__init__':
for name, method in inspect.getmembers(
cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)
):
if name == "__init__":
# Initializer
print('####', get_method_sig(method).replace(name, cls.__name__))
elif name[0] == '_':
print("####", get_method_sig(method).replace(name, cls.__name__))
elif name[0] == "_":
# Private method
continue
elif hasattr(method, '__self__') and method.__self__ == cls:
elif hasattr(method, "__self__") and method.__self__ == cls:
# Class method
if not list_methods:
continue
print('#### %s.%s' % (cls.__name__, get_method_sig(method)))
print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
else:
# Regular method
if not list_methods:
continue
print('####', get_method_sig(method))
print("####", get_method_sig(method))
print()
docstring(method)
print()
@ -110,7 +111,7 @@ def class_doc(cls, list_methods=True):
def module_doc(classes, list_methods=True):
mdl = classes[0].__module__
print(mdl)
print('-' * len(mdl))
print("-" * len(mdl))
print()
for cls in classes:
class_doc(cls, list_methods)
@ -120,7 +121,7 @@ def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
if __name__ == '__main__':
if __name__ == "__main__":
from clickhouse_orm import database
from clickhouse_orm import fields
@ -131,13 +132,24 @@ if __name__ == '__main__':
from clickhouse_orm import system_models
from clickhouse_orm.aio import database as aio_database
print('Class Reference')
print('===============')
print("Class Reference")
print("===============")
print()
module_doc([database.Database, database.DatabaseException])
module_doc([aio_database.AioDatabase])
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index])
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
module_doc(
[
models.Model,
models.BufferModel,
models.MergeModel,
models.DistributedModel,
models.Constraint,
models.Index,
]
)
module_doc(
sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False
)
module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False)
module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False)
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])

View File

@ -2,13 +2,13 @@ from html.parser import HTMLParser
import sys
HEADER_TAGS = ('h1', 'h2', 'h3')
HEADER_TAGS = ("h1", "h2", "h3")
class HeadersToMarkdownParser(HTMLParser):
inside = None
text = ''
text = ""
def handle_starttag(self, tag, attrs):
if tag.lower() in HEADER_TAGS:
@ -16,11 +16,11 @@ class HeadersToMarkdownParser(HTMLParser):
def handle_endtag(self, tag):
if tag.lower() in HEADER_TAGS:
indent = ' ' * int(self.inside[1])
fragment = self.text.lower().replace(' ', '-').replace('.', '')
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment))
indent = " " * int(self.inside[1])
fragment = self.text.lower().replace(" ", "-").replace(".", "")
print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
self.inside = None
self.text = ''
self.text = ""
def handle_data(self, data):
if self.inside:
@ -28,4 +28,4 @@ class HeadersToMarkdownParser(HTMLParser):
HeadersToMarkdownParser().feed(sys.stdin.read())
print('')
print("")

View File

@ -43,7 +43,7 @@ class Lazy(DatabaseEngine):
for which there is a long time interval between accesses.
"""
def __int__(self, expiration_time_in_seconds: int):
def __init__(self, expiration_time_in_seconds: int):
self.expiration_time_in_seconds = expiration_time_in_seconds
def create_database_sql(self) -> str:
@ -59,7 +59,7 @@ class MySQL(DatabaseEngine):
such as SHOW TABLES or SHOW CREATE TABLE.
"""
def __int__(self, host: str, port: int, database: str, user: str, password: str):
def __init__(self, host: str, port: int, database: str, user: str, password: str):
"""
- `host`: MySQL server address.
- `port`: MySQL server port.
@ -90,7 +90,7 @@ class PostgreSQL(DatabaseEngine):
but can be updated with DETACH and ATTACH queries.
"""
def __int__(
def __init__(
self,
host: str,
port: int,

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals, annotations
import json
import re
from enum import Enum
from uuid import UUID
@ -14,7 +15,7 @@ import iso8601
import pytz
from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names, parse_map
from .funcs import F, FunctionOperatorsMixin
if TYPE_CHECKING:
@ -870,5 +871,75 @@ class LowCardinalityField(Field):
return sql
MapKey = Union[
StringField,
BaseIntField,
LowCardinalityField,
UUIDField,
DateField,
DateTimeField,
BaseEnumField,
]
class MapField(Field):
"""MapField is only experimental and is expected to be available in the next few releases"""
class_default = None # incorrect value
def __init__(
self,
key: MapKey,
value: Field,
default: dict = None,
alias: Optional[Union[F, str]] = None,
materialized: Optional[Union[F, str]] = None,
readonly: bool = None,
codec: Optional[str] = None,
db_column: Optional[str] = None,
):
assert isinstance(key, Field)
assert isinstance(value, Field)
self.key = key
self.value = value
if not default:
default = {}
super().__init__(default, alias, materialized, readonly, codec, db_column)
def to_python(self, value, timezone_in_use) -> dict:
if isinstance(value, bytes):
value = value.decode('utf-8')
if isinstance(value, str):
value = parse_map(value)
if not isinstance(value, dict):
raise ValueError("MapField expects dict, not %s" % type(value))
value = {
self.key.to_python(k, timezone_in_use): self.value.to_python(v, timezone_in_use)
for k, v in value.items()
}
return value
def validate(self, value):
for k, v in value.items():
self.key.validate(k)
self.value.validate(v)
def to_db_string(self, value, quote=True) -> str:
value2 = {}
int_key = isinstance(self.key, (BaseIntField, BaseFloatField))
int_value = isinstance(self.value, (BaseIntField, BaseFloatField))
for k, v in value.items():
if not int_key:
k = self.key.to_db_string(k, quote=False)
if not int_value:
v = self.value.to_db_string(v, quote=False)
value2[k] = v
return json.dumps(value2).replace("\"", "'")
def get_sql(self, with_default_expression=True, db=None) -> str:
sql = "Map(%s, %s)" % (self.key.get_sql(False), self.value.get_sql(False))
if with_default_expression and self.codec and db and db.has_codec_support:
sql += " CODEC(%s)" % self.codec
return sql
# Expose only relevant classes in import *
__all__ = get_subclass_names(locals(), Field)

View File

@ -186,6 +186,12 @@ class FunctionOperatorsMixin:
def isNotIn(self, others):
return F._notIn(self, others) # pylint: disable=W0212
def isGlobalIn(self, others):
return F._gin(self, others) # pylint: disable=W0212
def isNotGlobalIn(self, others):
return F._notGIn(self, others) # pylint: disable=W0212
class FMeta(type):
@ -404,6 +410,20 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): # pylint: disable=R0904
b = tuple(b)
return F("NOT IN", a, b)
@staticmethod
@binary_operator
def _gin(a, b):
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
b = tuple(b)
return F("GLOBAL IN", a, b)
@staticmethod
@binary_operator
def _notGIn(a, b):
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
b = tuple(b)
return F("NOT GLOBAL IN", a, b)
# Functions for working with dates and times
@staticmethod

View File

@ -253,6 +253,12 @@ class ModelBase(type):
else:
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
return orm_fields.TupleField(name_fields=name_fields)
# Map
if db_type.startswith("Map"):
types = [s.strip() for s in db_type[4:-1].split(",")]
key_filed = cls.create_ad_hoc_field(types[0])
value_filed = cls.create_ad_hoc_field(types[1])
return orm_fields.MapField(key_filed, value_filed)
# FixedString
if db_type.startswith("FixedString"):
length = int(db_type[12:-1])

View File

@ -1,4 +1,5 @@
import codecs
import json
import re
from datetime import date, datetime, tzinfo, timedelta
@ -16,6 +17,7 @@ SPECIAL_CHARS = {
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
MAP_REGEX = re.compile(r"(.+?)=(.+?),?")
def escape(value, quote=True):
@ -120,6 +122,22 @@ def parse_array(array_string):
array_string = array_string[match.end() - 1 :]
def parse_map(map_string: str) -> dict:
"""
Parse an map string as returned by clickhouse. For example:
"{key1=1, key2=2, key3=3}" ==> {"key1": 1, "key2": 2, "key3: 3}
"""
if any([map_string[0] != "{", map_string[-1] != "}"]):
raise ValueError('Invalid map string: "%s"' % map_string)
ret = {}
try:
ret = json.loads(map_string.replace("'", "\""))
except json.JSONDecodeError:
for (key, value) in MAP_REGEX.findall(map_string[1:-1]):
ret[key] = value
return ret
def import_submodules(package_name):
"""
Import all submodules of a module.

59
tests/test_map_fields.py Normal file
View File

@ -0,0 +1,59 @@
import unittest
from datetime import date
from clickhouse_orm.database import Database
from clickhouse_orm.models import Model
from clickhouse_orm.fields import MapField, DateField, StringField, Int32Field, Float64Field
from clickhouse_orm.engines import MergeTree
class TupleFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(ModelWithTuple)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithTuple(
date_field='2022-06-26',
map1={"k1": "v1", "k2": "v2"},
map2={"v1": 1, "v2": 2},
map3={"f1": 1.1, "f2": 2.0},
map4={"2022-06-25": "ok", "2022-06-26": "today"}
)
self.database.insert([instance])
query = 'SELECT * from $db.modelwithtuple ORDER BY date_field'
for model_cls in (ModelWithTuple, None):
results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1)
self.assertIn("k1", results[0].map1)
self.assertEqual(results[0].map1["k2"], "v2")
self.assertEqual(results[0].map2["v1"], 1)
self.assertEqual(results[0].map3["f2"], 2.0)
self.assertEqual(results[0].map4[date(2022, 6, 26)], "today")
def test_conversion(self):
instance = ModelWithTuple(
map2="{'1': '2'}"
)
self.assertEqual(instance.map2['1'], 2)
def test_assignment_error(self):
instance = ModelWithTuple()
for value in (7, 'x', [date.today()], ['aaa'], [None]):
with self.assertRaises(ValueError):
instance.map1 = value
class ModelWithTuple(Model):
date_field = DateField()
map1 = MapField(StringField(), StringField())
map2 = MapField(StringField(), Int32Field())
map3 = MapField(StringField(), Float64Field())
map4 = MapField(DateField(), StringField())
engine = MergeTree('date_field', ('date_field',))