mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 11:10:11 +03:00
feat: Support MapField
This commit is contained in:
parent
2a08fdcf94
commit
0ea078dd56
|
@ -24,7 +24,7 @@ dependencies = [
|
||||||
"iso8601 >= 0.1.12",
|
"iso8601 >= 0.1.12",
|
||||||
"setuptools"
|
"setuptools"
|
||||||
]
|
]
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value')
|
DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
|
||||||
|
|
||||||
|
|
||||||
def _get_default_arg(args, defaults, arg_index):
|
def _get_default_arg(args, defaults, arg_index):
|
||||||
""" Method that determines if an argument has default value or not,
|
"""Method that determines if an argument has default value or not,
|
||||||
and if yes what is the default value for the argument
|
and if yes what is the default value for the argument
|
||||||
|
|
||||||
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
|
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
|
||||||
|
@ -26,13 +25,13 @@ def _get_default_arg(args, defaults, arg_index):
|
||||||
return DefaultArgSpec(False, None)
|
return DefaultArgSpec(False, None)
|
||||||
else:
|
else:
|
||||||
value = defaults[arg_index - args_with_no_defaults]
|
value = defaults[arg_index - args_with_no_defaults]
|
||||||
if (type(value) is str):
|
if type(value) is str:
|
||||||
value = '"%s"' % value
|
value = '"%s"' % value
|
||||||
return DefaultArgSpec(True, value)
|
return DefaultArgSpec(True, value)
|
||||||
|
|
||||||
|
|
||||||
def get_method_sig(method):
|
def get_method_sig(method):
|
||||||
""" Given a function, it returns a string that pretty much looks how the
|
"""Given a function, it returns a string that pretty much looks how the
|
||||||
function signature would be written in python.
|
function signature would be written in python.
|
||||||
|
|
||||||
:param method: a python method
|
:param method: a python method
|
||||||
|
@ -59,16 +58,16 @@ def get_method_sig(method):
|
||||||
args.append(arg)
|
args.append(arg)
|
||||||
arg_index += 1
|
arg_index += 1
|
||||||
if argspec.varargs:
|
if argspec.varargs:
|
||||||
args.append('*' + argspec.varargs)
|
args.append("*" + argspec.varargs)
|
||||||
if argspec.varkw:
|
if argspec.varkw:
|
||||||
args.append('**' + argspec.varkw)
|
args.append("**" + argspec.varkw)
|
||||||
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
|
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
|
||||||
|
|
||||||
|
|
||||||
def docstring(obj):
|
def docstring(obj):
|
||||||
doc = (obj.__doc__ or '').rstrip()
|
doc = (obj.__doc__ or "").rstrip()
|
||||||
if doc:
|
if doc:
|
||||||
lines = doc.split('\n')
|
lines = doc.split("\n")
|
||||||
# Find the length of the whitespace prefix common to all non-empty lines
|
# Find the length of the whitespace prefix common to all non-empty lines
|
||||||
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
|
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
|
||||||
# Output the lines without the indentation
|
# Output the lines without the indentation
|
||||||
|
@ -78,30 +77,32 @@ def docstring(obj):
|
||||||
|
|
||||||
|
|
||||||
def class_doc(cls, list_methods=True):
|
def class_doc(cls, list_methods=True):
|
||||||
bases = ', '.join([b.__name__ for b in cls.__bases__])
|
bases = ", ".join([b.__name__ for b in cls.__bases__])
|
||||||
print('###', cls.__name__)
|
print("###", cls.__name__)
|
||||||
print()
|
print()
|
||||||
if bases != 'object':
|
if bases != "object":
|
||||||
print('Extends', bases)
|
print("Extends", bases)
|
||||||
print()
|
print()
|
||||||
docstring(cls)
|
docstring(cls)
|
||||||
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
|
for name, method in inspect.getmembers(
|
||||||
if name == '__init__':
|
cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)
|
||||||
|
):
|
||||||
|
if name == "__init__":
|
||||||
# Initializer
|
# Initializer
|
||||||
print('####', get_method_sig(method).replace(name, cls.__name__))
|
print("####", get_method_sig(method).replace(name, cls.__name__))
|
||||||
elif name[0] == '_':
|
elif name[0] == "_":
|
||||||
# Private method
|
# Private method
|
||||||
continue
|
continue
|
||||||
elif hasattr(method, '__self__') and method.__self__ == cls:
|
elif hasattr(method, "__self__") and method.__self__ == cls:
|
||||||
# Class method
|
# Class method
|
||||||
if not list_methods:
|
if not list_methods:
|
||||||
continue
|
continue
|
||||||
print('#### %s.%s' % (cls.__name__, get_method_sig(method)))
|
print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
|
||||||
else:
|
else:
|
||||||
# Regular method
|
# Regular method
|
||||||
if not list_methods:
|
if not list_methods:
|
||||||
continue
|
continue
|
||||||
print('####', get_method_sig(method))
|
print("####", get_method_sig(method))
|
||||||
print()
|
print()
|
||||||
docstring(method)
|
docstring(method)
|
||||||
print()
|
print()
|
||||||
|
@ -110,7 +111,7 @@ def class_doc(cls, list_methods=True):
|
||||||
def module_doc(classes, list_methods=True):
|
def module_doc(classes, list_methods=True):
|
||||||
mdl = classes[0].__module__
|
mdl = classes[0].__module__
|
||||||
print(mdl)
|
print(mdl)
|
||||||
print('-' * len(mdl))
|
print("-" * len(mdl))
|
||||||
print()
|
print()
|
||||||
for cls in classes:
|
for cls in classes:
|
||||||
class_doc(cls, list_methods)
|
class_doc(cls, list_methods)
|
||||||
|
@ -120,7 +121,7 @@ def all_subclasses(cls):
|
||||||
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
|
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
from clickhouse_orm import database
|
from clickhouse_orm import database
|
||||||
from clickhouse_orm import fields
|
from clickhouse_orm import fields
|
||||||
|
@ -131,13 +132,24 @@ if __name__ == '__main__':
|
||||||
from clickhouse_orm import system_models
|
from clickhouse_orm import system_models
|
||||||
from clickhouse_orm.aio import database as aio_database
|
from clickhouse_orm.aio import database as aio_database
|
||||||
|
|
||||||
print('Class Reference')
|
print("Class Reference")
|
||||||
print('===============')
|
print("===============")
|
||||||
print()
|
print()
|
||||||
module_doc([database.Database, database.DatabaseException])
|
module_doc([database.Database, database.DatabaseException])
|
||||||
module_doc([aio_database.AioDatabase])
|
module_doc([aio_database.AioDatabase])
|
||||||
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index])
|
module_doc(
|
||||||
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
|
[
|
||||||
|
models.Model,
|
||||||
|
models.BufferModel,
|
||||||
|
models.MergeModel,
|
||||||
|
models.DistributedModel,
|
||||||
|
models.Constraint,
|
||||||
|
models.Index,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
module_doc(
|
||||||
|
sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False
|
||||||
|
)
|
||||||
module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False)
|
module_doc([engines.DatabaseEngine] + all_subclasses(engines.DatabaseEngine), False)
|
||||||
module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False)
|
module_doc([engines.TableEngine] + all_subclasses(engines.TableEngine), False)
|
||||||
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])
|
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])
|
||||||
|
|
|
@ -2,13 +2,13 @@ from html.parser import HTMLParser
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
HEADER_TAGS = ('h1', 'h2', 'h3')
|
HEADER_TAGS = ("h1", "h2", "h3")
|
||||||
|
|
||||||
|
|
||||||
class HeadersToMarkdownParser(HTMLParser):
|
class HeadersToMarkdownParser(HTMLParser):
|
||||||
|
|
||||||
inside = None
|
inside = None
|
||||||
text = ''
|
text = ""
|
||||||
|
|
||||||
def handle_starttag(self, tag, attrs):
|
def handle_starttag(self, tag, attrs):
|
||||||
if tag.lower() in HEADER_TAGS:
|
if tag.lower() in HEADER_TAGS:
|
||||||
|
@ -16,11 +16,11 @@ class HeadersToMarkdownParser(HTMLParser):
|
||||||
|
|
||||||
def handle_endtag(self, tag):
|
def handle_endtag(self, tag):
|
||||||
if tag.lower() in HEADER_TAGS:
|
if tag.lower() in HEADER_TAGS:
|
||||||
indent = ' ' * int(self.inside[1])
|
indent = " " * int(self.inside[1])
|
||||||
fragment = self.text.lower().replace(' ', '-').replace('.', '')
|
fragment = self.text.lower().replace(" ", "-").replace(".", "")
|
||||||
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment))
|
print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
|
||||||
self.inside = None
|
self.inside = None
|
||||||
self.text = ''
|
self.text = ""
|
||||||
|
|
||||||
def handle_data(self, data):
|
def handle_data(self, data):
|
||||||
if self.inside:
|
if self.inside:
|
||||||
|
@ -28,4 +28,4 @@ class HeadersToMarkdownParser(HTMLParser):
|
||||||
|
|
||||||
|
|
||||||
HeadersToMarkdownParser().feed(sys.stdin.read())
|
HeadersToMarkdownParser().feed(sys.stdin.read())
|
||||||
print('')
|
print("")
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Lazy(DatabaseEngine):
|
||||||
for which there is a long time interval between accesses.
|
for which there is a long time interval between accesses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __int__(self, expiration_time_in_seconds: int):
|
def __init__(self, expiration_time_in_seconds: int):
|
||||||
self.expiration_time_in_seconds = expiration_time_in_seconds
|
self.expiration_time_in_seconds = expiration_time_in_seconds
|
||||||
|
|
||||||
def create_database_sql(self) -> str:
|
def create_database_sql(self) -> str:
|
||||||
|
@ -59,7 +59,7 @@ class MySQL(DatabaseEngine):
|
||||||
such as SHOW TABLES or SHOW CREATE TABLE.
|
such as SHOW TABLES or SHOW CREATE TABLE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __int__(self, host: str, port: int, database: str, user: str, password: str):
|
def __init__(self, host: str, port: int, database: str, user: str, password: str):
|
||||||
"""
|
"""
|
||||||
- `host`: MySQL server address.
|
- `host`: MySQL server address.
|
||||||
- `port`: MySQL server port.
|
- `port`: MySQL server port.
|
||||||
|
@ -90,7 +90,7 @@ class PostgreSQL(DatabaseEngine):
|
||||||
but can be updated with DETACH and ATTACH queries.
|
but can be updated with DETACH and ATTACH queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __int__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import unicode_literals, annotations
|
from __future__ import unicode_literals, annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
@ -14,7 +15,7 @@ import iso8601
|
||||||
import pytz
|
import pytz
|
||||||
from pytz import BaseTzInfo
|
from pytz import BaseTzInfo
|
||||||
|
|
||||||
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
|
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names, parse_map
|
||||||
from .funcs import F, FunctionOperatorsMixin
|
from .funcs import F, FunctionOperatorsMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -870,5 +871,75 @@ class LowCardinalityField(Field):
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
|
|
||||||
|
MapKey = Union[
|
||||||
|
StringField,
|
||||||
|
BaseIntField,
|
||||||
|
LowCardinalityField,
|
||||||
|
UUIDField,
|
||||||
|
DateField,
|
||||||
|
DateTimeField,
|
||||||
|
BaseEnumField,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MapField(Field):
|
||||||
|
"""MapField is only experimental and is expected to be available in the next few releases"""
|
||||||
|
class_default = None # incorrect value
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: MapKey,
|
||||||
|
value: Field,
|
||||||
|
default: dict = None,
|
||||||
|
alias: Optional[Union[F, str]] = None,
|
||||||
|
materialized: Optional[Union[F, str]] = None,
|
||||||
|
readonly: bool = None,
|
||||||
|
codec: Optional[str] = None,
|
||||||
|
db_column: Optional[str] = None,
|
||||||
|
):
|
||||||
|
assert isinstance(key, Field)
|
||||||
|
assert isinstance(value, Field)
|
||||||
|
self.key = key
|
||||||
|
self.value = value
|
||||||
|
if not default:
|
||||||
|
default = {}
|
||||||
|
super().__init__(default, alias, materialized, readonly, codec, db_column)
|
||||||
|
|
||||||
|
def to_python(self, value, timezone_in_use) -> dict:
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode('utf-8')
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = parse_map(value)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError("MapField expects dict, not %s" % type(value))
|
||||||
|
value = {
|
||||||
|
self.key.to_python(k, timezone_in_use): self.value.to_python(v, timezone_in_use)
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
|
||||||
|
def validate(self, value):
|
||||||
|
for k, v in value.items():
|
||||||
|
self.key.validate(k)
|
||||||
|
self.value.validate(v)
|
||||||
|
|
||||||
|
def to_db_string(self, value, quote=True) -> str:
|
||||||
|
value2 = {}
|
||||||
|
int_key = isinstance(self.key, (BaseIntField, BaseFloatField))
|
||||||
|
int_value = isinstance(self.value, (BaseIntField, BaseFloatField))
|
||||||
|
for k, v in value.items():
|
||||||
|
if not int_key:
|
||||||
|
k = self.key.to_db_string(k, quote=False)
|
||||||
|
if not int_value:
|
||||||
|
v = self.value.to_db_string(v, quote=False)
|
||||||
|
value2[k] = v
|
||||||
|
return json.dumps(value2).replace("\"", "'")
|
||||||
|
|
||||||
|
def get_sql(self, with_default_expression=True, db=None) -> str:
|
||||||
|
sql = "Map(%s, %s)" % (self.key.get_sql(False), self.value.get_sql(False))
|
||||||
|
if with_default_expression and self.codec and db and db.has_codec_support:
|
||||||
|
sql += " CODEC(%s)" % self.codec
|
||||||
|
return sql
|
||||||
|
|
||||||
# Expose only relevant classes in import *
|
# Expose only relevant classes in import *
|
||||||
__all__ = get_subclass_names(locals(), Field)
|
__all__ = get_subclass_names(locals(), Field)
|
||||||
|
|
|
@ -186,6 +186,12 @@ class FunctionOperatorsMixin:
|
||||||
def isNotIn(self, others):
|
def isNotIn(self, others):
|
||||||
return F._notIn(self, others) # pylint: disable=W0212
|
return F._notIn(self, others) # pylint: disable=W0212
|
||||||
|
|
||||||
|
def isGlobalIn(self, others):
|
||||||
|
return F._gin(self, others) # pylint: disable=W0212
|
||||||
|
|
||||||
|
def isNotGlobalIn(self, others):
|
||||||
|
return F._notGIn(self, others) # pylint: disable=W0212
|
||||||
|
|
||||||
|
|
||||||
class FMeta(type):
|
class FMeta(type):
|
||||||
|
|
||||||
|
@ -404,6 +410,20 @@ class F(Cond, FunctionOperatorsMixin, metaclass=FMeta): # pylint: disable=R0904
|
||||||
b = tuple(b)
|
b = tuple(b)
|
||||||
return F("NOT IN", a, b)
|
return F("NOT IN", a, b)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@binary_operator
|
||||||
|
def _gin(a, b):
|
||||||
|
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
|
||||||
|
b = tuple(b)
|
||||||
|
return F("GLOBAL IN", a, b)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@binary_operator
|
||||||
|
def _notGIn(a, b):
|
||||||
|
if is_iterable(b) and not isinstance(b, (tuple, QuerySet)):
|
||||||
|
b = tuple(b)
|
||||||
|
return F("NOT GLOBAL IN", a, b)
|
||||||
|
|
||||||
# Functions for working with dates and times
|
# Functions for working with dates and times
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -253,6 +253,12 @@ class ModelBase(type):
|
||||||
else:
|
else:
|
||||||
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
|
name_fields.append((str(i), cls.create_ad_hoc_field(tp[0])))
|
||||||
return orm_fields.TupleField(name_fields=name_fields)
|
return orm_fields.TupleField(name_fields=name_fields)
|
||||||
|
# Map
|
||||||
|
if db_type.startswith("Map"):
|
||||||
|
types = [s.strip() for s in db_type[4:-1].split(",")]
|
||||||
|
key_filed = cls.create_ad_hoc_field(types[0])
|
||||||
|
value_filed = cls.create_ad_hoc_field(types[1])
|
||||||
|
return orm_fields.MapField(key_filed, value_filed)
|
||||||
# FixedString
|
# FixedString
|
||||||
if db_type.startswith("FixedString"):
|
if db_type.startswith("FixedString"):
|
||||||
length = int(db_type[12:-1])
|
length = int(db_type[12:-1])
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import codecs
|
import codecs
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import date, datetime, tzinfo, timedelta
|
from datetime import date, datetime, tzinfo, timedelta
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ SPECIAL_CHARS = {
|
||||||
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
SPECIAL_CHARS_REGEX = re.compile("[" + "".join(SPECIAL_CHARS.values()) + "]")
|
||||||
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
|
POINT_REGEX = re.compile(r"\((?P<x>\d+(\.\d+)?),(?P<y>\d+(\.\d+)?)\)")
|
||||||
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
|
RING_VALID_REGEX = re.compile(r"\[((\(\d+(\.\d+)?,\d+(\.\d+)?\)),)*\(\d+(\.\d+)?,\d+(\.\d+)?\)\]")
|
||||||
|
MAP_REGEX = re.compile(r"(.+?)=(.+?),?")
|
||||||
|
|
||||||
|
|
||||||
def escape(value, quote=True):
|
def escape(value, quote=True):
|
||||||
|
@ -120,6 +122,22 @@ def parse_array(array_string):
|
||||||
array_string = array_string[match.end() - 1 :]
|
array_string = array_string[match.end() - 1 :]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_map(map_string: str) -> dict:
|
||||||
|
"""
|
||||||
|
Parse an map string as returned by clickhouse. For example:
|
||||||
|
"{key1=1, key2=2, key3=3}" ==> {"key1": 1, "key2": 2, "key3: 3}
|
||||||
|
"""
|
||||||
|
if any([map_string[0] != "{", map_string[-1] != "}"]):
|
||||||
|
raise ValueError('Invalid map string: "%s"' % map_string)
|
||||||
|
ret = {}
|
||||||
|
try:
|
||||||
|
ret = json.loads(map_string.replace("'", "\""))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
for (key, value) in MAP_REGEX.findall(map_string[1:-1]):
|
||||||
|
ret[key] = value
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def import_submodules(package_name):
|
def import_submodules(package_name):
|
||||||
"""
|
"""
|
||||||
Import all submodules of a module.
|
Import all submodules of a module.
|
||||||
|
|
59
tests/test_map_fields.py
Normal file
59
tests/test_map_fields.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import unittest
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from clickhouse_orm.database import Database
|
||||||
|
from clickhouse_orm.models import Model
|
||||||
|
from clickhouse_orm.fields import MapField, DateField, StringField, Int32Field, Float64Field
|
||||||
|
from clickhouse_orm.engines import MergeTree
|
||||||
|
|
||||||
|
|
||||||
|
class TupleFieldsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.database = Database('test-db', log_statements=True)
|
||||||
|
self.database.create_table(ModelWithTuple)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.database.drop_database()
|
||||||
|
|
||||||
|
def test_insert_and_select(self):
|
||||||
|
instance = ModelWithTuple(
|
||||||
|
date_field='2022-06-26',
|
||||||
|
map1={"k1": "v1", "k2": "v2"},
|
||||||
|
map2={"v1": 1, "v2": 2},
|
||||||
|
map3={"f1": 1.1, "f2": 2.0},
|
||||||
|
map4={"2022-06-25": "ok", "2022-06-26": "today"}
|
||||||
|
)
|
||||||
|
self.database.insert([instance])
|
||||||
|
query = 'SELECT * from $db.modelwithtuple ORDER BY date_field'
|
||||||
|
for model_cls in (ModelWithTuple, None):
|
||||||
|
results = list(self.database.select(query, model_cls))
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertIn("k1", results[0].map1)
|
||||||
|
self.assertEqual(results[0].map1["k2"], "v2")
|
||||||
|
self.assertEqual(results[0].map2["v1"], 1)
|
||||||
|
self.assertEqual(results[0].map3["f2"], 2.0)
|
||||||
|
self.assertEqual(results[0].map4[date(2022, 6, 26)], "today")
|
||||||
|
|
||||||
|
def test_conversion(self):
|
||||||
|
instance = ModelWithTuple(
|
||||||
|
map2="{'1': '2'}"
|
||||||
|
)
|
||||||
|
self.assertEqual(instance.map2['1'], 2)
|
||||||
|
|
||||||
|
def test_assignment_error(self):
|
||||||
|
instance = ModelWithTuple()
|
||||||
|
for value in (7, 'x', [date.today()], ['aaa'], [None]):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
instance.map1 = value
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithTuple(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
map1 = MapField(StringField(), StringField())
|
||||||
|
map2 = MapField(StringField(), Int32Field())
|
||||||
|
map3 = MapField(StringField(), Float64Field())
|
||||||
|
map4 = MapField(DateField(), StringField())
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('date_field',))
|
Loading…
Reference in New Issue
Block a user