mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 03:00:09 +03:00
feat: Support MapField
This commit is contained in:
parent
2a08fdcf94
commit
0ea078dd56
|
@ -24,7 +24,7 @@ dependencies = [
|
|||
"iso8601 >= 0.1.12",
|
||||
"setuptools"
|
||||
]
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value')
|
||||
DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
|
||||
|
||||
|
||||
def _get_default_arg(args, defaults, arg_index):
|
||||
""" Method that determines if an argument has default value or not,
|
||||
"""Method that determines if an argument has default value or not,
|
||||
and if yes what is the default value for the argument
|
||||
|
||||
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
|
||||
|
@ -26,13 +25,13 @@ def _get_default_arg(args, defaults, arg_index):
|
|||
return DefaultArgSpec(False, None)
|
||||
else:
|
||||
value = defaults[arg_index - args_with_no_defaults]
|
||||
if (type(value) is str):
|
||||
if type(value) is str:
|
||||
value = '"%s"' % value
|
||||
return DefaultArgSpec(True, value)
|
||||
|
||||
|
||||
def get_method_sig(method):
|
||||
""" Given a function, it returns a string that pretty much looks how the
|
||||
"""Given a function, it returns a string that pretty much looks how the
|
||||
function signature would be written in python.
|
||||
|
||||
:param method: a python method
|
||||
|
@ -59,16 +58,16 @@ def get_method_sig(method):
|
|||
args.append(arg)
|
||||
arg_index += 1
|
||||
if argspec.varargs:
|
||||
args.append('*' + argspec.varargs)
|
||||
args.append("*" + argspec.varargs)
|
||||
if argspec.varkw:
|
||||
args.append('**' + argspec.varkw)
|
||||
args.append("**" + argspec.varkw)
|
||||
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
|
||||
|
||||
|
||||
def docstring(obj):
|
||||
doc = (obj.__doc__ or '').rstrip()
|
||||
doc = (obj.__doc__ or "").rstrip()
|
||||
if doc:
|
||||
lines = doc.split('\n')
|
||||
lines = doc.split("\n")
|
||||
# Find the length of the whitespace prefix common to all non-empty lines
|
||||
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
|
||||
# Output the lines without the indentation
|
||||
|
@ -78,30 +77,32 @@ def docstring(obj):
|
|||
|
||||
|
||||
def class_doc(cls, list_methods=True):
|
||||
bases = ', '.join([b.__name__ for b in cls.__bases__])
|
||||
print('###', cls.__name__)
|
||||
bases = ", ".join([b.__name__ for b in cls.__bases__])
|
||||
print("###", cls.__name__)
|
||||
print()
|
||||
if bases != 'object':
|
||||
print('Extends', bases)
|
||||
if bases != "object":
|
||||
print("Extends", bases)
|
||||
print()
|
||||
docstring(cls)
|
||||
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
|
||||
if name == '__init__':
|
||||
for name, method in inspect.getmembers(
|
||||
cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)
|
||||
):
|
||||
if name == "__init__":
|
||||
# Initializer
|
||||
print('####', get_method_sig(method).replace(name, cls.__name__))
|
||||
elif name[0] == '_':
|
||||
print("####", get_method_sig(method).replace(name, cls.__name__))
|
||||
elif name[0] == "_":
|
||||
# Private method
|
||||
continue
|
||||
elif hasattr(method, '__self__') and method.__self__ == cls:
|
||||
elif hasattr(method, "__self__") and method.__self__ == cls:
|
||||
# Class method
|
||||
if not list_methods:
|
||||
continue
|
||||
print('#### %s.%s' % (cls.__name__, get_method_sig(method)))
|
||||
print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
|
||||
else:
|
||||
# Regular method
|
||||
if not list_methods:
|
||||
continue
|
||||
print('####', get_method_sig(method))
|
||||
print("####", get_method_sig(method))
|
||||
print()
|
||||
docstring(method)
|
||||
print()
|
||||
|
@ -110,7 +111,7 @@ def class_doc(cls, list_methods=True):
|
|||
def module_doc(classes, list_methods=True):
|
||||
mdl = classes[0].__module__
|
||||
print(mdl)
|
||||
print('-' * len(mdl))
|
||||
print("-" * len(mdl))
|
||||
print()
|
||||
for cls in classes:
|
||||
class_doc(cls, list_methods)
|
||||
|
@ -120,7 +121,7 @@ def all_subclasses(cls):
|
|||
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
|
||||
from clickhouse_orm import database
|
||||
from clickhouse_orm import fields
|
||||
|
@ -131,13 +132,24 @@ if __name__ == '__main__':
|
|||
from clickhouse_orm import system_models
|
||||
from clickhouse_orm.aio import database as aio_database
|
||||
|
||||
print('Class Reference')
|
||||
print('===============')
|
||||
print("Class Reference")
|
||||
print("===============")
|
||||
print()
|
||||
module_doc([database.Database, database.DatabaseException])
|
||||
module_doc([aio_database.AioDatabase])
|
||||
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index])
|
||||
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
|
||||
module_doc(
|
||||
[
|
||||
models.Model,
|
||||
models.BufferModel,
|
||||
models.MergeModel,
|
||||
models.DistributedModel,
|
||||
models.Constraint,
|
||||
models.Index,
|
||||
]
|
||||
)
|
||||
module_doc(
|
||||
sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False
|
||||
)
|
||||
module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False)
|
||||
module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False)
|
||||
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])
|
||||
|
|
|
@ -2,13 +2,13 @@ from html.parser import HTMLParser
|
|||
import sys
|
||||
|
||||
|
||||
HEADER_TAGS = ('h1', 'h2', 'h3')
|
||||
HEADER_TAGS = ("h1", "h2", "h3")
|
||||
|
||||
|
||||
class HeadersToMarkdownParser(HTMLParser):
|
||||
|
||||
inside = None
|
||||
text = ''
|
||||
text = ""
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag.lower() in HEADER_TAGS:
|
||||
|
@ -16,11 +16,11 @@ class HeadersToMarkdownParser(HTMLParser):
|
|||
|
||||
def handle_endtag(self, tag):
|
||||
if tag.lower() in HEADER_TAGS:
|
||||
indent = ' ' * int(self.inside[1])
|
||||
fragment = self.text.lower().replace(' ', '-').replace('.', '')
|
||||
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment))
|
||||
indent = " " * int(self.inside[1])
|
||||
fragment = self.text.lower().replace(" ", "-").replace(".", "")
|
||||
print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
|
||||
self.inside = None
|
||||
self.text = ''
|
||||
self.text = ""
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.inside:
|
||||
|
@ -28,4 +28,4 @@ class HeadersToMarkdownParser(HTMLParser):
|
|||
|
||||
|
||||
HeadersToMarkdownParser().feed(sys.stdin.read())
|
||||
print('')
|
||||
print("")
|
||||
|
|
|
@ -43,7 +43,7 @@ class Lazy(DatabaseEngine):
|
|||
for which there is a long time interval between accesses.
|
||||
"""
|
||||
|
||||
def __int__(self, expiration_time_in_seconds: int):
|
||||
def __init__(self, expiration_time_in_seconds: int):
|
||||
self.expiration_time_in_seconds = expiration_time_in_seconds
|
||||
|
||||
def create_database_sql(self) -> str:
|
||||
|
@ -59,7 +59,7 @@ class MySQL(DatabaseEngine):
|
|||
such as SHOW TABLES or SHOW CREATE TABLE.
|
||||
"""
|
||||
|
||||
def __int__(self, host: str, port: int, database: str, user: str, password: str):
|
||||
def __init__(self, host: str, port: int, database: str, user: str, password: str):
|
||||
"""
|
||||
- `host`: MySQL server address.
|
||||
- `port`: MySQL server port.
|
||||
|
@ -90,7 +90,7 @@ class PostgreSQL(DatabaseEngine):
|
|||
but can be updated with DETACH and ATTACH queries.
|
||||
"""
|
||||
|
||||
def __int__(
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals, annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
@ -14,7 +15,7 @@ import iso8601
|
|||
import pytz
|
||||
from pytz import BaseTzInfo
|
||||
|
||||
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
|
||||
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names, parse_map
|
||||
from .funcs import F, FunctionOperatorsMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -870,5 +871,75 @@ class LowCardinalityField(Field):
|
|||
return sql
|
||||
|
||||
|
||||
MapKey = Union[
|
||||
StringField,
|
||||
BaseIntField,
|
||||
LowCardinalityField,
|
||||
UUIDField,
|
||||
DateField,
|
||||
DateTimeField,
|
||||
BaseEnumField,
|
||||
]
|
||||
|
||||
|
||||
class MapField(Field):
|
||||
"""MapField is only experimental and is expected to be available in the next few releases"""
|
||||
class_default = None # incorrect value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: MapKey,
|
||||
value: Field,
|
||||
default: dict = None,
|
||||
alias: Optional[Union[F, str]] = None,
|
||||
materialized: Optional[Union[F, str]] = None,
|
||||
readonly: bool = None,
|
||||
codec: Optional[str] = None,
|
||||
db_column: Optional[str] = None,
|
||||
):
|
||||
assert isinstance(key, Field)
|
||||
assert isinstance(value, Field)
|
||||
self.key = key
|
||||
self.value = value
|
||||
if not default:
|
||||
default = {}
|
||||
super().__init__(default, alias, materialized, readonly, codec, db_column)
|
||||
|
||||
def to_python(self, value, timezone_in_use) -> dict:
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
if isinstance(value, str):
|
||||
value = parse_map(value)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("MapField expects dict, not %s" % type(value))
|
||||
value = {
|
||||
self.key.to_python(k, timezone_in_use): self.value.to_python(v, timezone_in_use)
|
||||
for k, v in value.items()
|
||||
}
|
||||
return value
|
||||
|
||||
def validate(self, value):
|
||||
for k, v in value.items():
|
||||
self.key.validate(k)
|
||||
self.value.validate(v)
|
||||
|
||||
def to_db_string(self, value, quote=True) -> str:
|
||||
value2 = {}
|
||||
int_key = isinstance(self.key, (BaseIntField, BaseFloatField))
|
||||
int_value = isinstance(self.value, (BaseIntField, BaseFloatField))
|
||||
for k, v in value.items():
|
||||
if not int_key:
|
||||
k = self.key.to_db_string(k, quote=False)
|
||||
if not int_value:
|
||||
v = self.value.to_db_string(v, quote=False)
|
||||
value2[k] = v
|
||||
return json.dumps(value2).replace("\"", "'")
|
||||
|
||||
def get_sql(self, with_default_expression=True, db=None) -> str:
|
||||
sql = "Map(%s, %s)" % (self.key.get_sql(False), self.value.get_sql(False))
|
||||
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||
sql += " CODEC(%s)" % self.codec
|
||||
return sql
|
||||
|
||||
# Expose only relevant classes in import *
|
||||
__all__ = get_subclass_names(locals(), Field)
|
||||
|
|
|
@ -186,6 +186,12 @@ class FunctionOperatorsMixin:
|
|||
def isNotIn(self, others):
|
||||
return F._notIn(self, others) # pylint: disable=W0212
|
||||
|
||||
def isGlobalIn(self, others):
|
||||
return F._gin(self, others) # pylint: disable=W0212
|
||||
|
||||
def isNotGlobalIn(self, others):
|
||||
return F._notGIn(self, others) # pylint: disable=W0212
|
||||
|
||||
|
||||
class FMeta(type):
|
||||
|
||||
|
@ -404,6 +410,20 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): # pylint: disable=R0904
|
|||
b = tuple(b)
|
||||
return F("NOT IN", a, b)
|
||||
|
||||
@staticmethod
|
||||
@binary_operator
|
||||
def _gin(a, b):
|
||||
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
|
||||
b = tuple(b)
|
||||
return F("GLOBAL IN", a, b)
|
||||
|
||||
@staticmethod
|
||||
@binary_operator
|
||||
def _notGIn(a, b):
|
||||
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
|
||||
b = tuple(b)
|
||||
return F("NOT GLOBAL IN", a, b)
|
||||
|
||||
# Functions for working with dates and times
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -253,6 +253,12 @@ class ModelBase(type):
|
|||
else:
|
||||
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
|
||||
return orm_fields.TupleField(name_fields=name_fields)
|
||||
# Map
|
||||
if db_type.startswith("Map"):
|
||||
types = [s.strip() for s in db_type[4:-1].split(",")]
|
||||
key_filed = cls.create_ad_hoc_field(types[0])
|
||||
value_filed = cls.create_ad_hoc_field(types[1])
|
||||
return orm_fields.MapField(key_filed, value_filed)
|
||||
# FixedString
|
||||
if db_type.startswith("FixedString"):
|
||||
length = int(db_type[12:-1])
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import codecs
|
||||
import json
|
||||
import re
|
||||
from datetime import date, datetime, tzinfo, timedelta
|
||||
|
||||
|
@ -16,6 +17,7 @@ SPECIAL_CHARS = {
|
|||
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
||||
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
|
||||
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
|
||||
MAP_REGEX = re.compile(r"(.+?)=(.+?),?")
|
||||
|
||||
|
||||
def escape(value, quote=True):
|
||||
|
@ -120,6 +122,22 @@ def parse_array(array_string):
|
|||
array_string = array_string[match.end() - 1 :]
|
||||
|
||||
|
||||
def parse_map(map_string: str) -> dict:
|
||||
"""
|
||||
Parse an map string as returned by clickhouse. For example:
|
||||
"{key1=1, key2=2, key3=3}" ==> {"key1": 1, "key2": 2, "key3: 3}
|
||||
"""
|
||||
if any([map_string[0] != "{", map_string[-1] != "}"]):
|
||||
raise ValueError('Invalid map string: "%s"' % map_string)
|
||||
ret = {}
|
||||
try:
|
||||
ret = json.loads(map_string.replace("'", "\""))
|
||||
except json.JSONDecodeError:
|
||||
for (key, value) in MAP_REGEX.findall(map_string[1:-1]):
|
||||
ret[key] = value
|
||||
return ret
|
||||
|
||||
|
||||
def import_submodules(package_name):
|
||||
"""
|
||||
Import all submodules of a module.
|
||||
|
|
59
tests/test_map_fields.py
Normal file
59
tests/test_map_fields.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import unittest
|
||||
from datetime import date
|
||||
|
||||
from clickhouse_orm.database import Database
|
||||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.fields import MapField, DateField, StringField, Int32Field, Float64Field
|
||||
from clickhouse_orm.engines import MergeTree
|
||||
|
||||
|
||||
class TupleFieldsTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.database = Database('test-db', log_statements=True)
|
||||
self.database.create_table(ModelWithTuple)
|
||||
|
||||
def tearDown(self):
|
||||
self.database.drop_database()
|
||||
|
||||
def test_insert_and_select(self):
|
||||
instance = ModelWithTuple(
|
||||
date_field='2022-06-26',
|
||||
map1={"k1": "v1", "k2": "v2"},
|
||||
map2={"v1": 1, "v2": 2},
|
||||
map3={"f1": 1.1, "f2": 2.0},
|
||||
map4={"2022-06-25": "ok", "2022-06-26": "today"}
|
||||
)
|
||||
self.database.insert([instance])
|
||||
query = 'SELECT * from $db.modelwithtuple ORDER BY date_field'
|
||||
for model_cls in (ModelWithTuple, None):
|
||||
results = list(self.database.select(query, model_cls))
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertIn("k1", results[0].map1)
|
||||
self.assertEqual(results[0].map1["k2"], "v2")
|
||||
self.assertEqual(results[0].map2["v1"], 1)
|
||||
self.assertEqual(results[0].map3["f2"], 2.0)
|
||||
self.assertEqual(results[0].map4[date(2022, 6, 26)], "today")
|
||||
|
||||
def test_conversion(self):
|
||||
instance = ModelWithTuple(
|
||||
map2="{'1': '2'}"
|
||||
)
|
||||
self.assertEqual(instance.map2['1'], 2)
|
||||
|
||||
def test_assignment_error(self):
|
||||
instance = ModelWithTuple()
|
||||
for value in (7, 'x', [date.today()], ['aaa'], [None]):
|
||||
with self.assertRaises(ValueError):
|
||||
instance.map1 = value
|
||||
|
||||
|
||||
class ModelWithTuple(Model):
|
||||
|
||||
date_field = DateField()
|
||||
map1 = MapField(StringField(), StringField())
|
||||
map2 = MapField(StringField(), Int32Field())
|
||||
map3 = MapField(StringField(), Float64Field())
|
||||
map4 = MapField(DateField(), StringField())
|
||||
|
||||
engine = MergeTree('date_field', ('date_field',))
|
Loading…
Reference in New Issue
Block a user