This commit is contained in:
Beda Kosata 2021-10-04 08:08:24 +00:00 committed by GitHub
commit 4661f5ba55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
111 changed files with 4739 additions and 4495 deletions

40
.github/workflows/python-publish.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Upload to PIP
# Controls when the action will run.
on:
# Triggers the workflow when a release is created
release:
types: [created]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "upload"
upload:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Sets up python
- uses: actions/setup-python@v2
with:
python-version: 3.9
# Install dependencies
- name: Install dependencies
run: |
python -m pip install poetry
poetry install --no-dev
# Build and upload to PyPI
- name: Builds and upload to PyPI
run: |
poetry config pypi-token.pypi "$TWINE_TOKEN"
poetry publish --build
env:
TWINE_TOKEN: ${{ secrets.TWINE_TOKEN }}

89
.github/workflows/python-test.yml vendored Normal file
View File

@ -0,0 +1,89 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Tests
on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
jobs:
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
# Lint on smallest and largest active versions
python-version: [3.6, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Lint with flake8
run: |
poetry run flake8
- name: Check formatting with black
run: |
poetry run black --check .
test:
runs-on: ubuntu-latest
services:
clickhouse:
image: yandex/clickhouse-server:21.3
ports:
- 8123:8123
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9, pypy3]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Test with pytest
run: |
poetry run pytest
test_compat:
# Tests compatibility with an older LTS release of clickhouse
runs-on: ubuntu-latest
services:
clickhouse:
image: yandex/clickhouse-server:20.8
ports:
- 8123:8123
strategy:
fail-fast: false
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Test with pytest
run: |
poetry run pytest

59
.gitignore vendored
View File

@ -63,3 +63,62 @@ cover/
# tox
.tox/
# misc
*.7z
*.apk
*.backup
*.bak
*.bk
*.bz2
*.deb
*.doc
*.docx
*.gz
*.gzip
*.img
*.iso
*.jar
*.jpeg
*.jpg
*.log
*.ods
*.part
*.pdf
*.pkg
*.png
*.pps
*.ppsx
*.ppt
*.pptx
*.ps
*.pyc
*.rar
*.swp
*.sys
*.tar
*.tgz
*.tmp
*.xls
*.xlsx
*.xz
*.zip
**/*venv/
.cache
.coverage*
.idea/
.isort.cfg
**.directory
venv
.pytest_cache/
.vscode/
*.egg-info/
.tox/
.cargo/
.expected
.hypothesis/
.mypy_cache/
**/__pycache__/
# poetry
poetry.lock

1073
.noseids

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,26 @@
Change Log
==========
v2.2.2
------
- Unpined requirements to enhance compatability
v2.2.1
------
- Minor tooling changes for PyPI
- Better project description for PyPI
v2.2.0
------
- Support up to clickhouse 20.12, including LTS 20.8 release
- Fixed boolean logic for Q objects (https://github.com/Infinidat/infi.clickhouse_orm/issues/158)
- Remove implicit use of '\N' character (which was causing deprecation warnings in queries)
- Tooling updates: use poetry, pytest, isort, black
**Backwards incompatible changes**
You can no longer supply a codec for an `alias` field. Previously this had no effect in clickhouse, but now it explicitly returns an error.
v2.1.0
------
- Support for model constraints

View File

@ -1,3 +1,4 @@
Copyright (c) 2021 Suade Labs
Copyright (c) 2017 INFINIDAT
Redistribution and use in source and binary forms, with or without

View File

@ -1,3 +1,8 @@
A fork of [infi.clikchouse_orm](https://github.com/Infinidat/infi.clickhouse_orm) aimed at more frequent maintenance and bugfixes.
[![Tests](https://github.com/SuadeLabs/clickhouse_orm/actions/workflows/python-test.yml/badge.svg)](https://github.com/SuadeLabs/clickhouse_orm/actions/workflows/python-test.yml)
![PyPI](https://img.shields.io/pypi/v/clickhouse_orm)
Introduction
============
@ -8,7 +13,7 @@ Let's jump right in with a simple example of monitoring CPU usage. First we need
connect to the database and create a table for the model:
```python
from infi.clickhouse_orm import Database, Model, DateTimeField, UInt16Field, Float32Field, Memory, F
from clickhouse_orm import Database, Model, DateTimeField, UInt16Field, Float32Field, Memory, F
class CPUStats(Model):

View File

@ -1,68 +0,0 @@
[buildout]
prefer-final = false
newest = false
download-cache = .cache
develop = .
parts =
relative-paths = true
[project]
name = infi.clickhouse_orm
company = Infinidat
namespace_packages = ['infi']
install_requires = [
'iso8601 >= 0.1.12',
'pytz',
'requests',
'setuptools'
]
version_file = src/infi/clickhouse_orm/__version__.py
description = A Python library for working with the ClickHouse database
long_description = A Python library for working with the ClickHouse database
console_scripts = []
gui_scripts = []
package_data = []
upgrade_code = {58530fba-3932-11e6-a20e-7071bc32067f}
product_name = infi.clickhouse_orm
post_install_script_name = None
pre_uninstall_script_name = None
homepage = https://github.com/Infinidat/infi.clickhouse_orm
[isolated-python]
recipe = infi.recipe.python
version = v3.8.0.2
[setup.py]
recipe = infi.recipe.template.version
input = setup.in
output = setup.py
[__version__.py]
recipe = infi.recipe.template.version
output = ${project:version_file}
[development-scripts]
dependent-scripts = true
recipe = infi.recipe.console_scripts
eggs = ${project:name}
ipython<6
nose
coverage
enum-compat
infi.unittest
infi.traceback
memory_profiler
profilehooks
psutil
zc.buildout
scripts = ipython
nosetests
interpreter = python
[pack]
recipe = infi.recipe.application_packager
[sublime]
recipe = corneti.recipes.codeintel
eggs = ${development-scripts:eggs}

View File

@ -0,0 +1,12 @@
from inspect import isclass
from .database import * # noqa: F401, F403
from .engines import * # noqa: F401, F403
from .fields import * # noqa: F401, F403
from .funcs import * # noqa: F401, F403
from .migrations import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .query import * # noqa: F401, F403
from .system_models import * # noqa: F401, F403
__all__ = [c.__name__ for c in locals().values() if isclass(c)]

View File

@ -1,26 +1,23 @@
from __future__ import unicode_literals
import re
import requests
from collections import namedtuple
from .models import ModelBase
from .utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime
from string import Template
import pytz
import logging
logger = logging.getLogger('clickhouse_orm')
import re
from math import ceil
from string import Template
import pytz
import requests
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
from .models import ModelBase
from .utils import Page, import_submodules, parse_tsv
logger = logging.getLogger("clickhouse_orm")
class DatabaseException(Exception):
'''
"""
Raised when a database operation fails.
'''
"""
pass
@ -28,6 +25,7 @@ class ServerError(DatabaseException):
"""
Raised when a server returns an error.
"""
def __init__(self, message):
self.code = None
processed = self.get_error_code_msg(message)
@ -41,16 +39,22 @@ class ServerError(DatabaseException):
ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3
re.compile(r'''
re.compile(
r"""
Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
\ e.what\(\)\ =\ (?P<type2>[^ \n]+)
''', re.VERBOSE | re.DOTALL),
""",
re.VERBOSE | re.DOTALL,
),
# ClickHouse v19.3.3+
re.compile(r'''
re.compile(
r"""
Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL),
""",
re.VERBOSE | re.DOTALL,
),
)
@classmethod
@ -65,7 +69,7 @@ class ServerError(DatabaseException):
match = pattern.match(full_error_message)
if match:
# assert match.group('type1') == match.group('type2')
return int(match.group('code')), match.group('msg').strip()
return int(match.group("code")), match.group("msg").strip()
return 0, full_error_message
@ -75,15 +79,24 @@ class ServerError(DatabaseException):
class Database(object):
'''
"""
Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations.
'''
"""
def __init__(self, db_name, db_url='http://localhost:8123/',
username=None, password=None, readonly=False, autocreate=True,
timeout=60, verify_ssl_cert=True, log_statements=False):
'''
def __init__(
self,
db_name,
db_url="http://localhost:8123/",
username=None,
password=None,
readonly=False,
autocreate=True,
timeout=60,
verify_ssl_cert=True,
log_statements=False,
):
"""
Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist.
@ -96,7 +109,7 @@ class Database(object):
- `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged.
'''
"""
self.db_name = db_name
self.db_url = db_url
self.readonly = False
@ -104,14 +117,14 @@ class Database(object):
self.request_session = requests.Session()
self.request_session.verify = verify_ssl_cert
if username:
self.request_session.auth = (username, password or '')
self.request_session.auth = (username, password or "")
self.log_statements = log_statements
self.settings = {}
self.db_exists = False # this is required before running _is_existing_database
self.db_exists = self._is_existing_database()
if readonly:
if not self.db_exists:
raise DatabaseException('Database does not exist, and cannot be created under readonly connection')
raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
self.connection_readonly = self._is_connection_readonly()
self.readonly = True
elif autocreate and not self.db_exists:
@ -125,56 +138,56 @@ class Database(object):
self.has_low_cardinality_support = self.server_version >= (19, 0)
def create_database(self):
'''
"""
Creates the database on the ClickHouse server if it does not already exist.
'''
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
"""
self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True
def drop_database(self):
'''
"""
Deletes the database on the ClickHouse server.
'''
self._send('DROP DATABASE `%s`' % self.db_name)
"""
self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False
def create_table(self, model_class):
'''
"""
Creates a table for the given model class, if it does not exist already.
'''
"""
if model_class.is_system_model():
raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None:
if model_class.engine is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self))
def drop_table(self, model_class):
'''
"""
Drops the database table of the given model class, if it exists.
'''
"""
if model_class.is_system_model():
raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class):
'''
"""
Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name.
'''
"""
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1'
return r.text.strip() == "1"
def get_model_for_table(self, table_name, system_table=False):
'''
"""
Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class,
for example system tables.
- `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database
'''
db_name = 'system' if system_table else self.db_name
"""
db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines()
fields = [parse_tsv(line)[:2] for line in lines]
@ -184,27 +197,28 @@ class Database(object):
return model
def add_setting(self, name, value):
'''
"""
Adds a database setting that will be sent with every request.
For example, `db.add_setting("max_execution_time", 10)` will
limit query execution time to 10 seconds.
The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value.
'''
assert isinstance(name, str), 'Setting name must be a string'
"""
assert isinstance(name, str), "Setting name must be a string"
if value is None:
self.settings.pop(name, None)
else:
self.settings[name] = str(value)
def insert(self, model_instances, batch_size=1000):
'''
"""
Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
'''
"""
from io import BytesIO
i = iter(model_instances)
try:
first_instance = next(i)
@ -215,14 +229,13 @@ class Database(object):
if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join(
['`%s`' % name for name in first_instance.fields(writable=True)])
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated'
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
def gen():
buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8'))
buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self)
buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size
@ -240,35 +253,37 @@ class Database(object):
# Return any remaining lines in partial batch
if lines:
yield buf.getvalue()
self._send(gen())
def count(self, model_class, conditions=None):
'''
"""
Counts the number of records in the model's table.
- `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause).
'''
from infi.clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
"""
from clickhouse_orm.query import Q
query = "SELECT count() FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += " WHERE " + str(conditions)
query = self._substitute(query, model_class)
r = self._send(query)
return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None):
'''
"""
Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters
'''
query += ' FORMAT TabSeparatedWithNamesAndTypes'
"""
query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class)
r = self._send(query, settings, True)
lines = r.iter_lines()
@ -281,18 +296,18 @@ class Database(object):
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
def raw(self, query, settings=None, stream=False):
'''
"""
Performs a query and returns its output as text.
- `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
'''
"""
query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
'''
"""
Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table,
@ -305,54 +320,63 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
'''
from infi.clickhouse_orm.query import Q
"""
from clickhouse_orm.query import Q
count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = max(pages_total, 1)
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table'
query = "SELECT * FROM $table"
if conditions:
if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions)
query += ' ORDER BY %s' % order_by
query += ' LIMIT %d, %d' % (offset, page_size)
query += " WHERE " + str(conditions)
query += " ORDER BY %s" % order_by
query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class)
return Page(
objects=list(self.select(query, model_class, settings)) if count else [],
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
def migrate(self, migrations_package_name, up_to=9999):
'''
"""
Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package
containing the migrations.
- `up_to` - number of the last migration to apply.
'''
"""
from .migrations import MigrationHistory
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name)
logger.info("Applying migration %s...", name)
for operation in modules[name].operations:
operation.apply(self)
self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())])
self.insert(
[
MigrationHistory(
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
)
]
)
if int(name[:4]) >= up_to:
break
def _get_applied_migrations(self, migrations_package_name):
from .migrations import MigrationHistory
self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory)
@ -360,7 +384,7 @@ class Database(object):
def _send(self, data, settings=None, stream=False):
if isinstance(data, str):
data = data.encode('utf-8')
data = data.encode("utf-8")
if self.log_statements:
logger.info(data)
params = self._build_params(settings)
@ -373,50 +397,50 @@ class Database(object):
params = dict(settings or {})
params.update(self.settings)
if self.db_exists:
params['database'] = self.db_name
params["database"] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error)
if self.readonly and not self.connection_readonly:
params['readonly'] = '1'
params["readonly"] = "1"
return params
def _substitute(self, query, model_class=None):
'''
"""
Replaces $db and $table placeholders in the query.
'''
if '$' in query:
"""
if "$" in query:
mapping = dict(db="`%s`" % self.db_name)
if model_class:
if model_class.is_system_model():
mapping['table'] = "`system`.`%s`" % model_class.table_name()
mapping["table"] = "`system`.`%s`" % model_class.table_name()
else:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).safe_substitute(mapping)
return query
def _get_server_timezone(self):
try:
r = self._send('SELECT timezone()')
r = self._send("SELECT timezone()")
return pytz.timezone(r.text.strip())
except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e)
logger.exception("Cannot determine server timezone (%s), assuming UTC", e)
return pytz.utc
def _get_server_version(self, as_tuple=True):
try:
r = self._send('SELECT version();')
r = self._send("SELECT version();")
ver = r.text
except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e)
ver = '1.1.0'
return tuple(int(n) for n in ver.split('.')) if as_tuple else ver
logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
ver = "1.1.0"
return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
return r.text.strip() == '1'
return r.text.strip() == "1"
def _is_connection_readonly(self):
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0'
return r.text.strip() != "0"
# Expose only relevant classes in import *

View File

@ -1,54 +1,60 @@
from __future__ import unicode_literals
import logging
from .utils import comma_join, get_subclass_names
logger = logging.getLogger('clickhouse_orm')
logger = logging.getLogger("clickhouse_orm")
class Engine(object):
def create_table_sql(self, db):
raise NotImplementedError() # pragma: no cover
class TinyLog(Engine):
def create_table_sql(self, db):
return 'TinyLog'
return "TinyLog"
class Log(Engine):
def create_table_sql(self, db):
return 'Log'
return "Log"
class Memory(Engine):
def create_table_sql(self, db):
return 'Memory'
return "Memory"
class MergeTree(Engine):
def __init__(self, date_col=None, order_by=(), sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None,
primary_key=None):
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple'
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present'
assert primary_key is None or type(primary_key) in (list, tuple), 'primary_key must be a list or tuple'
assert partition_key is None or type(partition_key) in (list, tuple),\
'partition_key must be tuple or list if present'
assert (replica_table_path is None) == (replica_name is None), \
'both replica_table_path and replica_name must be specified'
def __init__(
self,
date_col=None,
order_by=(),
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
settings=None,
):
assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
assert primary_key is None or type(primary_key) in (list, tuple), "primary_key must be a list or tuple"
assert partition_key is None or type(partition_key) in (
list,
tuple,
), "partition_key must be tuple or list if present"
assert (replica_table_path is None) == (
replica_name is None
), "both replica_table_path and replica_name must be specified"
assert settings is None or type(settings) is dict, 'settings must be dict'
# These values conflict with each other (old and new syntax of table engines.
# So let's control only one of them is given.
assert date_col or partition_key, "You must set either date_col or partition_key"
self.date_col = date_col
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,)
self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
self.primary_key = primary_key
self.order_by = order_by
@ -56,50 +62,62 @@ class MergeTree(Engine):
self.index_granularity = index_granularity
self.replica_table_path = replica_table_path
self.replica_name = replica_name
self.settings = settings
# I changed field name for new reality and syntax
@property
def key_cols(self):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead')
logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
return self.order_by
@key_cols.setter
def key_cols(self, value):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead')
logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
self.order_by = value
def create_table_sql(self, db):
name = self.__class__.__name__
if self.replica_name:
name = 'Replicated' + name
name = "Replicated" + name
# In ClickHouse 1.1.54310 custom partitioning key was introduced
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
# Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \
% (comma_join(self.partition_key, stringify=True),
comma_join(self.order_by, stringify=True))
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
comma_join(map(str, self.partition_key)),
comma_join(map(str, self.order_by)),
)
if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True)
partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
if self.sampling_expr:
partition_sql += " SAMPLE BY %s" % self.sampling_expr
partition_sql += " SETTINGS index_granularity=%d" % self.index_granularity
if self.settings:
settings_sql = ", ".join('%s=%s' % (key, value) for key, value in self.settings.items())
partition_sql += ", " + settings_sql
elif not self.date_col:
# Can't import it globally due to circular import
from infi.clickhouse_orm.database import DatabaseException
raise DatabaseException("Custom partitioning is not supported before ClickHouse 1.1.54310. "
from clickhouse_orm.database import DatabaseException
raise DatabaseException(
"Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/")
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
)
else:
partition_sql = ''
partition_sql = ""
params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql)
return "%s(%s) %s" % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db):
params = []
@ -114,19 +132,37 @@ class MergeTree(Engine):
params.append(self.date_col)
if self.sampling_expr:
params.append(self.sampling_expr)
params.append('(%s)' % comma_join(self.order_by, stringify=True))
params.append("(%s)" % comma_join(map(str(self.order_by))))
params.append(str(self.index_granularity))
return params
class CollapsingMergeTree(MergeTree):
def __init__(self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None,
primary_key=None):
super(CollapsingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity,
replica_table_path, replica_name, partition_key, primary_key)
def __init__(
self,
date_col=None,
order_by=(),
sign_col="sign",
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
settings=None,
):
super(CollapsingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
settings=settings,
)
self.sign_col = sign_col
def _build_sql_params(self, db):
@ -136,29 +172,61 @@ class CollapsingMergeTree(MergeTree):
class SummingMergeTree(MergeTree):
def __init__(self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None,
primary_key=None):
super(SummingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, replica_table_path,
replica_name, partition_key, primary_key)
assert type is None or type(summing_cols) in (list, tuple), 'summing_cols must be a list or tuple'
def __init__(
self,
date_col=None,
order_by=(),
summing_cols=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(SummingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
assert type is None or type(summing_cols) in (list, tuple), "summing_cols must be a list or tuple"
self.summing_cols = summing_cols
def _build_sql_params(self, db):
params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols))
params.append("(%s)" % comma_join(self.summing_cols))
return params
class ReplacingMergeTree(MergeTree):
def __init__(self, date_col=None, order_by=(), ver_col=None, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None,
primary_key=None):
super(ReplacingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity,
replica_table_path, replica_name, partition_key, primary_key)
def __init__(
self,
date_col=None,
order_by=(),
ver_col=None,
sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(ReplacingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.ver_col = ver_col
def _build_sql_params(self, db):
@ -175,9 +243,18 @@ class Buffer(Engine):
Read more [here](https://clickhouse.tech/docs/en/engines/table-engines/special/buffer/).
"""
#Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes)
def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000,
min_bytes=10000000, max_bytes=100000000):
# Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes)
def __init__(
self,
main_model,
num_layers=16,
min_time=10,
max_time=100,
min_rows=10000,
max_rows=1000000,
min_bytes=10000000,
max_bytes=100000000,
):
self.main_model = main_model
self.num_layers = num_layers
self.min_time = min_time
@ -190,10 +267,16 @@ class Buffer(Engine):
def create_table_sql(self, db):
# Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % (
db.db_name, self.main_model.table_name(), self.num_layers,
self.min_time, self.max_time, self.min_rows,
self.max_rows, self.min_bytes, self.max_bytes
sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
db.db_name,
self.main_model.table_name(),
self.num_layers,
self.min_time,
self.max_time,
self.min_rows,
self.max_rows,
self.min_bytes,
self.max_bytes,
)
return sql
@ -224,6 +307,7 @@ class Distributed(Engine):
See full documentation here
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
"""
def __init__(self, cluster, table=None, sharding_key=None):
"""
- `cluster`: what cluster to access data from
@ -252,12 +336,11 @@ class Distributed(Engine):
def create_table_sql(self, db):
name = self.__class__.__name__
params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params))
return "%s(%s)" % (name, ", ".join(params))
def _build_sql_params(self, db):
if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format(
self.__class__.__name__))
raise ValueError("Cannot create {} engine: specify an underlying table".format(self.__class__.__name__))
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
if self.sharding_key:

View File

@ -1,23 +1,25 @@
from __future__ import unicode_literals
import datetime
import iso8601
import pytz
from calendar import timegm
from decimal import Decimal, localcontext
from uuid import UUID
from logging import getLogger
from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .funcs import F, FunctionOperatorsMixin
from ipaddress import IPv4Address, IPv6Address
from logging import getLogger
from uuid import UUID
logger = getLogger('clickhouse_orm')
import iso8601
import pytz
from pytz import BaseTzInfo
from .funcs import F, FunctionOperatorsMixin
from .utils import comma_join, escape, get_subclass_names, parse_array, string_or_func
logger = getLogger("clickhouse_orm")
class Field(FunctionOperatorsMixin):
'''
"""
Abstract base class for all field types.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered
@ -25,15 +27,19 @@ class Field(FunctionOperatorsMixin):
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\
"Alias parameter must be a string or function object, if given"
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\
"Materialized parameter must be a string or function object, if given"
assert [default, alias, materialized].count(
None
) >= 2, "Only one of default, alias and materialized parameters can be given"
assert (
alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
), "Alias parameter must be a string or function object, if given"
assert (
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
), "Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given"
assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
if alias:
assert codec is None, "Codec cannot be used for alias fields"
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
@ -47,49 +53,51 @@ class Field(FunctionOperatorsMixin):
return self.name
def __repr__(self):
return '<%s>' % self.__class__.__name__
return "<%s>" % self.__class__.__name__
def to_python(self, value, timezone_in_use):
'''
"""
Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this.
The timezone_in_use parameter should be consulted when parsing datetime fields.
'''
"""
return value # pragma: no cover
def validate(self, value):
'''
"""
Called after to_python to validate that the value is suitable for the field's database type.
Subclasses should override this.
'''
"""
pass
def _range_check(self, value, min_value, max_value):
'''
"""
Utility method to check that the given value is between min_value and max_value.
'''
"""
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
raise ValueError(
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
)
def to_db_string(self, value, quote=True):
'''
"""
Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes.
'''
"""
return escape(value, quote)
def get_sql(self, with_default_expression=True, db=None):
'''
"""
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
- `with_default_expression`: If True, adds default value to sql.
It doesn't affect fields with alias and materialized values.
- `db`: Database, used for checking supported features.
'''
"""
sql = self.db_type
args = self.get_db_type_args()
if args:
sql += '(%s)' % comma_join(args)
sql += "(%s)" % comma_join(args)
if with_default_expression:
sql += self._extra_params(db)
return sql
@ -99,18 +107,18 @@ class Field(FunctionOperatorsMixin):
return []
def _extra_params(self, db):
sql = ''
sql = ""
if self.alias:
sql += ' ALIAS %s' % string_or_func(self.alias)
sql += " ALIAS %s" % string_or_func(self.alias)
elif self.materialized:
sql += ' MATERIALIZED %s' % string_or_func(self.materialized)
sql += " MATERIALIZED %s" % string_or_func(self.materialized)
elif isinstance(self.default, F):
sql += ' DEFAULT %s' % self.default.to_sql()
sql += " DEFAULT %s" % self.default.to_sql()
elif self.default:
default = self.to_db_string(self.default)
sql += ' DEFAULT %s' % default
sql += " DEFAULT %s" % default
if self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec
sql += " CODEC(%s)" % self.codec
return sql
def isinstance(self, types):
@ -124,43 +132,42 @@ class Field(FunctionOperatorsMixin):
"""
if isinstance(self, types):
return True
inner_field = getattr(self, 'inner_field', None)
inner_field = getattr(self, "inner_field", None)
while inner_field:
if isinstance(inner_field, types):
return True
inner_field = getattr(inner_field, 'inner_field', None)
inner_field = getattr(inner_field, "inner_field", None)
return False
class StringField(Field):
class_default = ''
db_type = 'String'
class_default = ""
db_type = "String"
def to_python(self, value, timezone_in_use):
if isinstance(value, str):
return value
if isinstance(value, bytes):
return value.decode('UTF-8')
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
return value.decode("utf-8")
raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
self._length = length
self.db_type = 'FixedString(%d)' % length
self.db_type = "FixedString(%d)" % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly)
def to_python(self, value, timezone_in_use):
value = super(FixedStringField, self).to_python(value, timezone_in_use)
return value.rstrip('\0')
return value.rstrip("\0")
def validate(self, value):
if isinstance(value, str):
value = value.encode('UTF-8')
value = value.encode("utf-8")
if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length))
raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
class DateField(Field):
@ -168,7 +175,7 @@ class DateField(Field):
min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31)
class_default = min_value
db_type = 'Date'
db_type = "Date"
def to_python(self, value, timezone_in_use):
if isinstance(value, datetime.datetime):
@ -178,10 +185,10 @@ class DateField(Field):
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, str):
if value == '0000-00-00':
if value == "0000-00-00":
return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value)
@ -193,10 +200,9 @@ class DateField(Field):
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
db_type = "DateTime"
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None):
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
super().__init__(default, alias, materialized, readonly, codec)
# assert not timezone, 'Temporarily field timezone is not supported'
if timezone:
@ -217,7 +223,7 @@ class DateTimeField(Field):
if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str):
if value == '0000-00-00 00:00:00':
if value == "0000-00-00 00:00:00":
return self.class_default
if len(value) == 10:
try:
@ -235,19 +241,20 @@ class DateTimeField(Field):
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = timezone_in_use.localize(dt)
return dt
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
return escape('%010d' % timegm(value.utctimetuple()), quote)
return escape("%010d" % timegm(value.utctimetuple()), quote)
class DateTime64Field(DateTimeField):
db_type = 'DateTime64'
db_type = "DateTime64"
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None,
timezone=None, precision=6):
def __init__(
self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
):
super().__init__(default, alias, materialized, readonly, codec, timezone)
assert precision is None or isinstance(precision, int), 'Precision must be int type'
assert precision is None or isinstance(precision, int), "Precision must be int type"
self.precision = precision
def get_db_type_args(self):
@ -263,11 +270,10 @@ class DateTime64Field(DateTimeField):
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
"""
return escape(
'{timestamp:0{width}.{precision}f}'.format(
timestamp=value.timestamp(),
width=11 + self.precision,
precision=self.precision),
quote
"{timestamp:0{width}.{precision}f}".format(
timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
),
quote,
)
def to_python(self, value, timezone_in_use):
@ -277,8 +283,8 @@ class DateTime64Field(DateTimeField):
if isinstance(value, (int, float)):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str):
left_part = value.split('.')[0]
if left_part == '0000-00-00 00:00:00':
left_part = value.split(".")[0]
if left_part == "0000-00-00 00:00:00":
return self.class_default
if len(left_part) == 10:
try:
@ -290,14 +296,15 @@ class DateTime64Field(DateTimeField):
class BaseIntField(Field):
'''
"""
Abstract base class for all integer-type fields.
'''
"""
def to_python(self, value, timezone_in_use):
try:
return int(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
@ -311,69 +318,69 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField):
min_value = 0
max_value = 2**8 - 1
db_type = 'UInt8'
max_value = 2 ** 8 - 1
db_type = "UInt8"
class UInt16Field(BaseIntField):
min_value = 0
max_value = 2**16 - 1
db_type = 'UInt16'
max_value = 2 ** 16 - 1
db_type = "UInt16"
class UInt32Field(BaseIntField):
min_value = 0
max_value = 2**32 - 1
db_type = 'UInt32'
max_value = 2 ** 32 - 1
db_type = "UInt32"
class UInt64Field(BaseIntField):
min_value = 0
max_value = 2**64 - 1
db_type = 'UInt64'
max_value = 2 ** 64 - 1
db_type = "UInt64"
class Int8Field(BaseIntField):
min_value = -2**7
max_value = 2**7 - 1
db_type = 'Int8'
min_value = -(2 ** 7)
max_value = 2 ** 7 - 1
db_type = "Int8"
class Int16Field(BaseIntField):
min_value = -2**15
max_value = 2**15 - 1
db_type = 'Int16'
min_value = -(2 ** 15)
max_value = 2 ** 15 - 1
db_type = "Int16"
class Int32Field(BaseIntField):
min_value = -2**31
max_value = 2**31 - 1
db_type = 'Int32'
min_value = -(2 ** 31)
max_value = 2 ** 31 - 1
db_type = "Int32"
class Int64Field(BaseIntField):
min_value = -2**63
max_value = 2**63 - 1
db_type = 'Int64'
min_value = -(2 ** 63)
max_value = 2 ** 63 - 1
db_type = "Int64"
class BaseFloatField(Field):
'''
"""
Abstract base class for all float-type fields.
'''
"""
def to_python(self, value, timezone_in_use):
try:
return float(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
@ -383,25 +390,25 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField):
db_type = 'Float32'
db_type = "Float32"
class Float64Field(BaseFloatField):
db_type = 'Float64'
db_type = "Float64"
class DecimalField(Field):
'''
"""
Base class for all decimal fields. Can also be used directly.
'''
"""
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38'
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision'
assert 1 <= precision <= 38, "Precision must be between 1 and 38"
assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
self.precision = precision
self.scale = scale
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale)
self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
with localcontext() as ctx:
ctx.prec = 38
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
@ -413,10 +420,10 @@ class DecimalField(Field):
if not isinstance(value, Decimal):
try:
value = Decimal(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
except Exception:
raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
if not value.is_finite():
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value))
raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
return self._round(value)
def to_db_string(self, value, quote=True):
@ -432,30 +439,27 @@ class DecimalField(Field):
class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal32(%d)' % scale
self.db_type = "Decimal32(%d)" % scale
class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal64(%d)' % scale
self.db_type = "Decimal64(%d)" % scale
class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal128(%d)' % scale
self.db_type = "Decimal128(%d)" % scale
class BaseEnumField(Field):
'''
"""
Abstract base class for all enum-type fields.
'''
"""
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
self.enum_cls = enum_cls
@ -473,7 +477,7 @@ class BaseEnumField(Field):
except Exception:
return self.enum_cls(value)
if isinstance(value, bytes):
decoded = value.decode('UTF-8')
decoded = value.decode("utf-8")
try:
return self.enum_cls[decoded]
except Exception:
@ -482,38 +486,39 @@ class BaseEnumField(Field):
return self.enum_cls(value)
except (KeyError, ValueError):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True):
return escape(value.name, quote)
def get_db_type_args(self):
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
@classmethod
def create_ad_hoc_field(cls, db_type):
'''
"""
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
'''
"""
import re
from enum import Enum
members = {}
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
enum_cls = Enum("AdHocEnum", members)
field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
return field_class(enum_cls)
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
db_type = "Enum8"
class Enum16Field(BaseEnumField):
db_type = 'Enum16'
db_type = "Enum16"
class ArrayField(Field):
@ -530,9 +535,9 @@ class ArrayField(Field):
if isinstance(value, str):
value = parse_array(value)
elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8'))
value = parse_array(value.decode("utf-8"))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
def validate(self, value):
@ -541,19 +546,19 @@ class ArrayField(Field):
def to_db_string(self, value, quote=True):
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + comma_join(array) + ']'
return "[" + comma_join(array) + "]"
def get_sql(self, with_default_expression=True, db=None):
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support:
sql+= ' CODEC(%s)' % self.codec
sql += " CODEC(%s)" % self.codec
return sql
class UUIDField(Field):
class_default = UUID(int=0)
db_type = 'UUID'
db_type = "UUID"
def to_python(self, value, timezone_in_use):
if isinstance(value, UUID):
@ -567,7 +572,7 @@ class UUIDField(Field):
elif isinstance(value, tuple):
return UUID(fields=value)
else:
raise ValueError('Invalid value for UUIDField: %r' % value)
raise ValueError("Invalid value for UUIDField: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -576,7 +581,7 @@ class UUIDField(Field):
class IPv4Field(Field):
class_default = 0
db_type = 'IPv4'
db_type = "IPv4"
def to_python(self, value, timezone_in_use):
if isinstance(value, IPv4Address):
@ -584,7 +589,7 @@ class IPv4Field(Field):
elif isinstance(value, (bytes, str, int)):
return IPv4Address(value)
else:
raise ValueError('Invalid value for IPv4Address: %r' % value)
raise ValueError("Invalid value for IPv4Address: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -593,7 +598,7 @@ class IPv4Field(Field):
class IPv6Field(Field):
class_default = 0
db_type = 'IPv6'
db_type = "IPv6"
def to_python(self, value, timezone_in_use):
if isinstance(value, IPv6Address):
@ -601,7 +606,7 @@ class IPv6Field(Field):
elif isinstance(value, (bytes, str, int)):
return IPv6Address(value)
else:
raise ValueError('Invalid value for IPv6Address: %r' % value)
raise ValueError("Invalid value for IPv6Address: %r" % value)
def to_db_string(self, value, quote=True):
return escape(str(value), quote)
@ -611,9 +616,10 @@ class NullableField(Field):
class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None,
extra_null_values=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
assert isinstance(
inner_field, Field
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
self.inner_field = inner_field
self._null_values = [None]
if extra_null_values:
@ -621,7 +627,7 @@ class NullableField(Field):
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values:
if value == "\\N" or value in self._null_values:
return None
return self.inner_field.to_python(value, timezone_in_use)
@ -630,22 +636,27 @@ class NullableField(Field):
def to_db_string(self, value, quote=True):
if value in self._null_values:
return '\\N'
return "\\N"
return self.inner_field.to_db_string(value, quote=quote)
def get_sql(self, with_default_expression=True, db=None):
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db)
sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression:
sql += self._extra_params(db)
return sql
class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
assert isinstance(
inner_field, Field
), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
assert not isinstance(
inner_field, LowCardinalityField
), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(
inner_field, ArrayField
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
self.inner_field = inner_field
self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
@ -661,10 +672,14 @@ class LowCardinalityField(Field):
def get_sql(self, with_default_expression=True, db=None):
if db and db.has_low_cardinality_support:
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False)
sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
else:
sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__))
logger.warning(
"LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback".format(
self.inner_field.__class__.__name__
)
)
if with_default_expression:
sql += self._extra_params(db)
return sql

File diff suppressed because it is too large Load Diff

View File

@ -1,70 +1,71 @@
from .models import Model, BufferModel
from .fields import DateField, StringField
from .engines import MergeTree
from .utils import escape, get_subclass_names
import logging
logger = logging.getLogger('migrations')
from .engines import MergeTree
from .fields import DateField, StringField
from .models import BufferModel, Model
from .utils import get_subclass_names
logger = logging.getLogger("migrations")
class Operation():
'''
class Operation:
"""
Base class for migration operations.
'''
"""
def apply(self, database):
raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation):
'''
"""
Base class for migration operations that work on a specific model.
'''
"""
def __init__(self, model_class):
'''
"""
Initializer.
'''
"""
self.model_class = model_class
self.table_name = model_class.table_name()
def _alter_table(self, database, cmd):
'''
"""
Utility for running ALTER TABLE commands.
'''
"""
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
logger.debug(cmd)
database.raw(cmd)
class CreateTable(ModelOperation):
'''
"""
A migration operation that creates a table for a given model class.
'''
"""
def apply(self, database):
logger.info(' Create table %s', self.table_name)
logger.info(" Create table %s", self.table_name)
if issubclass(self.model_class, BufferModel):
database.create_table(self.model_class.engine.main_model)
database.create_table(self.model_class)
class AlterTable(ModelOperation):
'''
"""
A migration operation that compares the table of a given model class to
the model's fields, and alters the table to match the model. The operation can:
- add new columns
- drop obsolete columns
- modify column types
Default values are not altered by this operation.
'''
"""
def _get_table_fields(self, database):
query = "DESC `%s`.`%s`" % (database.db_name, self.table_name)
return [(row.name, row.type) for row in database.select(query)]
def apply(self, database):
logger.info(' Alter table %s', self.table_name)
logger.info(" Alter table %s", self.table_name)
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
# ADD COLUMN ... AFTER doesn't affect it
@ -73,8 +74,8 @@ class AlterTable(ModelOperation):
# Identify fields that were deleted from the model
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
for name in deleted_fields:
logger.info(' Drop column %s', name)
self._alter_table(database, 'DROP COLUMN %s' % name)
logger.info(" Drop column %s", name)
self._alter_table(database, "DROP COLUMN %s" % name)
del table_fields[name]
# Identify fields that were added to the model
@ -82,11 +83,11 @@ class AlterTable(ModelOperation):
for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias)
if name not in table_fields:
logger.info(' Add column %s', name)
assert prev_name, 'Cannot add a column to the beginning of the table'
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database))
logger.info(" Add column %s", name)
assert prev_name, "Cannot add a column to the beginning of the table"
cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field:
cmd += ' AFTER %s' % prev_name
cmd += " AFTER %s" % prev_name
self._alter_table(database, cmd)
if is_regular_field:
@ -97,25 +98,28 @@ class AlterTable(ModelOperation):
# Identify fields whose type was changed
# The order of class attributes can be changed any time, so we can't count on it
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()}
# attribute position. Watch https://github.com/Infinidat/clickhouse_orm/issues/47
model_fields = {
name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()
}
for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement'
assert field_name in model_fields, "Model fields and table columns in disagreement"
if field_sql != model_fields[field_name]:
logger.info(' Change type of column %s from %s to %s', field_name, field_sql,
model_fields[field_name])
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name]))
logger.info(
" Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
)
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
class AlterTableWithBuffer(ModelOperation):
'''
"""
A migration operation for altering a buffer table and its underlying on-disk table.
The buffer table is dropped, the on-disk table is altered, and then the buffer table
is re-created.
'''
"""
def apply(self, database):
if issubclass(self.model_class, BufferModel):
@ -127,149 +131,152 @@ class AlterTableWithBuffer(ModelOperation):
class DropTable(ModelOperation):
'''
"""
A migration operation that drops the table of a given model class.
'''
"""
def apply(self, database):
logger.info(' Drop table %s', self.table_name)
logger.info(" Drop table %s", self.table_name)
database.drop_table(self.model_class)
class AlterConstraints(ModelOperation):
'''
"""
A migration operation that adds new constraints from the model to the database
table, and drops obsolete ones. Constraints are identified by their names, so
a change in an existing constraint will not be detected unless its name was changed too.
ClickHouse does not check that the constraints hold for existing data in the table.
'''
"""
def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name)
logger.info(" Alter constraints for %s", self.table_name)
existing = self._get_constraint_names(database)
# Go over constraints in the model
for constraint in self.model_class._constraints.values():
# Check if it's a new constraint
if constraint.name not in existing:
logger.info(' Add constraint %s', constraint.name)
self._alter_table(database, 'ADD %s' % constraint.create_table_sql())
logger.info(" Add constraint %s", constraint.name)
self._alter_table(database, "ADD %s" % constraint.create_table_sql())
else:
existing.remove(constraint.name)
# Remaining constraints in `existing` are obsolete
for name in existing:
logger.info(' Drop constraint %s', name)
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name)
logger.info(" Drop constraint %s", name)
self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
def _get_constraint_names(self, database):
'''
"""
Returns a set containing the names of existing constraints in the table.
'''
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def)
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
return set(matches)
class AlterIndexes(ModelOperation):
'''
"""
A migration operation that adds new indexes from the model to the database
table, and drops obsolete ones. Indexes are identified by their names, so
a change in an existing index will not be detected unless its name was changed too.
'''
"""
def __init__(self, model_class, reindex=False):
'''
"""
Initializer.
By default ClickHouse does not build indexes over existing data, only for
new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build
the indexes over the existing data.
'''
"""
super().__init__(model_class)
self.reindex = reindex
def apply(self, database):
logger.info(' Alter indexes for %s', self.table_name)
logger.info(" Alter indexes for %s", self.table_name)
existing = self._get_index_names(database)
logger.info(existing)
# Go over indexes in the model
for index in self.model_class._indexes.values():
# Check if it's a new index
if index.name not in existing:
logger.info(' Add index %s', index.name)
self._alter_table(database, 'ADD %s' % index.create_table_sql())
logger.info(" Add index %s", index.name)
self._alter_table(database, "ADD %s" % index.create_table_sql())
else:
existing.remove(index.name)
# Remaining indexes in `existing` are obsolete
for name in existing:
logger.info(' Drop index %s', name)
self._alter_table(database, 'DROP INDEX `%s`' % name)
logger.info(" Drop index %s", name)
self._alter_table(database, "DROP INDEX `%s`" % name)
# Reindex
if self.reindex:
logger.info(' Build indexes on table')
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name)
logger.info(" Build indexes on table")
database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
def _get_index_names(self, database):
'''
"""
Returns a set containing the names of existing indexes in the table.
'''
"""
import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def)
table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
return set(matches)
class RunPython(Operation):
'''
"""
A migration operation that executes a Python function.
'''
"""
def __init__(self, func):
'''
"""
Initializer. The given Python function will be called with a single
argument - the Database instance to apply the migration to.
'''
"""
assert callable(func), "'func' argument must be function"
self._func = func
def apply(self, database):
logger.info(' Executing python operation %s', self._func.__name__)
logger.info(" Executing python operation %s", self._func.__name__)
self._func(database)
class RunSQL(Operation):
'''
"""
A migration operation that executes arbitrary SQL statements.
'''
"""
def __init__(self, sql):
'''
"""
Initializer. The given sql argument must be a valid SQL statement or
list of statements.
'''
"""
if isinstance(sql, str):
sql = [sql]
assert isinstance(sql, list), "'sql' argument must be string or list of strings"
self._sql = sql
def apply(self, database):
logger.info(' Executing raw SQL operations')
logger.info(" Executing raw SQL operations")
for item in self._sql:
database.raw(item)
class MigrationHistory(Model):
'''
"""
A model for storing which migrations were already applied to the containing database.
'''
"""
package_name = StringField()
module_name = StringField()
applied = DateField()
engine = MergeTree('applied', ('package_name', 'module_name'))
engine = MergeTree("applied", ("package_name", "module_name"))
@classmethod
def table_name(cls):
return 'infi_clickhouse_orm_migrations'
return "infi_clickhouse_orm_migrations"
# Expose only relevant classes in import *

View File

@ -1,4 +1,3 @@
from __future__ import unicode_literals
import sys
from collections import OrderedDict
from itertools import chain
@ -6,84 +5,83 @@ from logging import getLogger
import pytz
from .engines import Distributed, Merge
from .fields import Field, StringField
from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql, unescape
from .query import QuerySet
from .funcs import F
from .engines import Merge, Distributed
logger = getLogger('clickhouse_orm')
from .query import QuerySet
from .utils import NO_VALUE, arg_to_sql, get_subclass_names, parse_tsv
logger = getLogger("clickhouse_orm")
class Constraint:
'''
"""
Defines a model constraint.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
def __init__(self, expr):
'''
"""
Initializer. Expects an expression that ClickHouse will verify when inserting data.
'''
"""
self.expr = expr
def create_table_sql(self):
'''
"""
Returns the SQL statement for defining this constraint during table creation.
'''
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr))
"""
return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
class Index:
'''
"""
Defines a data-skipping index.
'''
"""
name = None # this is set by the parent model
parent = None # this is set by the parent model
def __init__(self, expr, type, granularity):
'''
"""
Initializer.
- `expr` - a column, expression, or tuple of columns and expressions to index.
- `type` - the index type. Use one of the following methods to specify the type:
`Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`.
- `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine).
'''
"""
self.expr = expr
self.type = type
self.granularity = granularity
def create_table_sql(self):
'''
"""
Returns the SQL statement for defining this index during table creation.
'''
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
"""
return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
@staticmethod
def minmax():
'''
"""
An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
'''
return 'minmax'
"""
return "minmax"
@staticmethod
def set(max_rows):
'''
"""
An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data.
'''
return 'set(%d)' % max_rows
"""
return "set(%d)" % max_rows
@staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
'''
"""
An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -92,12 +90,12 @@ class Index:
for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
'''
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
"""
return "ngrambf_v1(%d, %d, %d, %d)" % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
'''
"""
An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters.
@ -105,28 +103,28 @@ class Index:
for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions.
'''
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
"""
return "tokenbf_v1(%d, %d, %d)" % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod
def bloom_filter(false_positive=0.025):
'''
"""
An index that stores a Bloom filter containing values of the index expression.
- `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter
'''
return 'bloom_filter(%f)' % false_positive
"""
return "bloom_filter(%f)" % false_positive
class ModelBase(type):
'''
"""
A metaclass for ORM models. It adds the _fields list to model classes.
'''
"""
ad_hoc_model_cache = {}
def __new__(cls, name, bases, attrs):
def __new__(metacls, name, bases, attrs):
# Collect fields, constraints and indexes from parent classes
fields = {}
@ -170,90 +168,88 @@ class ModelBase(type):
_indexes=indexes,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults
_has_funcs_as_defaults=has_funcs_as_defaults,
)
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs)
model = super(ModelBase, metacls).__new__(metacls, str(name), bases, attrs)
# Let each field, constraint and index know its parent and its own name
for n, obj in chain(fields, constraints.items(), indexes.items()):
setattr(obj, 'parent', model)
setattr(obj, 'name', n)
obj.parent = model
obj.name = n
return model
@classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'):
def create_ad_hoc_model(metacls, fields, model_name="AdHocModel"):
# fields is a list of tuples (name, db_type)
# Check if model exists in cache
fields = list(fields)
cache_key = model_name + ' ' + str(fields)
if cache_key in cls.ad_hoc_model_cache:
return cls.ad_hoc_model_cache[cache_key]
cache_key = model_name + " " + str(fields)
if cache_key in metacls.ad_hoc_model_cache:
return metacls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class
attrs = {}
for name, db_type in fields:
attrs[name] = cls.create_ad_hoc_field(db_type)
model_class = cls.__new__(cls, model_name, (Model,), attrs)
attrs[name] = metacls.create_ad_hoc_field(db_type)
model_class = metacls.__new__(metacls, model_name, (Model,), attrs)
# Add the model class to the cache
cls.ad_hoc_model_cache[cache_key] = model_class
metacls.ad_hoc_model_cache[cache_key] = model_class
return model_class
@classmethod
def create_ad_hoc_field(cls, db_type):
import infi.clickhouse_orm.fields as orm_fields
def create_ad_hoc_field(metacls, db_type):
import clickhouse_orm.fields as orm_fields
# Enums
if db_type.startswith('Enum'):
if db_type.startswith("Enum"):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# DateTime with timezone
if db_type.startswith('DateTime('):
if db_type.startswith("DateTime("):
timezone = db_type[9:-1]
return orm_fields.DateTimeField(
timezone=timezone[1:-1] if timezone else None
)
return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
# DateTime64
if db_type.startswith('DateTime64('):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')]
if db_type.startswith("DateTime64("):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
return orm_fields.DateTime64Field(
precision=int(precision),
timezone=timezone[0][1:-1] if timezone else None
precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
)
# Arrays
if db_type.startswith('Array'):
inner_field = cls.create_ad_hoc_field(db_type[6 : -1])
if db_type.startswith("Array"):
inner_field = metacls.create_ad_hoc_field(db_type[6:-1])
return orm_fields.ArrayField(inner_field)
# Tuples (poor man's version - convert to array)
if db_type.startswith('Tuple'):
types = [s.strip() for s in db_type[6 : -1].split(',')]
assert len(set(types)) == 1, 'No support for mixed types in tuples - ' + db_type
inner_field = cls.create_ad_hoc_field(types[0])
if db_type.startswith("Tuple"):
types = [s.strip() for s in db_type[6:-1].split(",")]
assert len(set(types)) == 1, "No support for mixed types in tuples - " + db_type
inner_field = metacls.create_ad_hoc_field(types[0])
return orm_fields.ArrayField(inner_field)
# FixedString
if db_type.startswith('FixedString'):
length = int(db_type[12 : -1])
if db_type.startswith("FixedString"):
length = int(db_type[12:-1])
return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'):
p = db_type.index('(')
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')]
field_class = getattr(orm_fields, db_type[:p] + 'Field')
if db_type.startswith("Decimal"):
p = db_type.index("(")
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
field_class = getattr(orm_fields, db_type[:p] + "Field")
return field_class(*args)
# Nullable
if db_type.startswith('Nullable'):
inner_field = cls.create_ad_hoc_field(db_type[9 : -1])
if db_type.startswith("Nullable"):
inner_field = metacls.create_ad_hoc_field(db_type[9:-1])
return orm_fields.NullableField(inner_field)
# LowCardinality
if db_type.startswith('LowCardinality'):
inner_field = cls.create_ad_hoc_field(db_type[15 : -1])
if db_type.startswith("LowCardinality"):
inner_field = metacls.create_ad_hoc_field(db_type[15:-1])
return orm_fields.LowCardinalityField(inner_field)
# Simple fields
name = db_type + 'Field'
name = db_type + "Field"
if not hasattr(orm_fields, name):
raise NotImplementedError('No field class for %s' % db_type)
raise NotImplementedError("No field class for %s" % db_type)
return getattr(orm_fields, name)()
class Model(metaclass=ModelBase):
'''
"""
A base class for ORM models. Each model class represent a ClickHouse table. For example:
class CPUStats(Model):
@ -261,7 +257,7 @@ class Model(metaclass=ModelBase):
cpu_id = UInt16Field()
cpu_percent = Float32Field()
engine = Memory()
'''
"""
engine = None
@ -274,12 +270,12 @@ class Model(metaclass=ModelBase):
_database = None
def __init__(self, **kwargs):
'''
"""
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
'''
"""
super(Model, self).__init__()
# Assign default values
self.__dict__.update(self._defaults)
@ -289,13 +285,13 @@ class Model(metaclass=ModelBase):
if field:
setattr(self, name, value)
else:
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name))
raise AttributeError("%s does not have a field called %s" % (self.__class__.__name__, name))
def __setattr__(self, name, value):
'''
"""
When setting a field value, converts the value to its Pythonic type and validates it.
This may raise a `ValueError`.
'''
"""
field = self.get_field(name)
if field and (value != NO_VALUE):
try:
@ -308,77 +304,78 @@ class Model(metaclass=ModelBase):
super(Model, self).__setattr__(name, value)
def set_database(self, db):
'''
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
'''
"""
# This can not be imported globally due to circular import
from .database import Database
assert isinstance(db, Database), "database must be database.Database instance"
self._database = db
def get_database(self):
'''
"""
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
'''
"""
return self._database
def get_field(self, name):
'''
"""
Gets a `Field` instance given its name, or `None` if not found.
'''
"""
return self._fields.get(name)
@classmethod
def table_name(cls):
'''
"""
Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use
a different table name.
'''
"""
return cls.__name__.lower()
@classmethod
def has_funcs_as_defaults(cls):
'''
"""
Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances.
'''
"""
return cls._has_funcs_as_defaults
@classmethod
def create_table_sql(cls, db):
'''
"""
Returns the SQL statement for creating a table for this model.
'''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
"""
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
# Fields
items = []
for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db)))
items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints
for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql())
items.append(" %s" % c.create_table_sql())
# Indexes
for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql())
parts.append(',\n'.join(items))
items.append(" %s" % i.create_table_sql())
parts.append(",\n".join(items))
# Engine
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts)
parts.append(")")
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return "\n".join(parts)
@classmethod
def drop_table_sql(cls, db):
'''
"""
Returns the SQL command for deleting this model's table.
'''
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name())
"""
return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
@classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
'''
"""
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
@ -386,12 +383,12 @@ class Model(metaclass=ModelBase):
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to.
'''
"""
values = iter(parse_tsv(line))
kwargs = {}
for name in field_names:
field = getattr(cls, name)
field_timezone = getattr(field, 'timezone', None) or timezone_in_use
field_timezone = getattr(field, "timezone", None) or timezone_in_use
kwargs[name] = field.to_python(next(values), field_timezone)
obj = cls(**kwargs)
@ -401,45 +398,45 @@ class Model(metaclass=ModelBase):
return obj
def to_tsv(self, include_readonly=True):
'''
"""
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.
'''
"""
data = self.__dict__
fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
return "\t".join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_tskv(self, include_readonly=True):
'''
"""
Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database.
'''
"""
data = self.__dict__
fields = self.fields(writable=not include_readonly)
parts = []
for name, field in fields.items():
if data[name] != NO_VALUE:
parts.append(name + '=' + field.to_db_string(data[name], quote=False))
return '\t'.join(parts)
parts.append(name + "=" + field.to_db_string(data[name], quote=False))
return "\t".join(parts)
def to_db_string(self):
'''
"""
Returns the instance as a bytestring ready to be inserted into the database.
'''
"""
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n'
return s.encode('utf-8')
s += "\n"
return s.encode("utf-8")
def to_dict(self, include_readonly=True, field_names=None):
'''
"""
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
'''
"""
fields = self.fields(writable=not include_readonly)
if field_names is not None:
@ -450,56 +447,58 @@ class Model(metaclass=ModelBase):
@classmethod
def objects_in(cls, database):
'''
"""
Returns a `QuerySet` for selecting instances of this model class.
'''
"""
return QuerySet(cls, database)
@classmethod
def fields(cls, writable=False):
'''
"""
Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included.
Callers should not modify the dictionary.
'''
"""
# noinspection PyProtectedMember,PyUnresolvedReferences
return cls._writable_fields if writable else cls._fields
@classmethod
def is_read_only(cls):
'''
"""
Returns true if the model is marked as read only.
'''
"""
return cls._readonly
@classmethod
def is_system_model(cls):
'''
"""
Returns true if the model represents a system table.
'''
"""
return cls._system
class BufferModel(Model):
@classmethod
def create_table_sql(cls, db):
'''
"""
Returns the SQL statement for creating a table for this model.
'''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name,
cls.engine.main_model.table_name())]
"""
parts = [
"CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
% (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str)
return ' '.join(parts)
return " ".join(parts)
class MergeModel(Model):
'''
"""
Model for Merge engine
Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
'''
"""
readonly = True
# Virtual fields can't be inserted into database
@ -507,19 +506,20 @@ class MergeModel(Model):
@classmethod
def create_table_sql(cls, db):
'''
"""
Returns the SQL statement for creating a table for this model.
'''
"""
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
cols = []
for name, field in cls.fields().items():
if name != '_table':
cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols))
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql(db))
return '\n'.join(parts)
if name != "_table":
cols.append(" %s %s" % (name, field.get_sql(db=db)))
parts.append(",\n".join(cols))
parts.append(")")
parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return "\n".join(parts)
# TODO: base class for models that require specific engine
@ -530,10 +530,10 @@ class DistributedModel(Model):
"""
def set_database(self, db):
'''
"""
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
'''
"""
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db)
return res
@ -576,33 +576,37 @@ class DistributedModel(Model):
return
# find out all the superclasses of the Model that store any data
storage_models = [b for b in cls.__bases__ if issubclass(b, Model)
and not issubclass(b, DistributedModel)]
storage_models = [b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)]
if not storage_models:
raise TypeError("When defining Distributed engine without the table_name "
"ensure that your model has a parent model")
raise TypeError(
"When defining Distributed engine without the table_name " "ensure that your model has a parent model"
)
if len(storage_models) > 1:
raise TypeError("When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass")
raise TypeError(
"When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass"
)
# enable correct SQL for engine
cls.engine.table = storage_models[0]
@classmethod
def create_table_sql(cls, db):
'''
"""
Returns the SQL statement for creating a table for this model.
'''
"""
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table()
parts = [
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format(
db.db_name, cls.table_name(), cls.engine.table_name),
'ENGINE = ' + cls.engine.create_table_sql(db)]
return '\n'.join(parts)
"CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
db.db_name, cls.table_name(), cls.engine.table_name
),
"ENGINE = " + cls.engine.create_table_sql(db),
]
return "\n".join(parts)
# Expose only relevant classes in import *

View File

@ -1,15 +1,15 @@
from __future__ import unicode_literals
import pytz
from copy import copy, deepcopy
from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql
import pytz
from .engines import CollapsingMergeTree, ReplacingMergeTree
from .utils import Page, arg_to_sql, comma_join, string_or_func
# TODO
# - check that field names are valid
class Operator(object):
"""
Base class for filtering operators.
@ -23,9 +23,10 @@ class Operator(object):
raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True):
from infi.clickhouse_orm.funcs import F
if isinstance(value, F):
if isinstance(value, Cond):
# This is an 'in-database' value, rather than a python one
return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -41,9 +42,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value])
if value == "\\N" and self._sql_for_null is not None:
return " ".join([field_name, self._sql_for_null])
return " ".join([field_name, self._sql_operator, value])
class InOperator(Operator):
@ -63,7 +64,7 @@ class InOperator(Operator):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value)
return "%s IN (%s)" % (field_name, value)
class LikeOperator(Operator):
@ -79,12 +80,12 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_')
value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value)
if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern)
return "%s LIKE '%s'" % (field_name, pattern)
else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern)
return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field_name, pattern)
class IExactOperator(Operator):
@ -95,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value)
return "lowerUTF8(%s) = lowerUTF8(%s)" % (field_name, value)
class NotOperator(Operator):
@ -108,7 +109,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls, field_name, value):
# Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value)
return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator):
@ -126,35 +127,38 @@ class BetweenOperator(Operator):
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1)
return "%s BETWEEN %s AND %s" % (field_name, value0, value1)
if value0 and not value1:
return ' '.join([field_name, '>=', value0])
return " ".join([field_name, ">=", value0])
if value1 and not value0:
return ' '.join([field_name, '<=', value1])
return " ".join([field_name, "<=", value1])
# Define the set of builtin operators
_operators = {}
def register_operator(name, sql):
_operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL'))
register_operator('gt', SimpleOperator('>'))
register_operator('gte', SimpleOperator('>='))
register_operator('lt', SimpleOperator('<'))
register_operator('lte', SimpleOperator('<='))
register_operator('between', BetweenOperator())
register_operator('in', InOperator())
register_operator('not_in', NotOperator(InOperator()))
register_operator('contains', LikeOperator('%{}%'))
register_operator('startswith', LikeOperator('{}%'))
register_operator('endswith', LikeOperator('%{}'))
register_operator('icontains', LikeOperator('%{}%', False))
register_operator('istartswith', LikeOperator('{}%', False))
register_operator('iendswith', LikeOperator('%{}', False))
register_operator('iexact', IExactOperator())
register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator("gt", SimpleOperator(">"))
register_operator("gte", SimpleOperator(">="))
register_operator("lt", SimpleOperator("<"))
register_operator("lte", SimpleOperator("<="))
register_operator("between", BetweenOperator())
register_operator("in", InOperator())
register_operator("not_in", NotOperator(InOperator()))
register_operator("contains", LikeOperator("%{}%"))
register_operator("startswith", LikeOperator("{}%"))
register_operator("endswith", LikeOperator("%{}"))
register_operator("icontains", LikeOperator("%{}%", False))
register_operator("istartswith", LikeOperator("{}%", False))
register_operator("iendswith", LikeOperator("%{}", False))
register_operator("iexact", IExactOperator())
class Cond(object):
@ -170,19 +174,20 @@ class FieldCond(Cond):
"""
A single query condition made up of Field + Operator + Value.
"""
def __init__(self, field_name, operator, value):
self._field_name = field_name
self._operator = _operators.get(operator)
if self._operator is None:
# The field name contains __ like my__field
self._field_name = field_name + '__' + operator
self._operator = _operators['eq']
self._field_name = field_name + "__" + operator
self._operator = _operators["eq"]
self._value = value
def to_sql(self, model_cls):
return self._operator.to_sql(model_cls, self._field_name, self._value)
def __deepcopy__(self, memodict={}):
def __deepcopy__(self, memo):
res = copy(self)
res._value = deepcopy(self._value)
return res
@ -190,8 +195,8 @@ class FieldCond(Cond):
class Q(object):
AND_MODE = 'AND'
OR_MODE = 'OR'
AND_MODE = "AND"
OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
@ -205,29 +210,26 @@ class Q(object):
Checks if there are any conditions in Q object
Returns: Boolean
"""
return not bool(self._conds or self._children)
return not (self._conds or self._children)
@classmethod
def _construct_from(cls, l_child, r_child, mode):
if mode == l_child._mode:
if mode == l_child._mode and not l_child._negate:
q = deepcopy(l_child)
q._children.append(deepcopy(r_child))
elif mode == r_child._mode:
q = deepcopy(r_child)
q._children.append(deepcopy(l_child))
else:
# Different modes
q = Q()
q = cls()
q._children = [l_child, r_child]
q._mode = mode # AND/OR
q._mode = mode
return q
def _build_cond(self, key, value):
if '__' in key:
field_name, operator = key.rsplit('__', 1)
if "__" in key:
field_name, operator = key.rsplit("__", 1)
else:
field_name, operator = key, 'eq'
field_name, operator = key, "eq"
return FieldCond(field_name, operator, value)
def to_sql(self, model_cls):
@ -241,24 +243,30 @@ class Q(object):
if not condition_sql:
# Empty Q() object returns everything
sql = '1'
sql = "1"
elif len(condition_sql) == 1:
# Skip not needed brackets over single condition
sql = condition_sql[0]
else:
# Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql)
sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate:
sql = 'NOT (%s)' % sql
sql = "NOT (%s)" % sql
return sql
def __or__(self, other):
return Q._construct_from(self, other, self.OR_MODE)
if not isinstance(other, Q):
return NotImplemented
return self.__class__._construct_from(self, other, self.OR_MODE)
def __and__(self, other):
return Q._construct_from(self, other, self.AND_MODE)
if not isinstance(other, Q):
return NotImplemented
return self.__class__._construct_from(self, other, self.AND_MODE)
def __invert__(self):
q = copy(self)
@ -268,8 +276,8 @@ class Q(object):
def __bool__(self):
return not self.is_empty
def __deepcopy__(self, memodict={}):
q = Q()
def __deepcopy__(self, memo):
q = self.__class__()
q._conds = [deepcopy(cond) for cond in self._conds]
q._negate = self._negate
q._mode = self._mode
@ -327,17 +335,17 @@ class QuerySet(object):
def __getitem__(self, s):
if isinstance(s, int):
# Single index
assert s >= 0, 'negative indexes are not supported'
assert s >= 0, "negative indexes are not supported"
qs = copy(self)
qs._limits = (s, 1)
return next(iter(qs))
else:
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0
stop = s.stop or 2**63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end'
stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, "start of slice cannot be smaller than its end"
qs = copy(self)
qs._limits = (start, stop - start)
return qs
@ -353,7 +361,7 @@ class QuerySet(object):
offset_limit = (0, offset_limit)
offset = offset_limit[0]
limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
assert offset >= 0 and limit >= 0, "negative limits are not supported"
qs = copy(self)
qs._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr
@ -363,44 +371,44 @@ class QuerySet(object):
"""
Returns the selected fields or expressions as a SQL string.
"""
fields = '*'
fields = "*"
if self._fields:
fields = comma_join('`%s`' % field for field in self._fields)
fields = comma_join("`%s`" % field for field in self._fields)
return fields
def as_sql(self):
"""
Returns the whole query as a SQL string.
"""
distinct = 'DISTINCT ' if self._distinct else ''
final = ' FINAL' if self._final else ''
table_name = '`%s`' % self._model_cls.table_name()
distinct = "DISTINCT " if self._distinct else ""
final = " FINAL" if self._final else ""
table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model():
table_name = '`system`.' + table_name
table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = u'SELECT %s%s\nFROM %s%s' % params
sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True)
sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False)
sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields)
sql += "\nGROUP BY %s" % comma_join("`%s`" % field for field in self._grouping_fields)
if self._grouping_with_totals:
sql += ' WITH TOTALS'
sql += " WITH TOTALS"
if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql()
sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields)
sql += "\nLIMIT %d, %d" % self._limit_by
sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits:
sql += '\nLIMIT %d, %d' % self._limits
sql += "\nLIMIT %d, %d" % self._limits
return sql
@ -408,10 +416,12 @@ class QuerySet(object):
"""
Returns the contents of the query's `ORDER BY` clause as a string.
"""
return comma_join([
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field)
return comma_join(
[
"%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
for field in self._order_by
])
]
)
def conditions_as_sql(self, prewhere=False):
"""
@ -426,7 +436,7 @@ class QuerySet(object):
"""
if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -453,10 +463,8 @@ class QuerySet(object):
return qs
def _filter_or_exclude(self, *q, **kwargs):
from .funcs import F
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop("prewhere", False)
qs = copy(self)
@ -464,10 +472,10 @@ class QuerySet(object):
for arg in q:
if isinstance(arg, Q):
condition &= arg
elif isinstance(arg, F):
elif isinstance(arg, Cond):
condition &= Q(arg)
else:
raise TypeError('Invalid argument "%r" to queryset filter' % arg)
raise TypeError(f"Invalid argument '{arg}' of type '{type(arg)}' to filter")
if kwargs:
condition &= Q(**kwargs)
@ -509,20 +517,19 @@ class QuerySet(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`.
"""
from .database import Page
count = self.count()
pages_total = int(ceil(count / float(page_size)))
if page_num == -1:
page_num = pages_total
elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num)
raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size
return Page(
objects=list(self[offset : offset + page_size]),
number_of_objects=count,
pages_total=pages_total,
number=page_num,
page_size=page_size
page_size=page_size,
)
def distinct(self):
@ -539,9 +546,10 @@ class QuerySet(object):
Adds a FINAL modifier to table, meaning data will be collapsed to final version.
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
"""
from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines')
raise TypeError(
"final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines"
)
qs = copy(self)
qs._final = True
@ -554,7 +562,7 @@ class QuerySet(object):
"""
self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions)
sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql)
return self
@ -564,22 +572,22 @@ class QuerySet(object):
Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate.
"""
assert kwargs, 'No fields specified for update'
assert kwargs, "No fields specified for update"
self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions)
sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (self._model_cls.table_name(), fields, conditions)
self._database.raw(sql)
return self
def _verify_mutation_allowed(self):
'''
"""
Checks that the queryset's state allows mutations. Raises an AssertionError if not.
'''
assert not self._limits, 'Mutations are not allowed after slicing the queryset'
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)'
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()'
"""
assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs):
"""
@ -619,7 +627,7 @@ class AggregateQuerySet(QuerySet):
At least one calculated field is required.
"""
super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database)
assert calculated_fields, 'No calculated fields specified for aggregation'
assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields
self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields
@ -636,8 +644,9 @@ class AggregateQuerySet(QuerySet):
created with.
"""
for name in args:
assert name in self._fields or name in self._calculated_fields, \
'Cannot group by `%s` since it is not included in the query' % name
assert name in self._fields or name in self._calculated_fields, (
"Cannot group by `%s` since it is not included in the query" % name
)
qs = copy(self)
qs._grouping_fields = args
return qs
@ -652,13 +661,15 @@ class AggregateQuerySet(QuerySet):
"""
This method is not supported on `AggregateQuerySet`.
"""
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet')
raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self):
"""
Returns the selected fields or expressions as a SQL string.
"""
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()])
return comma_join(
[str(f) for f in self._fields] + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model
@ -667,7 +678,7 @@ class AggregateQuerySet(QuerySet):
"""
Returns the number of rows after aggregation.
"""
sql = u'SELECT count() FROM (%s)' % self.as_sql()
sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql)
return int(raw) if raw else 0
@ -682,7 +693,7 @@ class AggregateQuerySet(QuerySet):
return qs
def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet')
raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import *

View File

@ -2,10 +2,8 @@
This file contains system readonly models that can be got from the database
https://clickhouse.tech/docs/en/system_tables/
"""
from __future__ import unicode_literals
from .database import Database
from .fields import *
from .fields import DateTimeField, StringField, UInt8Field, UInt32Field, UInt64Field
from .models import Model
from .utils import comma_join
@ -16,7 +14,8 @@ class SystemPart(Model):
This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/
"""
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
_readonly = True
_system = True
@ -51,12 +50,13 @@ class SystemPart(Model):
@classmethod
def table_name(cls):
return 'parts'
return "parts"
"""
Next methods return SQL for some operations, which can be done with partitions
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
"""
def _partition_operation_sql(self, operation, settings=None, from_part=None):
"""
Performs some operation over partition
@ -83,7 +83,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('DETACH', settings=settings)
return self._partition_operation_sql("DETACH", settings=settings)
def drop(self, settings=None):
"""
@ -93,7 +93,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('DROP', settings=settings)
return self._partition_operation_sql("DROP", settings=settings)
def attach(self, settings=None):
"""
@ -103,7 +103,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('ATTACH', settings=settings)
return self._partition_operation_sql("ATTACH", settings=settings)
def freeze(self, settings=None):
"""
@ -113,7 +113,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('FREEZE', settings=settings)
return self._partition_operation_sql("FREEZE", settings=settings)
def fetch(self, zookeeper_path, settings=None):
"""
@ -124,7 +124,7 @@ class SystemPart(Model):
Returns: SQL Query
"""
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path)
return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
@classmethod
def get(cls, database, conditions=""):
@ -140,9 +140,12 @@ class SystemPart(Model):
assert isinstance(conditions, str), "conditions must be a string"
if conditions:
conditions += " AND"
field_names = ','.join(cls.fields())
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" %
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls)
field_names = ",".join(cls.fields())
return database.select(
"SELECT %s FROM `system`.%s WHERE %s database='%s'"
% (field_names, cls.table_name(), conditions, database.db_name),
model_class=cls,
)
@classmethod
def get_active(cls, database, conditions=""):
@ -155,8 +158,8 @@ class SystemPart(Model):
Returns: A list of SystemPart objects
"""
if conditions:
conditions += ' AND '
conditions += 'active'
conditions += " AND "
conditions += "active"
return SystemPart.get(database, conditions=conditions)

View File

@ -1,54 +1,48 @@
import codecs
import importlib
import pkgutil
import re
from datetime import date, datetime, tzinfo, timedelta
from collections import namedtuple
from datetime import date, datetime, timedelta, tzinfo
from inspect import isclass
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Type, Union
Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
Page.__doc__ += "\nA simple data structure for paginated results."
SPECIAL_CHARS = {
"\b" : "\\b",
"\f" : "\\f",
"\r" : "\\r",
"\n" : "\\n",
"\t" : "\\t",
"\0" : "\\0",
"\\" : "\\\\",
"'" : "\\'"
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True):
'''
def escape(value: str, quote: bool = True) -> str:
"""
If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one.
'''
def escape_one(match):
return SPECIAL_CHARS[match.group(0)]
if isinstance(value, str):
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
"""
value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
if quote:
value = "'" + value + "'"
return str(value)
return value
def unescape(value):
return codecs.escape_decode(value)[0].decode('utf-8')
def unescape(value: str) -> Optional[str]:
if value == "\\N":
return None
return codecs.escape_decode(value)[0].decode("utf-8")
def string_or_func(obj):
return obj.to_sql() if hasattr(obj, 'to_sql') else obj
return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg):
def arg_to_sql(arg: Any) -> str:
"""
Converts a function argument to SQL string according to its type.
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
None, numbers, timezones, arrays/iterables.
"""
from infi.clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet
from clickhouse_orm import DateTimeField, F, Field, QuerySet, StringField
if isinstance(arg, F):
return arg.to_sql()
if isinstance(arg, Field):
@ -66,42 +60,42 @@ def arg_to_sql(arg):
if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None))
if arg is None:
return 'NULL'
return "NULL"
if isinstance(arg, QuerySet):
return "(%s)" % arg
if isinstance(arg, tuple):
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')'
return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
if is_iterable(arg):
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']'
return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
return str(arg)
def parse_tsv(line):
def parse_tsv(line: Union[bytes, str]) -> List[str]:
if isinstance(line, bytes):
line = line.decode()
if line and line[-1] == '\n':
if line and line[-1] == "\n":
line = line[:-1]
return [unescape(value) for value in line.split(str('\t'))]
return [unescape(value) for value in line.split("\t")]
def parse_array(array_string):
def parse_array(array_string: str) -> List[Any]:
"""
Parse an array or tuple string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"]
"(1,2,3)" ==> [1, 2, 3]
"""
# Sanity check
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])':
if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace
array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left
values = []
while True:
if array_string in '])':
if array_string in "])":
# End of array
return values
elif array_string[0] in ', ':
elif array_string[0] in ", ":
# In between values
array_string = array_string[1:]
elif array_string[0] == "'":
@ -110,37 +104,33 @@ def parse_array(array_string):
if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():]
array_string = array_string[match.end() :]
else:
# Start of non-quoted value, find its end
match = re.search(r",|\]", array_string)
values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1:]
array_string = array_string[match.end() - 1 :]
def import_submodules(package_name):
def import_submodules(package_name: str) -> Dict[str, ModuleType]:
"""
Import all submodules of a module.
"""
import importlib, pkgutil
package = importlib.import_module(package_name)
return {
name: importlib.import_module(package_name + '.' + name)
name: importlib.import_module(package_name + "." + name)
for _, name, _ in pkgutil.iter_modules(package.__path__)
}
def comma_join(items, stringify=False):
def comma_join(items: Iterable[str]) -> str:
"""
Joins an iterable of strings with commas.
"""
if stringify:
return ', '.join(str(item) for item in items)
else:
return ', '.join(items)
return ", ".join(items)
def is_iterable(obj):
def is_iterable(obj: Any) -> bool:
"""
Checks if the given object is iterable.
"""
@ -151,17 +141,24 @@ def is_iterable(obj):
return False
def get_subclass_names(locals, base_class):
from inspect import isclass
def get_subclass_names(locals: Dict[str, Any], base_class: Type):
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
class NoValue:
'''
"""
A sentinel for fields with an expression for a default value,
that were not assigned a value yet.
'''
"""
def __repr__(self):
return 'NO_VALUE'
return "NO_VALUE"
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
NO_VALUE = NoValue()

View File

@ -1,8 +1,8 @@
Class Reference
===============
infi.clickhouse_orm.database
----------------------------
clickhouse_orm.database
-----------------------
### Database
@ -152,8 +152,8 @@ Extends Exception
Raised when a database operation fails.
infi.clickhouse_orm.models
--------------------------
clickhouse_orm.models
---------------------
### Model
@ -811,8 +811,8 @@ separated by non-alphanumeric characters.
- `random_seed` — The seed for Bloom filter hash functions.
infi.clickhouse_orm.fields
--------------------------
clickhouse_orm.fields
---------------------
### ArrayField
@ -1046,8 +1046,8 @@ Extends Field
#### UUIDField(default=None, alias=None, materialized=None, readonly=None, codec=None)
infi.clickhouse_orm.engines
---------------------------
clickhouse_orm.engines
----------------------
### Engine
@ -1140,8 +1140,8 @@ Extends MergeTree
#### ReplacingMergeTree(date_col=None, order_by=(), ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, primary_key=None)
infi.clickhouse_orm.query
-------------------------
clickhouse_orm.query
--------------------
### QuerySet
@ -1443,8 +1443,8 @@ https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier
#### to_sql(model_cls)
infi.clickhouse_orm.funcs
-------------------------
clickhouse_orm.funcs
--------------------
### F
@ -2012,7 +2012,7 @@ Initializer.
#### floor(n=None)
#### formatDateTime(format, timezone="")
#### formatDateTime(format, timezone=NO_VALUE)
#### gcd(b)
@ -2804,13 +2804,13 @@ Initializer.
#### toDateTimeOrZero()
#### toDayOfMonth()
#### toDayOfMonth(timezone=NO_VALUE)
#### toDayOfWeek()
#### toDayOfWeek(timezone=NO_VALUE)
#### toDayOfYear()
#### toDayOfYear(timezone=NO_VALUE)
#### toDecimal128(**kwargs)
@ -2861,7 +2861,7 @@ Initializer.
#### toFloat64OrZero()
#### toHour()
#### toHour(timezone=NO_VALUE)
#### toIPv4()
@ -2870,10 +2870,10 @@ Initializer.
#### toIPv6()
#### toISOWeek(timezone="")
#### toISOWeek(timezone=NO_VALUE)
#### toISOYear(timezone="")
#### toISOYear(timezone=NO_VALUE)
#### toInt16(**kwargs)
@ -2936,73 +2936,73 @@ Initializer.
#### toIntervalYear()
#### toMinute()
#### toMinute(timezone=NO_VALUE)
#### toMonday()
#### toMonday(timezone=NO_VALUE)
#### toMonth()
#### toMonth(timezone=NO_VALUE)
#### toQuarter(timezone="")
#### toQuarter(timezone=NO_VALUE)
#### toRelativeDayNum(timezone="")
#### toRelativeDayNum(timezone=NO_VALUE)
#### toRelativeHourNum(timezone="")
#### toRelativeHourNum(timezone=NO_VALUE)
#### toRelativeMinuteNum(timezone="")
#### toRelativeMinuteNum(timezone=NO_VALUE)
#### toRelativeMonthNum(timezone="")
#### toRelativeMonthNum(timezone=NO_VALUE)
#### toRelativeSecondNum(timezone="")
#### toRelativeSecondNum(timezone=NO_VALUE)
#### toRelativeWeekNum(timezone="")
#### toRelativeWeekNum(timezone=NO_VALUE)
#### toRelativeYearNum(timezone="")
#### toRelativeYearNum(timezone=NO_VALUE)
#### toSecond()
#### toSecond(timezone=NO_VALUE)
#### toStartOfDay()
#### toStartOfDay(timezone=NO_VALUE)
#### toStartOfFifteenMinutes()
#### toStartOfFifteenMinutes(timezone=NO_VALUE)
#### toStartOfFiveMinute()
#### toStartOfFiveMinute(timezone=NO_VALUE)
#### toStartOfHour()
#### toStartOfHour(timezone=NO_VALUE)
#### toStartOfISOYear()
#### toStartOfISOYear(timezone=NO_VALUE)
#### toStartOfMinute()
#### toStartOfMinute(timezone=NO_VALUE)
#### toStartOfMonth()
#### toStartOfMonth(timezone=NO_VALUE)
#### toStartOfQuarter()
#### toStartOfQuarter(timezone=NO_VALUE)
#### toStartOfTenMinutes()
#### toStartOfTenMinutes(timezone=NO_VALUE)
#### toStartOfWeek(mode=0)
#### toStartOfWeek(timezone=NO_VALUE)
#### toStartOfYear()
#### toStartOfYear(timezone=NO_VALUE)
#### toString()
@ -3011,7 +3011,7 @@ Initializer.
#### toStringCutToZero()
#### toTime(timezone="")
#### toTime(timezone=NO_VALUE)
#### toTimeZone(timezone)
@ -3056,22 +3056,22 @@ Initializer.
#### toUUID()
#### toUnixTimestamp(timezone="")
#### toUnixTimestamp(timezone=NO_VALUE)
#### toWeek(mode=0, timezone="")
#### toWeek(mode=0, timezone=NO_VALUE)
#### toYYYYMM(timezone="")
#### toYYYYMM(timezone=NO_VALUE)
#### toYYYYMMDD(timezone="")
#### toYYYYMMDD(timezone=NO_VALUE)
#### toYYYYMMDDhhmmss(timezone="")
#### toYYYYMMDDhhmmss(timezone=NO_VALUE)
#### toYear()
#### toYear(timezone=NO_VALUE)
#### to_sql(*args)
@ -3144,3 +3144,308 @@ For other functions:
#### uniqExact(**kwargs)
#### uniqExactIf(*args)
#### uniqExactOrDefault()
#### uniqExactOrDefaultIf(*args)
#### uniqExactOrNull()
#### uniqExactOrNullIf(*args)
#### uniqHLL12(**kwargs)
#### uniqHLL12If(*args)
#### uniqHLL12OrDefault()
#### uniqHLL12OrDefaultIf(*args)
#### uniqHLL12OrNull()
#### uniqHLL12OrNullIf(*args)
#### uniqIf(*args)
#### uniqOrDefault()
#### uniqOrDefaultIf(*args)
#### uniqOrNull()
#### uniqOrNullIf(*args)
#### upper(**kwargs)
#### upperUTF8()
#### varPop(**kwargs)
#### varPopIf(cond)
#### varPopOrDefault()
#### varPopOrDefaultIf(cond)
#### varPopOrNull()
#### varPopOrNullIf(cond)
#### varSamp(**kwargs)
#### varSampIf(cond)
#### varSampOrDefault()
#### varSampOrDefaultIf(cond)
#### varSampOrNull()
#### varSampOrNullIf(cond)
#### xxHash32()
#### xxHash64()
#### yesterday()
clickhouse_orm.system_models
----------------------------
### SystemPart
Extends Model
Contains information about parts of a table in the MergeTree family.
This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/
#### SystemPart(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### attach(settings=None)
Add a new part or partition from the 'detached' directory to the table.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.create_table_sql(db)
Returns the SQL statement for creating a table for this model.
#### detach(settings=None)
Move a partition to the 'detached' directory and forget it.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### drop(settings=None)
Delete a partition
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### fetch(zookeeper_path, settings=None)
Download a partition from another server.
- `zookeeper_path`: Path in zookeeper to fetch from
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.fields(writable=False)
Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included.
Callers should not modify the dictionary.
#### freeze(settings=None)
Create a backup of a partition.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.from_tsv(line, field_names, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to.
#### SystemPart.get(database, conditions="")
Get all data from system.parts table
- `database`: A database object to fetch data from.
- `conditions`: WHERE clause conditions. Database condition is added automatically
Returns: A list of SystemPart objects
#### SystemPart.get_active(database, conditions="")
Gets active data from system.parts table
- `database`: A database object to fetch data from.
- `conditions`: WHERE clause conditions. Database and active conditions are added automatically
Returns: A list of SystemPart objects
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### SystemPart.has_funcs_as_defaults()
Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances.
#### SystemPart.is_read_only()
Returns true if the model is marked as read only.
#### SystemPart.is_system_model()
Returns true if the model represents a system table.
#### SystemPart.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### SystemPart.table_name()
#### to_db_string()
Returns the instance as a bytestring ready to be inserted into the database.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tskv(include_readonly=True)
Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database.
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.

View File

@ -1,7 +1,7 @@
Contributing
============
This project is hosted on GitHub - [https://github.com/Infinidat/infi.clickhouse_orm/](https://github.com/Infinidat/infi.clickhouse_orm/).
This project is hosted on GitHub - [https://github.com/Infinidat/clickhouse_orm/](https://github.com/Infinidat/clickhouse_orm/).
Please open an issue there if you encounter a bug or want to request a feature.
Pull requests are also welcome.
@ -12,7 +12,7 @@ Building
After cloning the project, run the following commands:
easy_install -U infi.projector
cd infi.clickhouse_orm
cd clickhouse_orm
projector devenv build
A `setup.py` file will be generated, which you can use to install the development version of the package:
@ -28,7 +28,7 @@ To run the tests, ensure that the ClickHouse server is running on <http://localh
To see test coverage information run:
bin/nosetests --with-coverage --cover-package=infi.clickhouse_orm
bin/nosetests --with-coverage --cover-package=clickhouse_orm
To test with tox, ensure that the setup.py is present (otherwise run `bin/buildout buildout:develop= setup.py`) and run:

View File

@ -13,7 +13,7 @@ Using Expressions
Expressions usually include ClickHouse database functions, which are made available by the `F` class. Here's a simple function:
```python
from infi.clickhouse_orm import F
from clickhouse_orm import F
expr = F.today()
```

View File

@ -25,7 +25,7 @@ class Event(Model):
engine = Memory()
...
```
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` instead. For example:
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `clickhouse_orm.utils.NO_VALUE` instead. For example:
```python
>>> event = Event()
>>> print(event.to_dict())
@ -63,7 +63,7 @@ db.select('SELECT created, created_date, username, name FROM $db.event', model_c
# created_date and username will contain a default value
db.select('SELECT * FROM $db.event', model_class=Event)
```
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database.
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database.
## codec

View File

@ -166,7 +166,7 @@ For example, we can create a BooleanField which will hold `True` and `False` val
Here's the full implementation:
```python
from infi.clickhouse_orm import Field
from clickhouse_orm import Field
class BooleanField(Field):

View File

@ -7,24 +7,24 @@ The ORM supports different styles of importing and referring to its classes, so
Importing Everything
--------------------
It is safe to use `import *` from `infi.clickhouse_orm` or its submodules. Only classes that are needed by users of the ORM will get imported, and nothing else:
It is safe to use `import *` from `clickhouse_orm` or its submodules. Only classes that are needed by users of the ORM will get imported, and nothing else:
```python
from infi.clickhouse_orm import *
from clickhouse_orm import *
```
This is exactly equivalent to the following import statements:
```python
from infi.clickhouse_orm.database import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.funcs import *
from infi.clickhouse_orm.migrations import *
from infi.clickhouse_orm.models import *
from infi.clickhouse_orm.query import *
from infi.clickhouse_orm.system_models import *
from clickhouse_orm.database import *
from clickhouse_orm.engines import *
from clickhouse_orm.fields import *
from clickhouse_orm.funcs import *
from clickhouse_orm.migrations import *
from clickhouse_orm.models import *
from clickhouse_orm.query import *
from clickhouse_orm.system_models import *
```
By importing everything, all of the ORM's public classes can be used directly. For example:
```python
from infi.clickhouse_orm import *
from clickhouse_orm import *
class Event(Model):
@ -40,8 +40,8 @@ Importing Everything into a Namespace
To prevent potential name clashes and to make the code more readable, you can import the ORM's classes into a namespace of your choosing, e.g. `orm`. For brevity, it is recommended to import the `F` class explicitly:
```python
import infi.clickhouse_orm as orm
from infi.clickhouse_orm import F
import clickhouse_orm as orm
from clickhouse_orm import F
class Event(orm.Model):
@ -57,7 +57,7 @@ Importing Specific Submodules
It is possible to import only the submodules you need, and use their names to qualify the ORM's class names. This option is more verbose, but makes it clear where each class comes from. For example:
```python
from infi.clickhouse_orm import models, fields, engines, F
from clickhouse_orm import models, fields, engines, F
class Event(models.Model):
@ -71,9 +71,9 @@ class Event(models.Model):
Importing Specific Classes
--------------------------
If you prefer, you can import only the specific ORM classes that you need directly from `infi.clickhouse_orm`:
If you prefer, you can import only the specific ORM classes that you need directly from `clickhouse_orm`:
```python
from infi.clickhouse_orm import Model, StringField, UInt32Field, DateTimeField, F, Memory
from clickhouse_orm import Model, StringField, UInt32Field, DateTimeField, F, Memory
class Event(Model):

View File

@ -8,9 +8,9 @@ Version 1.x supports Python 2.7 and 3.5+. Version 2.x dropped support for Python
Installation
------------
To install infi.clickhouse_orm:
To install clickhouse_orm:
pip install infi.clickhouse_orm
pip install clickhouse_orm
---

View File

@ -10,7 +10,7 @@ Defining Models
Models are defined in a way reminiscent of Django's ORM, by subclassing `Model`:
```python
from infi.clickhouse_orm import Model, StringField, DateField, Float32Field, MergeTree
from clickhouse_orm import Model, StringField, DateField, Float32Field, MergeTree
class Person(Model):
@ -133,7 +133,7 @@ Inserting to the Database
To write your instances to ClickHouse, you need a `Database` instance:
from infi.clickhouse_orm import Database
from clickhouse_orm import Database
db = Database('my_test_db')

View File

@ -1,7 +1,7 @@
Class Reference
===============
infi.clickhouse_orm.database
clickhouse_orm.database
----------------------------
### Database
@ -104,7 +104,7 @@ Extends Exception
Raised when a database operation fails.
infi.clickhouse_orm.models
clickhouse_orm.models
--------------------------
### Model
@ -263,7 +263,7 @@ Returns the instance's column values as a tab-separated line. A newline is not i
- `include_readonly`: if false, returns only fields that can be inserted into database.
infi.clickhouse_orm.fields
clickhouse_orm.fields
--------------------------
### Field
@ -419,7 +419,7 @@ Extends BaseEnumField
#### Enum16Field(enum_cls, default=None, alias=None, materialized=None)
infi.clickhouse_orm.engines
clickhouse_orm.engines
---------------------------
### Engine
@ -474,7 +474,7 @@ Extends MergeTree
#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
infi.clickhouse_orm.query
clickhouse_orm.query
-------------------------
### QuerySet

View File

@ -22,7 +22,7 @@ To write migrations, create a Python package. Then create a python file for the
Each migration file is expected to contain a list of `operations`, for example:
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from analytics import models
operations = [

View File

@ -30,7 +30,7 @@ A partition in a table is data for a single calendar month. Table "system.parts"
Usage example:
from infi.clickhouse_orm import Database, SystemPart
from clickhouse_orm import Database, SystemPart
db = Database('my_test_db', db_url='http://192.168.1.1:8050', username='scott', password='tiger')
partitions = SystemPart.get_active(db, conditions='') # Getting all active partitions of the database
if len(partitions) > 0:

View File

@ -78,17 +78,17 @@
* [Tests](contributing.md#tests)
* [Class Reference](class_reference.md#class-reference)
* [infi.clickhouse_orm.database](class_reference.md#inficlickhouse_ormdatabase)
* [clickhouse_orm.database](class_reference.md#clickhouse_ormdatabase)
* [Database](class_reference.md#database)
* [DatabaseException](class_reference.md#databaseexception)
* [infi.clickhouse_orm.models](class_reference.md#inficlickhouse_ormmodels)
* [clickhouse_orm.models](class_reference.md#clickhouse_ormmodels)
* [Model](class_reference.md#model)
* [BufferModel](class_reference.md#buffermodel)
* [MergeModel](class_reference.md#mergemodel)
* [DistributedModel](class_reference.md#distributedmodel)
* [Constraint](class_reference.md#constraint)
* [Index](class_reference.md#index)
* [infi.clickhouse_orm.fields](class_reference.md#inficlickhouse_ormfields)
* [clickhouse_orm.fields](class_reference.md#clickhouse_ormfields)
* [ArrayField](class_reference.md#arrayfield)
* [BaseEnumField](class_reference.md#baseenumfield)
* [BaseFloatField](class_reference.md#basefloatfield)
@ -120,7 +120,7 @@
* [UInt64Field](class_reference.md#uint64field)
* [UInt8Field](class_reference.md#uint8field)
* [UUIDField](class_reference.md#uuidfield)
* [infi.clickhouse_orm.engines](class_reference.md#inficlickhouse_ormengines)
* [clickhouse_orm.engines](class_reference.md#clickhouse_ormengines)
* [Engine](class_reference.md#engine)
* [TinyLog](class_reference.md#tinylog)
* [Log](class_reference.md#log)
@ -132,10 +132,12 @@
* [CollapsingMergeTree](class_reference.md#collapsingmergetree)
* [SummingMergeTree](class_reference.md#summingmergetree)
* [ReplacingMergeTree](class_reference.md#replacingmergetree)
* [infi.clickhouse_orm.query](class_reference.md#inficlickhouse_ormquery)
* [clickhouse_orm.query](class_reference.md#clickhouse_ormquery)
* [QuerySet](class_reference.md#queryset)
* [AggregateQuerySet](class_reference.md#aggregatequeryset)
* [Q](class_reference.md#q)
* [infi.clickhouse_orm.funcs](class_reference.md#inficlickhouse_ormfuncs)
* [clickhouse_orm.funcs](class_reference.md#clickhouse_ormfuncs)
* [F](class_reference.md#f)
* [clickhouse_orm.system_models](class_reference.md#clickhouse_ormsystem_models)
* [SystemPart](class_reference.md#systempart)

View File

@ -50,9 +50,9 @@ for row in QueryLog.objects_in(db).filter(QueryLog.query_duration_ms > 10000):
## Convenient ways to import ORM classes
You can now import all ORM classes directly from `infi.clickhouse_orm`, without worrying about sub-modules. For example:
You can now import all ORM classes directly from `clickhouse_orm`, without worrying about sub-modules. For example:
```python
from infi.clickhouse_orm import Database, Model, StringField, DateTimeField, MergeTree
from clickhouse_orm import Database, Model, StringField, DateTimeField, MergeTree
```
See [Importing ORM Classes](importing_orm_classes.md).

View File

@ -1,9 +1,12 @@
import psutil, time, datetime
from infi.clickhouse_orm import Database
import datetime
import time
import psutil
from models import CPUStats
from clickhouse_orm import Database
db = Database('demo')
db = Database("demo")
db.create_table(CPUStats)
@ -14,7 +17,9 @@ while True:
stats = psutil.cpu_percent(percpu=True)
timestamp = datetime.datetime.now()
print(timestamp)
db.insert([
db.insert(
[
CPUStats(timestamp=timestamp, cpu_id=cpu_id, cpu_percent=cpu_percent)
for cpu_id, cpu_percent in enumerate(stats)
])
]
)

View File

@ -1,4 +1,4 @@
from infi.clickhouse_orm import Model, DateTimeField, UInt16Field, Float32Field, Memory
from clickhouse_orm import DateTimeField, Float32Field, Memory, Model, UInt16Field
class CPUStats(Model):
@ -8,4 +8,3 @@ class CPUStats(Model):
cpu_percent = Float32Field()
engine = Memory()

View File

@ -1,2 +1,2 @@
infi.clickhouse_orm
clickhouse_orm
psutil

View File

@ -1,13 +1,13 @@
from infi.clickhouse_orm import Database, F
from models import CPUStats
from clickhouse_orm import Database, F
db = Database('demo')
db = Database("demo")
queryset = CPUStats.objects_in(db)
total = queryset.filter(CPUStats.cpu_id == 1).count()
busy = queryset.filter(CPUStats.cpu_id == 1, CPUStats.cpu_percent > 95).count()
print('CPU 1 was busy {:.2f}% of the time'.format(busy * 100.0 / total))
print("CPU 1 was busy {:.2f}% of the time".format(busy * 100.0 / total))
# Calculate the average usage per CPU
for row in queryset.aggregate(CPUStats.cpu_id, average=F.avg(CPUStats.cpu_percent)):
print('CPU {row.cpu_id}: {row.average:.2f}%'.format(row=row))
print("CPU {row.cpu_id}: {row.average:.2f}%".format(row=row))

View File

@ -1,62 +1,73 @@
import pygal
from pygal.style import RotateStyle
from jinja2.filters import do_filesizeformat
from pygal.style import RotateStyle
# Formatting functions
number_formatter = lambda v: '{:,}'.format(v)
bytes_formatter = lambda v: do_filesizeformat(v, True)
def number_formatter(v):
return "{:,}".format(v)
def bytes_formatter(v):
do_filesizeformat(v, True)
def tables_piechart(db, by_field, value_formatter):
'''
"""
Generate a pie chart of the top n tables in the database.
`db` - the database instance
`by_field` - the field name to sort by
`value_formatter` - a function to use for formatting the numeric values
'''
Tables = db.get_model_for_table('tables', system_table=True)
qs = Tables.objects_in(db).filter(database=db.db_name, is_temporary=False).exclude(engine='Buffer')
"""
Tables = db.get_model_for_table("tables", system_table=True)
qs = Tables.objects_in(db).filter(database=db.db_name, is_temporary=False).exclude(engine="Buffer")
tuples = [(getattr(table, by_field), table.name) for table in qs]
return _generate_piechart(tuples, value_formatter)
def columns_piechart(db, tbl_name, by_field, value_formatter):
'''
"""
Generate a pie chart of the top n columns in the table.
`db` - the database instance
`tbl_name` - the table name
`by_field` - the field name to sort by
`value_formatter` - a function to use for formatting the numeric values
'''
ColumnsTable = db.get_model_for_table('columns', system_table=True)
"""
ColumnsTable = db.get_model_for_table("columns", system_table=True)
qs = ColumnsTable.objects_in(db).filter(database=db.db_name, table=tbl_name)
tuples = [(getattr(col, by_field), col.name) for col in qs]
return _generate_piechart(tuples, value_formatter)
def _get_top_tuples(tuples, n=15):
'''
"""
Given a list of tuples (value, name), this function sorts
the list and returns only the top n results. All other tuples
are aggregated to a single "others" tuple.
'''
"""
non_zero_tuples = [t for t in tuples if t[0]]
sorted_tuples = sorted(non_zero_tuples, reverse=True)
if len(sorted_tuples) > n:
others = (sum(t[0] for t in sorted_tuples[n:]), 'others')
others = (sum(t[0] for t in sorted_tuples[n:]), "others")
sorted_tuples = sorted_tuples[:n] + [others]
return sorted_tuples
def _generate_piechart(tuples, value_formatter):
'''
"""
Generates a pie chart.
`tuples` - a list of (value, name) tuples to include in the chart
`value_formatter` - a function to use for formatting the values
'''
style = RotateStyle('#9e6ffe', background='white', legend_font_family='Roboto', legend_font_size=18, tooltip_font_family='Roboto', tooltip_font_size=24)
chart = pygal.Pie(style=style, margin=0, title=' ', value_formatter=value_formatter, truncate_legend=-1)
"""
style = RotateStyle(
"#9e6ffe",
background="white",
legend_font_family="Roboto",
legend_font_size=18,
tooltip_font_family="Roboto",
tooltip_font_size=24,
)
chart = pygal.Pie(style=style, margin=0, title=" ", value_formatter=value_formatter, truncate_legend=-1)
for t in _get_top_tuples(tuples):
chart.add(t[1], t[0])
return chart.render(is_unicode=True, disable_xml_declaration=True)

View File

@ -3,7 +3,7 @@ chardet==3.0.4
click==7.1.2
Flask==1.1.2
idna==2.9
infi.clickhouse-orm==2.0.1
clickhouse-orm==2.0.1
iso8601==0.1.12
itsdangerous==1.1.0
Jinja2==2.11.2

View File

@ -1,87 +1,93 @@
from infi.clickhouse_orm import Database, F
from charts import tables_piechart, columns_piechart, number_formatter, bytes_formatter
from flask import Flask
from flask import render_template
import sys
from charts import bytes_formatter, columns_piechart, number_formatter, tables_piechart
from flask import Flask, render_template
from clickhouse_orm import Database, F
app = Flask(__name__)
@app.route('/')
@app.route("/")
def homepage_view():
'''
"""
Root view that lists all databases.
'''
db = _get_db('system')
"""
db = _get_db("system")
# Get all databases in the system.databases table
DatabasesTable = db.get_model_for_table('databases', system_table=True)
databases = DatabasesTable.objects_in(db).exclude(name='system')
DatabasesTable = db.get_model_for_table("databases", system_table=True)
databases = DatabasesTable.objects_in(db).exclude(name="system")
databases = databases.order_by(F.lower(DatabasesTable.name))
# Generate the page
return render_template('homepage.html', db=db, databases=databases)
return render_template("homepage.html", db=db, databases=databases)
@app.route('/<db_name>/')
@app.route("/<db_name>/")
def database_view(db_name):
'''
"""
A view that displays information about a single database.
'''
"""
db = _get_db(db_name)
# Get all the tables in the database, by aggregating information from system.columns
ColumnsTable = db.get_model_for_table('columns', system_table=True)
tables = ColumnsTable.objects_in(db).filter(database=db_name).aggregate(
ColumnsTable = db.get_model_for_table("columns", system_table=True)
tables = (
ColumnsTable.objects_in(db)
.filter(database=db_name)
.aggregate(
ColumnsTable.table,
compressed_size=F.sum(ColumnsTable.data_compressed_bytes),
uncompressed_size=F.sum(ColumnsTable.data_uncompressed_bytes),
ratio=F.sum(ColumnsTable.data_uncompressed_bytes) / F.sum(ColumnsTable.data_compressed_bytes)
ratio=F.sum(ColumnsTable.data_uncompressed_bytes) / F.sum(ColumnsTable.data_compressed_bytes),
)
)
tables = tables.order_by(F.lower(ColumnsTable.table))
# Generate the page
return render_template('database.html',
return render_template(
"database.html",
db=db,
tables=tables,
tables_piechart_by_rows=tables_piechart(db, 'total_rows', value_formatter=number_formatter),
tables_piechart_by_size=tables_piechart(db, 'total_bytes', value_formatter=bytes_formatter),
tables_piechart_by_rows=tables_piechart(db, "total_rows", value_formatter=number_formatter),
tables_piechart_by_size=tables_piechart(db, "total_bytes", value_formatter=bytes_formatter),
)
@app.route('/<db_name>/<tbl_name>/')
@app.route("/<db_name>/<tbl_name>/")
def table_view(db_name, tbl_name):
'''
"""
A view that displays information about a single table.
'''
"""
db = _get_db(db_name)
# Get table information from system.tables
TablesTable = db.get_model_for_table('tables', system_table=True)
TablesTable = db.get_model_for_table("tables", system_table=True)
tbl_info = TablesTable.objects_in(db).filter(database=db_name, name=tbl_name)[0]
# Get the SQL used for creating the table
create_table_sql = db.raw('SHOW CREATE TABLE %s FORMAT TabSeparatedRaw' % tbl_name)
create_table_sql = db.raw("SHOW CREATE TABLE %s FORMAT TabSeparatedRaw" % tbl_name)
# Get all columns in the table from system.columns
ColumnsTable = db.get_model_for_table('columns', system_table=True)
ColumnsTable = db.get_model_for_table("columns", system_table=True)
columns = ColumnsTable.objects_in(db).filter(database=db_name, table=tbl_name)
# Generate the page
return render_template('table.html',
return render_template(
"table.html",
db=db,
tbl_name=tbl_name,
tbl_info=tbl_info,
create_table_sql=create_table_sql,
columns=columns,
piechart=columns_piechart(db, tbl_name, 'data_compressed_bytes', value_formatter=bytes_formatter),
piechart=columns_piechart(db, tbl_name, "data_compressed_bytes", value_formatter=bytes_formatter),
)
def _get_db(db_name):
'''
"""
Returns a Database instance using connection information
from the command line arguments (optional).
'''
db_url = sys.argv[1] if len(sys.argv) > 1 else 'http://localhost:8123/'
"""
db_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8123/"
username = sys.argv[2] if len(sys.argv) > 2 else None
password = sys.argv[3] if len(sys.argv) > 3 else None
return Database(db_name, db_url, username, password, readonly=True)
if __name__ == '__main__':
_get_db('system') # fail early on db connection problems
if __name__ == "__main__":
_get_db("system") # fail early on db connection problems
app.run(debug=True)

View File

@ -1,27 +1,28 @@
import requests
import os
import requests
def download_ebook(id):
print(id, end=' ')
print(id, end=" ")
# Download the ebook's text
r = requests.get('https://www.gutenberg.org/files/{id}/{id}-0.txt'.format(id=id))
r = requests.get("https://www.gutenberg.org/files/{id}/{id}-0.txt".format(id=id))
if r.status_code == 404:
print('NOT FOUND, SKIPPING')
print("NOT FOUND, SKIPPING")
return
r.raise_for_status()
# Find the ebook's title
text = r.content.decode('utf-8')
text = r.content.decode("utf-8")
for line in text.splitlines():
if line.startswith('Title:'):
if line.startswith("Title:"):
title = line[6:].strip()
print(title)
# Save the ebook
with open('ebooks/{}.txt'.format(title), 'wb') as f:
with open("ebooks/{}.txt".format(title), "wb") as f:
f.write(r.content)
if __name__ == "__main__":
os.makedirs('ebooks', exist_ok=True)
os.makedirs("ebooks", exist_ok=True)
for i in [1342, 11, 84, 2701, 25525, 1661, 98, 74, 43, 215, 1400, 76]:
download_ebook(i)

View File

@ -1,61 +1,64 @@
import sys
import nltk
from nltk.stem.porter import PorterStemmer
from glob import glob
from infi.clickhouse_orm import Database
import nltk
from models import Fragment
from nltk.stem.porter import PorterStemmer
from clickhouse_orm import Database
def trim_punctuation(word):
'''
"""
Trim punctuation characters from the beginning and end of the word
'''
"""
start = end = len(word)
for i in range(len(word)):
if word[i].isalnum():
start = min(start, i)
end = i + 1
return word[start : end]
return word[start:end]
def parse_file(filename):
'''
"""
Parses a text file at the give path.
Returns a generator of tuples (original_word, stemmed_word)
The original_word may include punctuation characters.
'''
"""
stemmer = PorterStemmer()
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, "r", encoding="utf-8") as f:
for line in f:
for word in line.split():
yield (word, stemmer.stem(trim_punctuation(word)))
def get_fragments(filename):
'''
"""
Converts a text file at the given path to a generator
of Fragment instances.
'''
"""
from os import path
document = path.splitext(path.basename(filename))[0]
idx = 0
for word, stem in parse_file(filename):
idx += 1
yield Fragment(document=document, idx=idx, word=word, stem=stem)
print('{} - {} words'.format(filename, idx))
print("{} - {} words".format(filename, idx))
if __name__ == '__main__':
if __name__ == "__main__":
# Load NLTK data if necessary
nltk.download('punkt')
nltk.download('wordnet')
nltk.download("punkt")
nltk.download("wordnet")
# Initialize database
db = Database('default')
db = Database("default")
db.create_table(Fragment)
# Load files from the command line or everything under ebooks/
filenames = sys.argv[1:] or glob('ebooks/*.txt')
filenames = sys.argv[1:] or glob("ebooks/*.txt")
for filename in filenames:
db.insert(get_fragments(filename), batch_size=100000)

View File

@ -1,9 +1,11 @@
from infi.clickhouse_orm import *
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import LowCardinalityField, StringField, UInt64Field
from clickhouse_orm.models import Index, Model
class Fragment(Model):
language = LowCardinalityField(StringField(), default='EN')
language = LowCardinalityField(StringField(), default="EN")
document = LowCardinalityField(StringField())
idx = UInt64Field()
word = StringField()
@ -13,4 +15,4 @@ class Fragment(Model):
index = Index((document, idx), type=Index.minmax(), granularity=1)
# The primary key allows efficient lookup of stems
engine = MergeTree(order_by=(stem, document, idx), partition_key=('language',))
engine = MergeTree(order_by=(stem, document, idx), partition_key=("language",))

View File

@ -1,4 +1,4 @@
infi.clickhouse_orm
clickhouse_orm
nltk
requests
colorama

View File

@ -1,19 +1,20 @@
import sys
from colorama import init, Fore, Back, Style
from nltk.stem.porter import PorterStemmer
from infi.clickhouse_orm import Database, F
from models import Fragment
from load import trim_punctuation
from colorama import Fore, Style, init
from load import trim_punctuation
from models import Fragment
from nltk.stem.porter import PorterStemmer
from clickhouse_orm import Database, F
# The wildcard character
WILDCARD = '*'
WILDCARD = "*"
def prepare_search_terms(text):
'''
"""
Convert the text to search into a list of stemmed words.
'''
"""
stemmer = PorterStemmer()
stems = []
for word in text.split():
@ -25,10 +26,10 @@ def prepare_search_terms(text):
def build_query(db, stems):
'''
"""
Returns a queryset instance for finding sequences of Fragment instances
that matche the list of stemmed words.
'''
"""
# Start by searching for the first stemmed word
all_fragments = Fragment.objects_in(db)
query = all_fragments.filter(stem=stems[0]).only(Fragment.document, Fragment.idx)
@ -47,44 +48,44 @@ def build_query(db, stems):
def get_matching_text(db, document, from_idx, to_idx, extra=5):
'''
"""
Reconstructs the document text between the given indexes (inclusive),
plus `extra` words before and after the match. The words that are
included in the given range are highlighted in green.
'''
"""
text = []
conds = (Fragment.document == document) & (Fragment.idx >= from_idx - extra) & (Fragment.idx <= to_idx + extra)
for fragment in Fragment.objects_in(db).filter(conds).order_by('document', 'idx'):
for fragment in Fragment.objects_in(db).filter(conds).order_by("document", "idx"):
word = fragment.word
if fragment.idx == from_idx:
word = Fore.GREEN + word
if fragment.idx == to_idx:
word = word + Style.RESET_ALL
text.append(word)
return ' '.join(text)
return " ".join(text)
def find(db, text):
'''
"""
Performs the search for the given text, and prints out the matches.
'''
"""
stems = prepare_search_terms(text)
query = build_query(db, stems)
print('\n' + Fore.MAGENTA + str(query) + Style.RESET_ALL + '\n')
print("\n" + Fore.MAGENTA + str(query) + Style.RESET_ALL + "\n")
for match in query:
text = get_matching_text(db, match.document, match.idx, match.idx + len(stems) - 1)
print(Fore.CYAN + match.document + ':' + Style.RESET_ALL, text)
print(Fore.CYAN + match.document + ":" + Style.RESET_ALL, text)
if __name__ == '__main__':
if __name__ == "__main__":
# Initialize colored output
init()
# Initialize database
db = Database('default')
db = Database("default")
# Search
text = ' '.join(sys.argv[1:])
text = " ".join(sys.argv[1:])
if text:
find(db, text)

52
pyproject.toml Normal file
View File

@ -0,0 +1,52 @@
[tool.black]
line-length = 120
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
[tool.poetry]
name = "clickhouse_orm"
version = "2.2.2"
description = "A simple ORM for working with the Clickhouse database. Maintainance fork of infi.clickhouse_orm."
authors = ["olliemath <oliver.margetts@gmail.com>"]
license = "BSD"
homepage = "https://github.com/SuadeLabs/clickhouse_orm"
repository = "https://github.com/SuadeLabs/clickhouse_orm"
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Database"
]
[tool.poetry.dependencies]
python = ">=3.6.2,<4"
requests = "*"
pytz = "*"
iso8601 = "*"
[tool.poetry.dev-dependencies]
flake8 = "^3.9.2"
flake8-bugbear = "^21.4.3"
pep8-naming = "^0.12.0"
pytest = "^6.2.4"
flake8-isort = "^4.0.0"
black = {version = "^21.7b0", markers = "platform_python_implementation == 'CPython'"}
isort = "^5.9.2"
freezegun = "^1.1.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,4 +1,4 @@
#!/bin/bash
mkdir -p ../htmldocs
find ./ -iname "*.md" -type f -exec sh -c 'echo "Converting ${0}"; pandoc "${0}" -s -o "../htmldocs/${0%.md}.html"' {} \;

View File

@ -1,5 +1,6 @@
#!/bin/bash
# Class reference
../bin/python ../scripts/generate_ref.py > class_reference.md
poetry run python ../scripts/generate_ref.py > class_reference.md
# Table of contents
../scripts/generate_toc.sh

View File

@ -1,11 +1,11 @@
import inspect
from collections import namedtuple
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value')
DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
def _get_default_arg(args, defaults, arg_index):
""" Method that determines if an argument has default value or not,
"""Method that determines if an argument has default value or not,
and if yes what is the default value for the argument
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
@ -25,12 +25,13 @@ def _get_default_arg(args, defaults, arg_index):
return DefaultArgSpec(False, None)
else:
value = defaults[arg_index - args_with_no_defaults]
if (type(value) is str):
if type(value) is str:
value = '"%s"' % value
return DefaultArgSpec(True, value)
def get_method_sig(method):
""" Given a function, it returns a string that pretty much looks how the
"""Given a function, it returns a string that pretty much looks how the
function signature would be written in python.
:param method: a python method
@ -42,31 +43,37 @@ def get_method_sig(method):
# list of defaults are returned in separate array.
# eg: ArgSpec(args=['first_arg', 'second_arg', 'third_arg'],
# varargs=None, keywords=None, defaults=(42, 'something'))
argspec = inspect.getargspec(method)
arg_index=0
argspec = inspect.getfullargspec(method)
args = []
# Use the args and defaults array returned by argspec and find out
# which arguments has default
for arg in argspec.args:
default_arg = _get_default_arg(argspec.args, argspec.defaults, arg_index)
for idx, arg in enumerate(argspec.args):
default_arg = _get_default_arg(argspec.args, argspec.defaults, idx)
if default_arg.has_default:
val = default_arg.default_value
args.append("%s=%s" % (arg, val))
else:
args.append(arg)
for idx, arg in enumerate(argspec.kwonlyargs):
default_arg = _get_default_arg(argspec.kwonlyargs, argspec.kwonlydefaults, idx)
if default_arg.has_default:
val = default_arg.default_value
args.append("%s=%s" % (arg, val))
else:
args.append(arg)
arg_index += 1
if argspec.varargs:
args.append('*' + argspec.varargs)
if argspec.keywords:
args.append('**' + argspec.keywords)
args.append("*" + argspec.varargs)
if argspec.varkw:
args.append("**" + argspec.varkw)
return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
def docstring(obj):
doc = (obj.__doc__ or '').rstrip()
doc = (obj.__doc__ or "").rstrip()
if doc:
lines = doc.split('\n')
lines = doc.split("\n")
# Find the length of the whitespace prefix common to all non-empty lines
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
# Output the lines without the indentation
@ -76,30 +83,30 @@ def docstring(obj):
def class_doc(cls, list_methods=True):
bases = ', '.join([b.__name__ for b in cls.__bases__])
print('###', cls.__name__)
bases = ", ".join([b.__name__ for b in cls.__bases__])
print("###", cls.__name__)
print()
if bases != 'object':
print('Extends', bases)
if bases != "object":
print("Extends", bases)
print()
docstring(cls)
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
if name == '__init__':
if name == "__init__":
# Initializer
print('####', get_method_sig(method).replace(name, cls.__name__))
elif name[0] == '_':
print("####", get_method_sig(method).replace(name, cls.__name__))
elif name[0] == "_":
# Private method
continue
elif hasattr(method, '__self__') and method.__self__ == cls:
elif hasattr(method, "__self__") and method.__self__ == cls:
# Class method
if not list_methods:
continue
print('#### %s.%s' % (cls.__name__, get_method_sig(method)))
print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
else:
# Regular method
if not list_methods:
continue
print('####', get_method_sig(method))
print("####", get_method_sig(method))
print()
docstring(method)
print()
@ -108,7 +115,7 @@ def class_doc(cls, list_methods=True):
def module_doc(classes, list_methods=True):
mdl = classes[0].__module__
print(mdl)
print('-' * len(mdl))
print("-" * len(mdl))
print()
for cls in classes:
class_doc(cls, list_methods)
@ -118,21 +125,17 @@ def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
if __name__ == '__main__':
if __name__ == "__main__":
from infi.clickhouse_orm import database
from infi.clickhouse_orm import fields
from infi.clickhouse_orm import engines
from infi.clickhouse_orm import models
from infi.clickhouse_orm import query
from infi.clickhouse_orm import funcs
from infi.clickhouse_orm import system_models
from clickhouse_orm import database, engines, fields, funcs, models, query, system_models
print('Class Reference')
print('===============')
print("Class Reference")
print("===============")
print()
module_doc([database.Database, database.DatabaseException])
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index])
module_doc(
[models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index]
)
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
module_doc([engines.Engine] + all_subclasses(engines.Engine), False)
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])

View File

@ -1,7 +1,7 @@
#!/bin/bash
generate_one() {
# Converts Markdown to HTML using Pandoc, and then extracts the header tags
pandoc "$1" | python "../scripts/html_to_markdown_toc.py" "$1" >> toc.md
pandoc "$1" | poetry run python "../scripts/html_to_markdown_toc.py" "$1" >> toc.md
}
printf "# Table of Contents\n\n" > toc.md

View File

@ -1,14 +1,13 @@
from html.parser import HTMLParser
import sys
from html.parser import HTMLParser
HEADER_TAGS = ('h1', 'h2', 'h3')
HEADER_TAGS = ("h1", "h2", "h3")
class HeadersToMarkdownParser(HTMLParser):
inside = None
text = ''
text = ""
def handle_starttag(self, tag, attrs):
if tag.lower() in HEADER_TAGS:
@ -16,11 +15,11 @@ class HeadersToMarkdownParser(HTMLParser):
def handle_endtag(self, tag):
if tag.lower() in HEADER_TAGS:
indent = ' ' * int(self.inside[1])
fragment = self.text.lower().replace(' ', '-').replace('.', '')
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment))
indent = " " * int(self.inside[1])
fragment = self.text.lower().replace(" ", "-").replace(".", "")
print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
self.inside = None
self.text = ''
self.text = ""
def handle_data(self, data):
if self.inside:
@ -28,4 +27,4 @@ class HeadersToMarkdownParser(HTMLParser):
HeadersToMarkdownParser().feed(sys.stdin.read())
print('')
print("")

View File

@ -1,11 +0,0 @@
#!/bin/bash
cd /tmp
rm -rf /tmp/orm_env*
virtualenv -p python3 /tmp/orm_env
cd /tmp/orm_env
source bin/activate
pip install infi.projector
git clone https://github.com/Infinidat/infi.clickhouse_orm.git
cd infi.clickhouse_orm
projector devenv build
bin/nosetests

19
setup.cfg Normal file
View File

@ -0,0 +1,19 @@
[flake8]
max-line-length = 120
select =
# pycodestyle
E, W
# pyflakes
F
# flake8-bugbear
B, B9
# pydocstyle
D
# isort
I
ignore =
E203 # Whitespace after ':'
W503 # Operator after new line
B950 # We use E501
exclude =
tests/sample_migrations

View File

@ -1,50 +0,0 @@
SETUP_INFO = dict(
name = '${project:name}',
version = '${infi.recipe.template.version:version}',
author = '${infi.recipe.template.version:author}',
author_email = '${infi.recipe.template.version:author_email}',
url = ${infi.recipe.template.version:homepage},
license = 'BSD',
description = """${project:description}""",
# http://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.4",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Database"
],
install_requires = ${project:install_requires},
namespace_packages = ${project:namespace_packages},
package_dir = {'': 'src'},
package_data = {'': ${project:package_data}},
include_package_data = True,
zip_safe = False,
entry_points = dict(
console_scripts = ${project:console_scripts},
gui_scripts = ${project:gui_scripts},
),
)
if SETUP_INFO['url'] is None:
_ = SETUP_INFO.pop('url')
def setup():
from setuptools import setup as _setup
from setuptools import find_packages
SETUP_INFO['packages'] = find_packages('src')
_setup(**SETUP_INFO)
if __name__ == '__main__':
setup()

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -1,13 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)
from infi.clickhouse_orm.database import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.funcs import *
from infi.clickhouse_orm.migrations import *
from infi.clickhouse_orm.models import *
from infi.clickhouse_orm.query import *
from infi.clickhouse_orm.system_models import *
from inspect import isclass
__all__ = [c.__name__ for c in locals().values() if isclass(c)]

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -1,19 +1,18 @@
# -*- coding: utf-8 -*-
import logging
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
from clickhouse_orm.models import Model
import logging
logging.getLogger("requests").setLevel(logging.WARNING)
class TestCaseWithData(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(Person)
def tearDown(self):
@ -35,7 +34,6 @@ class TestCaseWithData(unittest.TestCase):
yield Person(**entry)
class Person(Model):
first_name = StringField()
@ -44,16 +42,12 @@ class Person(Model):
height = Float32Field()
passport = NullableField(UInt32Field())
engine = MergeTree('birthday', ('first_name', 'last_name', 'birthday'))
engine = MergeTree("birthday", ("first_name", "last_name", "birthday"))
data = [
{"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63",
"passport": 35052255},
{"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74",
"passport": 36052255},
{"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", "passport": 35052255},
{"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", "passport": 36052255},
{"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"},
{"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"},
{"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"},
@ -151,5 +145,5 @@ data = [
{"first_name": "Whitney", "last_name": "Durham", "birthday": "1977-09-15", "height": "1.72"},
{"first_name": "Whitney", "last_name": "Scott", "birthday": "1971-07-04", "height": "1.70"},
{"first_name": "Wynter", "last_name": "Garcia", "birthday": "1975-01-10", "height": "1.69"},
{"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"}
{"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"},
]

0
tests/fields/__init__.py Normal file
View File

View File

@ -1,32 +1,29 @@
import unittest
from datetime import date
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.funcs import F
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, Int32Field, StringField
from clickhouse_orm.funcs import F
from clickhouse_orm.models import NO_VALUE, Model
class AliasFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithAliasFields)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithAliasFields(
date_field='2016-08-30',
int_field=-10,
str_field='TEST'
)
instance = ModelWithAliasFields(date_field="2016-08-30", int_field=-10, str_field="TEST")
self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func' \
' FROM $db.%s ORDER BY alias_date' % ModelWithAliasFields.table_name()
query = (
"SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func"
" FROM $db.%s ORDER BY alias_date" % ModelWithAliasFields.table_name()
)
for model_cls in (ModelWithAliasFields, None):
results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class AliasFieldsTest(unittest.TestCase):
def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithAliasFields()
for value in ('x', [date.today()], ['aaa'], [None]):
for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError):
instance.alias_date = value
@ -51,10 +48,10 @@ class AliasFieldsTest(unittest.TestCase):
def test_duplicate_default(self):
with self.assertRaises(AssertionError):
StringField(alias='str_field', default='with default')
StringField(alias="str_field", default="with default")
with self.assertRaises(AssertionError):
StringField(alias='str_field', materialized='str_field')
StringField(alias="str_field", materialized="str_field")
def test_default_value(self):
instance = ModelWithAliasFields()
@ -62,7 +59,7 @@ class AliasFieldsTest(unittest.TestCase):
# Check that NO_VALUE can be assigned to a field
instance.str_field = NO_VALUE
# Check that NO_VALUE can be assigned when creating a new instance
instance2 = ModelWithAliasFields(**instance.to_dict())
ModelWithAliasFields(**instance.to_dict())
class ModelWithAliasFields(Model):
@ -70,9 +67,9 @@ class ModelWithAliasFields(Model):
date_field = DateField()
str_field = StringField()
alias_str = StringField(alias=u'str_field')
alias_int = Int32Field(alias='int_field')
alias_date = DateField(alias='date_field')
alias_str = StringField(alias="str_field")
alias_int = Int32Field(alias="int_field")
alias_date = DateField(alias="date_field")
alias_func = Int32Field(alias=F.toYYYYMM(date_field))
engine = MergeTree('date_field', ('date_field',))
engine = MergeTree("date_field", ("date_field",))

View File

@ -1,16 +1,15 @@
import unittest
from datetime import date
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import ArrayField, DateField, Int32Field, StringField
from clickhouse_orm.models import Model
class ArrayFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithArrays)
def tearDown(self):
@ -18,12 +17,12 @@ class ArrayFieldsTest(unittest.TestCase):
def test_insert_and_select(self):
instance = ModelWithArrays(
date_field='2016-08-30',
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
arr_date=['2010-01-01'],
date_field="2016-08-30",
arr_str=["goodbye,", "cruel", "world", "special chars: ,\"\\'` \n\t\\[]"],
arr_date=["2010-01-01"],
)
self.database.insert([instance])
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
query = "SELECT * from $db.modelwitharrays ORDER BY date_field"
for model_cls in (ModelWithArrays, None):
results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1)
@ -32,32 +31,25 @@ class ArrayFieldsTest(unittest.TestCase):
self.assertEqual(results[0].arr_date, instance.arr_date)
def test_conversion(self):
instance = ModelWithArrays(
arr_int=('1', '2', '3'),
arr_date=['2010-01-01']
)
instance = ModelWithArrays(arr_int=("1", "2", "3"), arr_date=["2010-01-01"])
self.assertEqual(instance.arr_str, [])
self.assertEqual(instance.arr_int, [1, 2, 3])
self.assertEqual(instance.arr_date, [date(2010, 1, 1)])
def test_assignment_error(self):
instance = ModelWithArrays()
for value in (7, 'x', [date.today()], ['aaa'], [None]):
for value in (7, "x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError):
instance.arr_int = value
def test_parse_array(self):
from infi.clickhouse_orm.utils import parse_array, unescape
from clickhouse_orm.utils import parse_array, unescape
self.assertEqual(parse_array("[]"), [])
self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
for s in ("",
"[",
"]",
"[1, 2",
"3, 4]",
"['aaa', 'aaa]"):
for s in ("", "[", "]", "[1, 2", "3, 4]", "['aaa', 'aaa]"):
with self.assertRaises(ValueError):
parse_array(s)
@ -74,4 +66,4 @@ class ModelWithArrays(Model):
arr_int = ArrayField(Int32Field())
arr_date = ArrayField(DateField())
engine = MergeTree('date_field', ('date_field',))
engine = MergeTree("date_field", ("date_field",))

View File

@ -0,0 +1,174 @@
import datetime
import unittest
import pytest
import pytz
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import (
ArrayField,
DateField,
DateTimeField,
FixedStringField,
Float32Field,
Int64Field,
NullableField,
StringField,
UInt64Field,
)
from clickhouse_orm.models import NO_VALUE, Model
from clickhouse_orm.utils import parse_tsv
class CompressedFieldsTestCase(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
self.database.create_table(CompressedModel)
def tearDown(self):
self.database.drop_database()
def test_defaults(self):
# Check that all fields have their explicit or implicit defaults
instance = CompressedModel()
self.database.insert([instance])
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.string_field, "dozo")
self.assertEqual(instance.int64_field, 42)
self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, [])
def test_assignment(self):
# Check that all fields are assigned during construction
kwargs = dict(
uint64_field=217,
date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
string_field="aloha",
int64_field=-50,
float_field=3.14,
nullable_field=-2.718281,
array_field=["123456789123456", "", "a"],
)
instance = CompressedModel(**kwargs)
self.database.insert([instance])
for name in kwargs:
self.assertEqual(kwargs[name], getattr(instance, name))
def test_string_conversion(self):
# Check field conversion from string during construction
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
nullable_field=None,
array_field="[a,b,c]",
)
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int64_field, 100)
self.assertEqual(instance.float_field, 7)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, ["a", "b", "c"])
# Check field conversion from string during assignment
instance.int64_field = "99"
self.assertEqual(instance.int64_field, 99)
def test_to_dict(self):
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
array_field="[a,b,c]",
)
self.assertDictEqual(
instance.to_dict(),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"mat_field": NO_VALUE,
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(
include_readonly=False,
field_names=("int64_field", "mat_field", "datetime_field"),
),
{
"int64_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
},
)
def test_confirm_compression_codec(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest("ClickHouse version too old")
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
array_field="[a,b,c]",
)
self.database.insert([instance])
r = self.database.raw(
"select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(
instance.table_name(), self.database.db_name
)
)
lines = r.splitlines()
parse_tsv(lines[0])
parse_tsv(lines[1])
data = [tuple(parse_tsv(line)) for line in lines[2:]]
self.assertListEqual(
data,
[
("uint64_field", "CODEC(ZSTD(10))"),
("datetime_field", "CODEC(Delta(4), ZSTD(1))"),
("date_field", "CODEC(Delta(4), ZSTD(22))"),
("int64_field", "CODEC(LZ4)"),
("string_field", "CODEC(LZ4HC(10))"),
("nullable_field", "CODEC(ZSTD(1))"),
("array_field", "CODEC(Delta(2), LZ4HC(0))"),
("float_field", "CODEC(NONE)"),
("mat_field", "CODEC(ZSTD(4))"),
],
)
def test_alias_field(self):
with pytest.raises(AssertionError):
Float32Field(alias="something", codec="ZSTD(4)")
class CompressedModel(Model):
uint64_field = UInt64Field(codec="ZSTD(10)")
datetime_field = DateTimeField(codec="Delta,ZSTD")
date_field = DateField(codec="Delta(4),ZSTD(22)")
int64_field = Int64Field(default=42, codec="LZ4")
string_field = StringField(default="dozo", codec="LZ4HC(10)")
nullable_field = NullableField(Float32Field(), codec="ZSTD")
array_field = ArrayField(FixedStringField(length=15), codec="Delta(2),LZ4HC")
float_field = Float32Field(codec="NONE")
mat_field = Float32Field(materialized="float_field", codec="ZSTD(4)")
engine = MergeTree("datetime_field", ("uint64_field", "datetime_field"))

View File

@ -1,14 +1,14 @@
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Field, Int16Field
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.engines import Memory
from clickhouse_orm.database import Database
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import Field, Int16Field
from clickhouse_orm.models import Model
class CustomFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
def tearDown(self):
self.database.drop_database()
@ -19,15 +19,18 @@ class CustomFieldsTest(unittest.TestCase):
i = Int16Field()
f = BooleanField()
engine = Memory()
self.database.create_table(TestModel)
# Check valid values
for index, value in enumerate([1, '1', True, 0, '0', False]):
for index, value in enumerate([1, "1", True, 0, "0", False]):
rec = TestModel(i=index, f=value)
self.database.insert([rec])
self.assertEqual([rec.f for rec in TestModel.objects_in(self.database).order_by('i')],
[True, True, True, False, False, False])
self.assertEqual(
[rec.f for rec in TestModel.objects_in(self.database).order_by("i")],
[True, True, True, False, False, False],
)
# Check invalid values
for value in [None, 'zzz', -5, 7]:
for value in [None, "zzz", -5, 7]:
with self.assertRaises(ValueError):
TestModel(i=1, f=value)
@ -35,21 +38,20 @@ class CustomFieldsTest(unittest.TestCase):
class BooleanField(Field):
# The ClickHouse column type to use
db_type = 'UInt8'
db_type = "UInt8"
# The default value if empty
class_default = False
def to_python(self, value, timezone_in_use):
# Convert valid values to bool
if value in (1, '1', True):
if value in (1, "1", True):
return True
elif value in (0, '0', False):
elif value in (0, "0", False):
return False
else:
raise ValueError('Invalid value for BooleanField: %r' % value)
raise ValueError("Invalid value for BooleanField: %r" % value)
def to_db_string(self, value, quote=True):
# The value was already converted by to_python, so it's a bool
return '1' if value else '0'
return "1" if value else "0"

View File

@ -0,0 +1,187 @@
import datetime
import unittest
import pytz
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, DateTime64Field, DateTimeField
from clickhouse_orm.models import Model
class DateFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithDate)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert(
[
ModelWithDate(
date_field="2016-08-30",
datetime_field="2016-08-30 03:50:00",
datetime64_field="2016-08-30 03:50:00.123456",
datetime64_3_field="2016-08-30 03:50:00.123456",
),
ModelWithDate(
date_field="2016-08-31",
datetime_field="2016-08-31 01:30:00",
datetime64_field="2016-08-31 01:30:00.123456",
datetime64_3_field="2016-08-31 01:30:00.123456",
),
]
)
# toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
query = "SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field"
results = list(self.database.select(query))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
self.assertEqual(
results[0].datetime_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].hour_start,
datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
self.assertEqual(
results[1].datetime_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].hour_start,
datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_3_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, 123000, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_3_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, 123000, tzinfo=pytz.UTC),
)
class ModelWithDate(Model):
date_field = DateField()
datetime_field = DateTimeField()
datetime64_field = DateTime64Field()
datetime64_3_field = DateTime64Field(precision=3)
engine = MergeTree("date_field", ("date_field",))
class ModelWithTz(Model):
datetime_no_tz_field = DateTimeField() # server tz
datetime_tz_field = DateTimeField(timezone="Europe/Madrid")
datetime64_tz_field = DateTime64Field(timezone="Europe/Madrid")
datetime_utc_field = DateTimeField(timezone=pytz.UTC)
engine = MergeTree("datetime_no_tz_field", ("datetime_no_tz_field",))
class DateTimeFieldWithTzTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithTz)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert(
[
ModelWithTz(
datetime_no_tz_field="2020-06-11 04:00:00",
datetime_tz_field="2020-06-11 04:00:00",
datetime64_tz_field="2020-06-11 04:00:00",
datetime_utc_field="2020-06-11 04:00:00",
),
ModelWithTz(
datetime_no_tz_field="2020-06-11 07:00:00+0300",
datetime_tz_field="2020-06-11 07:00:00+0300",
datetime64_tz_field="2020-06-11 07:00:00+0300",
datetime_utc_field="2020-06-11 07:00:00+0300",
),
]
)
query = "SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field"
results = list(self.database.select(query))
self.assertEqual(
results[0].datetime_no_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_utc_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_no_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_utc_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_no_tz_field.tzinfo.zone,
self.database.server_timezone.zone,
)
self.assertEqual(
results[0].datetime_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(
results[0].datetime64_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)
self.assertEqual(
results[1].datetime_no_tz_field.tzinfo.zone,
self.database.server_timezone.zone,
)
self.assertEqual(
results[1].datetime_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(
results[1].datetime64_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)

View File

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
import unittest
from decimal import Decimal
from clickhouse_orm.database import Database, ServerError
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import DateField, Decimal32Field, Decimal64Field, Decimal128Field, DecimalField
from clickhouse_orm.models import Model
class DecimalFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
try:
self.database.create_table(DecimalModel)
except ServerError as e:
# This ClickHouse version does not support decimals yet
raise unittest.SkipTest(str(e))
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert(
[
DecimalModel(date_field="2016-08-20"),
DecimalModel(date_field="2016-08-21", dec=Decimal("1.234")),
DecimalModel(date_field="2016-08-22", dec32=Decimal("12342.2345")),
DecimalModel(date_field="2016-08-23", dec64=Decimal("12342.23456")),
DecimalModel(date_field="2016-08-24", dec128=Decimal("-4545456612342.234567")),
]
)
def _assert_sample_data(self, results):
self.assertEqual(len(results), 5)
self.assertEqual(results[0].dec, Decimal(0))
self.assertEqual(results[0].dec32, Decimal(17))
self.assertEqual(results[1].dec, Decimal("1.234"))
self.assertEqual(results[2].dec32, Decimal("12342.2345"))
self.assertEqual(results[3].dec64, Decimal("12342.23456"))
self.assertEqual(results[4].dec128, Decimal("-4545456612342.234567"))
def test_insert_and_select(self):
self._insert_sample_data()
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, DecimalModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = "SELECT * from decimalmodel ORDER BY date_field"
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_rounding(self):
d = Decimal("11111.2340000000000000001")
self.database.insert([DecimalModel(date_field="2016-08-20", dec=d, dec32=d, dec64=d, dec128=d)])
m = DecimalModel.objects_in(self.database)[0]
for val in (m.dec, m.dec32, m.dec64, m.dec128):
self.assertEqual(val, Decimal("11111.234"))
def test_assignment_ok(self):
for value in (True, False, 17, 3.14, "20.5", Decimal("20.5")):
DecimalModel(dec=value)
def test_assignment_error(self):
for value in ("abc", u"זה ארוך", None, float("NaN"), Decimal("-Infinity")):
with self.assertRaises(ValueError):
DecimalModel(dec=value)
def test_aggregation(self):
self._insert_sample_data()
result = DecimalModel.objects_in(self.database).aggregate(m="min(dec)", n="max(dec)")
self.assertEqual(result[0].m, Decimal(0))
self.assertEqual(result[0].n, Decimal("1.234"))
def test_precision_and_scale(self):
# Go over all valid combinations
for precision in range(1, 39):
for scale in range(0, precision + 1):
DecimalField(precision, scale)
# Some invalid combinations
for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
with self.assertRaises(AssertionError):
DecimalField(precision, scale)
def test_min_max(self):
# In range
f = DecimalField(3, 1)
f.validate(f.to_python("99.9", None))
f.validate(f.to_python("-99.9", None))
# In range after rounding
f.validate(f.to_python("99.94", None))
f.validate(f.to_python("-99.94", None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python("99.99", None))
with self.assertRaises(ValueError):
f.validate(f.to_python("-99.99", None))
# In range
f = Decimal32Field(4)
f.validate(f.to_python("99999.9999", None))
f.validate(f.to_python("-99999.9999", None))
# In range after rounding
f.validate(f.to_python("99999.99994", None))
f.validate(f.to_python("-99999.99994", None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python("100000", None))
with self.assertRaises(ValueError):
f.validate(f.to_python("-100000", None))
class DecimalModel(Model):
date_field = DateField()
dec = DecimalField(15, 3)
dec32 = Decimal32Field(4, default=17)
dec64 = Decimal64Field(5)
dec128 = Decimal128Field(6)
engine = Memory()

View File

@ -1,17 +1,15 @@
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from enum import Enum
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import ArrayField, DateField, Enum8Field, Enum16Field
from clickhouse_orm.models import Model
class EnumFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithEnum)
self.database.create_table(ModelWithEnumArray)
@ -19,12 +17,14 @@ class EnumFieldsTest(unittest.TestCase):
self.database.drop_database()
def test_insert_and_select(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry)
])
query = 'SELECT * from $table ORDER BY date_field'
self.database.insert(
[
ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
]
)
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnum))
self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field, Fruit.apple)
@ -32,12 +32,14 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(results[2].enum_field, Fruit.cherry)
def test_ad_hoc_model(self):
self.database.insert([
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry)
])
query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
self.database.insert(
[
ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
]
)
query = "SELECT * from $db.modelwithenum ORDER BY date_field"
results = list(self.database.select(query))
self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field.name, Fruit.apple.name)
@ -50,11 +52,11 @@ class EnumFieldsTest(unittest.TestCase):
def test_conversion(self):
self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
self.assertEqual(ModelWithEnum(enum_field=-7).enum_field, Fruit.cherry)
self.assertEqual(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
self.assertEqual(ModelWithEnum(enum_field="apple").enum_field, Fruit.apple)
self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
def test_assignment_error(self):
for value in (0, 17, 'pear', '', None, 99.9):
for value in (0, 17, "pear", "", None, 99.9):
with self.assertRaises(ValueError):
ModelWithEnum(enum_field=value)
@ -63,15 +65,15 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(instance.enum_field, Fruit.apple)
def test_enum_array(self):
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
instance = ModelWithEnumArray(date_field="2016-08-30", enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
self.database.insert([instance])
query = 'SELECT * from $table ORDER BY date_field'
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnumArray))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].enum_array, instance.enum_array)
Fruit = Enum('Fruit', [('apple', 1), ('banana', 2), ('orange', 3), ('cherry', -7)])
Fruit = Enum("Fruit", [("apple", 1), ("banana", 2), ("orange", 3), ("cherry", -7)])
class ModelWithEnum(Model):
@ -79,7 +81,7 @@ class ModelWithEnum(Model):
date_field = DateField()
enum_field = Enum8Field(Fruit)
engine = MergeTree('date_field', ('date_field',))
engine = MergeTree("date_field", ("date_field",))
class ModelWithEnumArray(Model):
@ -87,5 +89,4 @@ class ModelWithEnumArray(Model):
date_field = DateField()
enum_array = ArrayField(Enum16Field(Fruit))
engine = MergeTree('date_field', ('date_field',))
engine = MergeTree("date_field", ("date_field",))

View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
import unittest
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, FixedStringField
from clickhouse_orm.models import Model
class FixedStringFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
self.database.create_table(FixedStringModel)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert(
[
FixedStringModel(date_field="2016-08-30", fstr_field=""),
FixedStringModel(date_field="2016-08-30"),
FixedStringModel(date_field="2016-08-31", fstr_field="foo"),
FixedStringModel(date_field="2016-08-31", fstr_field=u"לילה"),
]
)
def _assert_sample_data(self, results):
self.assertEqual(len(results), 4)
self.assertEqual(results[0].fstr_field, "")
self.assertEqual(results[1].fstr_field, "ABCDEFGHIJK")
self.assertEqual(results[2].fstr_field, "foo")
self.assertEqual(results[3].fstr_field, u"לילה")
def test_insert_and_select(self):
self._insert_sample_data()
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, FixedStringModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = "SELECT * from $db.fixedstringmodel ORDER BY date_field"
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_assignment_error(self):
for value in (17, "this is too long", u"זה ארוך", None, 99.9):
with self.assertRaises(ValueError):
FixedStringModel(fstr_field=value)
class FixedStringModel(Model):
date_field = DateField()
fstr_field = FixedStringField(12, default="ABCDEFGHIJK")
engine = MergeTree("date_field", ("date_field",))

View File

@ -1,60 +1,59 @@
import unittest
from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Int16Field, IPv4Field, IPv6Field
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.engines import Memory
from clickhouse_orm.database import Database
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import Int16Field, IPv4Field, IPv6Field
from clickhouse_orm.models import Model
class IPFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
def tearDown(self):
self.database.drop_database()
def test_ipv4_field(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
# Create a model
class TestModel(Model):
i = Int16Field()
f = IPv4Field()
engine = Memory()
self.database.create_table(TestModel)
# Check valid values (all values are the same ip)
values = [
'1.2.3.4',
b'\x01\x02\x03\x04',
16909060,
IPv4Address('1.2.3.4')
]
values = ["1.2.3.4", b"\x01\x02\x03\x04", 16909060, IPv4Address("1.2.3.4")]
for index, value in enumerate(values):
rec = TestModel(i=index, f=value)
self.database.insert([rec])
for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv4Address(values[0]))
# Check invalid values
for value in [None, 'zzz', -1, '123']:
for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError):
TestModel(i=1, f=value)
def test_ipv6_field(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
# Create a model
class TestModel(Model):
i = Int16Field()
f = IPv6Field()
engine = Memory()
self.database.create_table(TestModel)
# Check valid values (all values are the same ip)
values = [
'2a02:e980:1e::1',
b'*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01',
"2a02:e980:1e::1",
b"*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
55842696359362256756849388082849382401,
IPv6Address('2a02:e980:1e::1')
IPv6Address("2a02:e980:1e::1"),
]
for index, value in enumerate(values):
rec = TestModel(i=index, f=value)
@ -62,7 +61,6 @@ class IPFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv6Address(values[0]))
# Check invalid values
for value in [None, 'zzz', -1, '123']:
for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError):
TestModel(i=1, f=value)

View File

@ -1,32 +1,29 @@
import unittest
from datetime import date
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.funcs import F
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, DateTimeField, Int32Field, StringField
from clickhouse_orm.funcs import F
from clickhouse_orm.models import NO_VALUE, Model
class MaterializedFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithMaterializedFields)
def tearDown(self):
self.database.drop_database()
def test_insert_and_select(self):
instance = ModelWithMaterializedFields(
date_time_field='2016-08-30 11:00:00',
int_field=-10,
str_field='TEST'
)
instance = ModelWithMaterializedFields(date_time_field="2016-08-30 11:00:00", int_field=-10, str_field="TEST")
self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func' \
' FROM $db.%s ORDER BY mat_date' % ModelWithMaterializedFields.table_name()
query = (
"SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func"
" FROM $db.%s ORDER BY mat_date" % ModelWithMaterializedFields.table_name()
)
for model_cls in (ModelWithMaterializedFields, None):
results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithMaterializedFields()
for value in ('x', [date.today()], ['aaa'], [None]):
for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError):
instance.mat_date = value
@ -51,10 +48,10 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_duplicate_default(self):
with self.assertRaises(AssertionError):
StringField(materialized='str_field', default='with default')
StringField(materialized="str_field", default="with default")
with self.assertRaises(AssertionError):
StringField(materialized='str_field', alias='str_field')
StringField(materialized="str_field", alias="str_field")
def test_default_value(self):
instance = ModelWithMaterializedFields()
@ -66,9 +63,9 @@ class ModelWithMaterializedFields(Model):
date_time_field = DateTimeField()
str_field = StringField()
mat_str = StringField(materialized='lower(str_field)')
mat_int = Int32Field(materialized='abs(int_field)')
mat_date = DateField(materialized=u'toDate(date_time_field)')
mat_str = StringField(materialized="lower(str_field)")
mat_int = Int32Field(materialized="abs(int_field)")
mat_date = DateField(materialized=u"toDate(date_time_field)")
mat_func = StringField(materialized=F.lower(str_field))
engine = MergeTree('mat_date', ('mat_date',))
engine = MergeTree("mat_date", ("mat_date",))

View File

@ -1,19 +1,35 @@
import unittest
from datetime import date, datetime
import pytz
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.utils import comma_join
from datetime import date, datetime
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import (
BaseFloatField,
BaseIntField,
DateField,
DateTimeField,
Float32Field,
Float64Field,
Int8Field,
Int16Field,
Int32Field,
Int64Field,
NullableField,
StringField,
UInt8Field,
UInt16Field,
UInt32Field,
UInt64Field,
)
from clickhouse_orm.models import Model
from clickhouse_orm.utils import comma_join
class NullableFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithNullable)
def tearDown(self):
@ -23,18 +39,20 @@ class NullableFieldsTest(unittest.TestCase):
f = NullableField(DateTimeField())
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values
for value in (date(1970, 1, 1),
for value in (
date(1970, 1, 1),
datetime(1970, 1, 1),
epoch,
epoch.astimezone(pytz.timezone('US/Eastern')),
epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
'1970-01-01 00:00:00',
'1970-01-17 00:00:17',
'0000-00-00 00:00:00',
epoch.astimezone(pytz.timezone("US/Eastern")),
epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
"1970-01-01 00:00:00",
"1970-01-17 00:00:17",
"0000-00-00 00:00:00",
0,
'\\N'):
"\\N",
):
dt = f.to_python(value, pytz.utc)
if value == '\\N':
if value == "\\N":
self.assertIsNone(dt)
else:
self.assertTrue(dt.tzinfo)
@ -42,32 +60,32 @@ class NullableFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2)
# Invalid values
for value in ('nope', '21/7/1999', 0.5):
for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
def test_nullable_uint8_field(self):
f = NullableField(UInt8Field())
# Valid values
for value in (17, '17', 17.0, '\\N'):
for value in (17, "17", 17.0, "\\N"):
python_value = f.to_python(value, pytz.utc)
if value == '\\N':
if value == "\\N":
self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value))
else:
self.assertEqual(python_value, 17)
# Invalid values
for value in ('nope', date.today()):
for value in ("nope", date.today()):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
def test_nullable_string_field(self):
f = NullableField(StringField())
# Valid values
for value in ('\\\\N', 'N', 'some text', '\\N'):
for value in ("\\\\N", "N", "some text", "\\N"):
python_value = f.to_python(value, pytz.utc)
if value == '\\N':
if value == "\\N":
self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value))
else:
@ -78,7 +96,16 @@ class NullableFieldsTest(unittest.TestCase):
f = NullableField(field())
self.assertTrue(f.isinstance(field))
self.assertTrue(f.isinstance(NullableField))
for field in (Int8Field, Int16Field, Int32Field, Int64Field, UInt8Field, UInt16Field, UInt32Field, UInt64Field):
for field in (
Int8Field,
Int16Field,
Int32Field,
Int64Field,
UInt8Field,
UInt16Field,
UInt32Field,
UInt64Field,
):
f = NullableField(field())
self.assertTrue(f.isinstance(BaseIntField))
for field in (Float32Field, Float64Field):
@ -91,12 +118,25 @@ class NullableFieldsTest(unittest.TestCase):
def _insert_sample_data(self):
dt = date(1970, 1, 1)
self.database.insert([
ModelWithNullable(date_field='2016-08-30', null_str='', null_int=42, null_date=dt),
ModelWithNullable(date_field='2016-08-30', null_str='nothing', null_int=None, null_date=None),
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=42, null_date=dt),
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=None, null_date=None, null_default=None)
])
self.database.insert(
[
ModelWithNullable(date_field="2016-08-30", null_str="", null_int=42, null_date=dt),
ModelWithNullable(
date_field="2016-08-30",
null_str="nothing",
null_int=None,
null_date=None,
),
ModelWithNullable(date_field="2016-08-31", null_str=None, null_int=42, null_date=dt),
ModelWithNullable(
date_field="2016-08-31",
null_str=None,
null_int=None,
null_date=None,
null_default=None,
),
]
)
def _assert_sample_data(self, results):
for r in results:
@ -110,7 +150,7 @@ class NullableFieldsTest(unittest.TestCase):
self.assertEqual(results[0].null_materialized, 420)
self.assertEqual(results[0].null_date, dt)
self.assertIsNone(results[1].null_date)
self.assertEqual(results[1].null_str, 'nothing')
self.assertEqual(results[1].null_str, "nothing")
self.assertIsNone(results[1].null_date)
self.assertIsNone(results[2].null_str)
self.assertEqual(results[2].null_date, dt)
@ -128,14 +168,14 @@ class NullableFieldsTest(unittest.TestCase):
def test_insert_and_select(self):
self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $table ORDER BY date_field' % fields
query = "SELECT %s from $table ORDER BY date_field" % fields
results = list(self.database.select(query, ModelWithNullable))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $db.modelwithnullable ORDER BY date_field' % fields
query = "SELECT %s from $db.modelwithnullable ORDER BY date_field" % fields
results = list(self.database.select(query))
self._assert_sample_data(results)
@ -143,11 +183,11 @@ class NullableFieldsTest(unittest.TestCase):
class ModelWithNullable(Model):
date_field = DateField()
null_str = NullableField(StringField(), extra_null_values={''})
null_str = NullableField(StringField(), extra_null_values={""})
null_int = NullableField(Int32Field())
null_date = NullableField(DateField())
null_default = NullableField(Int32Field(), default=7)
null_alias = NullableField(Int32Field(), alias='null_int/2')
null_materialized = NullableField(Int32Field(), alias='null_int*10')
null_alias = NullableField(Int32Field(), alias="null_int/2")
null_materialized = NullableField(Int32Field(), alias="null_int*10")
engine = MergeTree('date_field', ('date_field',))
engine = MergeTree("date_field", ("date_field",))

View File

@ -1,19 +1,30 @@
import unittest
from infi.clickhouse_orm.fields import *
from datetime import date, datetime
import pytz
from clickhouse_orm.fields import DateField, DateTime64Field, DateTimeField, UInt8Field
class SimpleFieldsTest(unittest.TestCase):
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values
dates = [
date(1970, 1, 1), datetime(1970, 1, 1), epoch,
epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
'1970-01-01 00:00:00', '1970-01-17 00:00:17', '0000-00-00 00:00:00', 0,
'2017-07-26T08:31:05', '2017-07-26T08:31:05Z', '2017-07-26 08:31',
'2017-07-26T13:31:05+05', '2017-07-26 13:31:05+0500'
date(1970, 1, 1),
datetime(1970, 1, 1),
epoch,
epoch.astimezone(pytz.timezone("US/Eastern")),
epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
"1970-01-01 00:00:00",
"1970-01-17 00:00:17",
"0000-00-00 00:00:00",
0,
"2017-07-26T08:31:05",
"2017-07-26T08:31:05Z",
"2017-07-26 08:31",
"2017-07-26T13:31:05+05",
"2017-07-26 13:31:05+0500",
]
def test_datetime_field(self):
@ -25,8 +36,14 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2)
# Invalid values
for value in ('nope', '21/7/1999', 0.5,
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'):
for value in (
"nope",
"21/7/1999",
0.5,
"2017-01 15:06:00",
"2017-01-01X15:06:00",
"2017-13-01T15:06:00",
):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
@ -35,10 +52,16 @@ class SimpleFieldsTest(unittest.TestCase):
# Valid values
for value in self.dates + [
datetime(1970, 1, 1, microsecond=100000),
pytz.timezone('US/Eastern').localize(datetime(1970, 1, 1, microsecond=100000)),
'1970-01-01 00:00:00.1', '1970-01-17 00:00:17.1', '0000-00-00 00:00:00.1', 0.1,
'2017-07-26T08:31:05.1', '2017-07-26T08:31:05.1Z', '2017-07-26 08:31.1',
'2017-07-26T13:31:05.1+05', '2017-07-26 13:31:05.1+0500'
pytz.timezone("US/Eastern").localize(datetime(1970, 1, 1, microsecond=100000)),
"1970-01-01 00:00:00.1",
"1970-01-17 00:00:17.1",
"0000-00-00 00:00:00.1",
0.1,
"2017-07-26T08:31:05.1",
"2017-07-26T08:31:05.1Z",
"2017-07-26 08:31.1",
"2017-07-26T13:31:05.1+05",
"2017-07-26 13:31:05.1+0500",
]:
dt = f.to_python(value, pytz.utc)
self.assertTrue(dt.tzinfo)
@ -46,8 +69,13 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2)
# Invalid values
for value in ('nope', '21/7/1999',
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'):
for value in (
"nope",
"21/7/1999",
"2017-01 15:06:00",
"2017-01-01X15:06:00",
"2017-13-01T15:06:00",
):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
@ -63,14 +91,14 @@ class SimpleFieldsTest(unittest.TestCase):
f = DateField()
epoch = date(1970, 1, 1)
# Valid values
for value in (datetime(1970, 1, 1), epoch, '1970-01-01', '0000-00-00', 0):
for value in (datetime(1970, 1, 1), epoch, "1970-01-01", "0000-00-00", 0):
d = f.to_python(value, pytz.utc)
self.assertEqual(d, epoch)
# Verify that conversion to and from db string does not change value
d2 = f.to_python(f.to_db_string(d, quote=False), pytz.utc)
self.assertEqual(d, d2)
# Invalid values
for value in ('nope', '21/7/1999', 0.5):
for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
# Range check
@ -81,7 +109,7 @@ class SimpleFieldsTest(unittest.TestCase):
def test_date_field_timezone(self):
# Verify that conversion of timezone-aware datetime is correct
f = DateField()
dt = datetime(2017, 10, 5, tzinfo=pytz.timezone('Asia/Jerusalem'))
dt = datetime(2017, 10, 5, tzinfo=pytz.timezone("Asia/Jerusalem"))
self.assertEqual(f.to_python(dt, pytz.utc), date(2017, 10, 4))
def test_datetime_field_timezone(self):
@ -89,21 +117,21 @@ class SimpleFieldsTest(unittest.TestCase):
f = DateTimeField()
utc_value = datetime(2017, 7, 26, 8, 31, 5, tzinfo=pytz.UTC)
for value in (
'2017-07-26T08:31:05',
'2017-07-26T08:31:05Z',
'2017-07-26T11:31:05+03',
'2017-07-26 11:31:05+0300',
'2017-07-26T03:31:05-0500',
"2017-07-26T08:31:05",
"2017-07-26T08:31:05Z",
"2017-07-26T11:31:05+03",
"2017-07-26 11:31:05+0300",
"2017-07-26T03:31:05-0500",
):
self.assertEqual(f.to_python(value, pytz.utc), utc_value)
def test_uint8_field(self):
f = UInt8Field()
# Valid values
for value in (17, '17', 17.0):
for value in (17, "17", 17.0):
self.assertEqual(f.to_python(value, pytz.utc), 17)
# Invalid values
for value in ('nope', date.today()):
for value in ("nope", date.today()):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
# Range check

View File

@ -1,35 +1,37 @@
import unittest
from uuid import UUID
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Int16Field, UUIDField
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.engines import Memory
from clickhouse_orm.database import Database
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import Int16Field, UUIDField
from clickhouse_orm.models import Model
class UUIDFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
def tearDown(self):
self.database.drop_database()
def test_uuid_field(self):
if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
# Create a model
class TestModel(Model):
i = Int16Field()
f = UUIDField()
engine = Memory()
self.database.create_table(TestModel)
# Check valid values (all values are the same UUID)
values = [
'12345678-1234-5678-1234-567812345678',
'{12345678-1234-5678-1234-567812345678}',
'12345678123456781234567812345678',
'urn:uuid:12345678-1234-5678-1234-567812345678',
b'\x12\x34\x56\x78'*4,
"12345678-1234-5678-1234-567812345678",
"{12345678-1234-5678-1234-567812345678}",
"12345678123456781234567812345678",
"urn:uuid:12345678-1234-5678-1234-567812345678",
b"\x12\x34\x56\x78" * 4,
(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678),
0x12345678123456781234567812345678,
UUID(int=0x12345678123456781234567812345678),
@ -40,7 +42,6 @@ class UUIDFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, UUID(values[0]))
# Check invalid values
for value in [None, 'zzz', -1, '123']:
for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError):
TestModel(i=1, f=value)

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(Model1)
]
operations = [migrations.CreateTable(Model1)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.DropTable(Model1)
]
operations = [migrations.DropTable(Model1)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(Model1)
]
operations = [migrations.CreateTable(Model1)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(Model2)
]
operations = [migrations.AlterTable(Model2)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(Model3)
]
operations = [migrations.AlterTable(Model3)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(EnumModel1)
]
operations = [migrations.CreateTable(EnumModel1)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(EnumModel2)
]
operations = [migrations.AlterTable(EnumModel2)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(MaterializedModel)
]
operations = [migrations.CreateTable(MaterializedModel)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(AliasModel)
]
operations = [migrations.CreateTable(AliasModel)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(Model4Buffer)
]
operations = [migrations.CreateTable(Model4Buffer)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTableWithBuffer(Model4Buffer_changed)
]
operations = [migrations.AlterTableWithBuffer(Model4Buffer_changed)]

View File

@ -1,9 +1,11 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
operations = [
migrations.RunSQL("INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-01', 1, 1, 'test') "),
migrations.RunSQL([
migrations.RunSQL(
[
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-02', 2, 2, 'test2') ",
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-03', 3, 3, 'test3') ",
])
]
),
]

View File

@ -1,15 +1,12 @@
import datetime
from infi.clickhouse_orm import migrations
from test_migrations import Model3
from clickhouse_orm import migrations
def forward(database):
database.insert([
Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4='test4')
])
database.insert([Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4="test4")])
operations = [
migrations.RunPython(forward)
]
operations = [migrations.RunPython(forward)]

View File

@ -1,7 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(MaterializedModel1),
migrations.AlterTable(AliasModel1)
]
operations = [migrations.AlterTable(MaterializedModel1), migrations.AlterTable(AliasModel1)]

View File

@ -1,7 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterTable(Model4_compressed),
migrations.AlterTable(Model2LowCardinality)
]
operations = [migrations.AlterTable(Model4_compressed), migrations.AlterTable(Model2LowCardinality)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(ModelWithConstraints)
]
operations = [migrations.CreateTable(ModelWithConstraints)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterConstraints(ModelWithConstraints2)
]
operations = [migrations.AlterConstraints(ModelWithConstraints2)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.CreateTable(ModelWithIndex)
]
operations = [migrations.CreateTable(ModelWithIndex)]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations
from clickhouse_orm import migrations
from ..test_migrations import *
operations = [
migrations.AlterIndexes(ModelWithIndex2, reindex=True)
]
operations = [migrations.AlterIndexes(ModelWithIndex2, reindex=True)]

View File

@ -1,13 +1,11 @@
# -*- coding: utf-8 -*-
import unittest
from clickhouse_orm.engines import Buffer
from clickhouse_orm.models import BufferModel
from infi.clickhouse_orm.models import BufferModel
from infi.clickhouse_orm.engines import *
from .base_test_with_data import *
from .base_test_with_data import Person, TestCaseWithData, data
class BufferTestCase(TestCaseWithData):
def _insert_and_check_buffer(self, data, count):
self.database.insert(data)
self.assertEqual(count, self.database.count(PersonBuffer))

View File

@ -1,123 +0,0 @@
import unittest
import datetime
import pytz
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.utils import parse_tsv
class CompressedFieldsTestCase(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(CompressedModel)
def tearDown(self):
self.database.drop_database()
def test_defaults(self):
# Check that all fields have their explicit or implicit defaults
instance = CompressedModel()
self.database.insert([instance])
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.string_field, 'dozo')
self.assertEqual(instance.int64_field, 42)
self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, [])
def test_assignment(self):
# Check that all fields are assigned during construction
kwargs = dict(
uint64_field=217,
date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
string_field='aloha',
int64_field=-50,
float_field=3.14,
nullable_field=-2.718281,
array_field=['123456789123456','','a']
)
instance = CompressedModel(**kwargs)
self.database.insert([instance])
for name, value in kwargs.items():
self.assertEqual(kwargs[name], getattr(instance, name))
def test_string_conversion(self):
# Check field conversion from string during construction
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', nullable_field=None, array_field='[a,b,c]')
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int64_field, 100)
self.assertEqual(instance.float_field, 7)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, ['a', 'b', 'c'])
# Check field conversion from string during assignment
instance.int64_field = '99'
self.assertEqual(instance.int64_field, 99)
def test_to_dict(self):
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]')
self.assertDictEqual(instance.to_dict(), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a','b','c']
})
self.assertDictEqual(instance.to_dict(include_readonly=False), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a', 'b', 'c']
})
self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=('int64_field', 'alias_field', 'datetime_field')), {
"int64_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc)
})
def test_confirm_compression_codec(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old')
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]')
self.database.insert([instance])
r = self.database.raw("select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(instance.table_name(), self.database.db_name))
lines = r.splitlines()
field_names = parse_tsv(lines[0])
field_types = parse_tsv(lines[1])
data = [tuple(parse_tsv(line)) for line in lines[2:]]
self.assertListEqual(data, [('uint64_field', 'CODEC(ZSTD(10))'),
('datetime_field', 'CODEC(Delta(4), ZSTD(1))'),
('date_field', 'CODEC(Delta(4), ZSTD(22))'),
('int64_field', 'CODEC(LZ4)'),
('string_field', 'CODEC(LZ4HC(10))'),
('nullable_field', 'CODEC(ZSTD(1))'),
('array_field', 'CODEC(Delta(2), LZ4HC(0))'),
('float_field', 'CODEC(NONE)'),
('alias_field', 'CODEC(ZSTD(4))')])
class CompressedModel(Model):
uint64_field = UInt64Field(codec='ZSTD(10)')
datetime_field = DateTimeField(codec='Delta,ZSTD')
date_field = DateField(codec='Delta(4),ZSTD(22)')
int64_field = Int64Field(default=42, codec='LZ4')
string_field = StringField(default='dozo', codec='LZ4HC(10)')
nullable_field = NullableField(Float32Field(), codec='ZSTD')
array_field = ArrayField(FixedStringField(length=15), codec='Delta(2),LZ4HC')
float_field = Float32Field(codec='NONE')
alias_field = Float32Field(alias='float_field', codec='ZSTD(4)')
engine = MergeTree('datetime_field', ('uint64_field', 'datetime_field'))

View File

@ -1,44 +1,63 @@
import unittest
from infi.clickhouse_orm import *
from clickhouse_orm import Constraint, Database, F, ServerError
from .base_test_with_data import Person
class ConstraintsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (19, 14, 3, 3):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(PersonWithConstraints)
def tearDown(self):
self.database.drop_database()
def test_insert_valid_values(self):
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66)
])
self.database.insert(
[
PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="2000-01-01",
height=1.66,
)
]
)
def test_insert_invalid_values(self):
with self.assertRaises(ServerError) as e:
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66)
])
self.database.insert(
[
PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="2100-01-01",
height=1.66,
)
]
)
self.assertEqual(e.code, 469)
self.assertTrue('Constraint `birthday_in_the_past`' in e.message)
self.assertTrue("Constraint `birthday_in_the_past`" in str(e))
with self.assertRaises(ServerError) as e:
self.database.insert([
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3)
])
self.database.insert(
[
PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="1970-01-01",
height=3,
)
]
)
self.assertEqual(e.code, 469)
self.assertTrue('Constraint `max_height`' in e.message)
self.assertTrue("Constraint `max_height`" in str(e))
class PersonWithConstraints(Person):
birthday_in_the_past = Constraint(Person.birthday <= F.today())
max_height = Constraint(Person.height <= 2.75)

View File

@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
import unittest
import datetime
import unittest
from infi.clickhouse_orm.database import ServerError, DatabaseException
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.engines import Memory
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.funcs import F
from infi.clickhouse_orm.query import Q
from .base_test_with_data import *
from clickhouse_orm.database import Database, DatabaseException, ServerError
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import DateField, DateTimeField, Float32Field, Int32Field, StringField
from clickhouse_orm.funcs import F
from clickhouse_orm.models import Model
from clickhouse_orm.query import Q
from .base_test_with_data import Person, TestCaseWithData, data
class DatabaseTestCase(TestCaseWithData):
def test_insert__generator(self):
self._insert_and_check(self._sample_data(), len(data))
@ -33,17 +33,19 @@ class DatabaseTestCase(TestCaseWithData):
def test_insert__funcs_as_default_values(self):
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('Buggy in server versions before 20.1.2.4')
raise unittest.SkipTest("Buggy in server versions before 20.1.2.4")
class TestModel(Model):
a = DateTimeField(default=datetime.datetime(2020, 1, 1))
b = DateField(default=F.toDate(a))
c = Int32Field(default=7)
d = Int32Field(default=c * 5)
engine = Memory()
self.database.create_table(TestModel)
self.database.insert([TestModel()])
t = TestModel.objects_in(self.database)[0]
self.assertEqual(str(t.b), '2020-01-01')
self.assertEqual(str(t.b), "2020-01-01")
self.assertEqual(t.d, 35)
def test_count(self):
@ -63,9 +65,9 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham')
self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott')
self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
@ -79,9 +81,9 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham')
self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 0) # default value
self.assertEqual(results[1].last_name, 'Scott')
self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 0) # default value
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
@ -91,10 +93,10 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].__class__.__name__, 'AdHocModel')
self.assertEqual(results[0].last_name, 'Durham')
self.assertEqual(results[0].__class__.__name__, "AdHocModel")
self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott')
self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database)
@ -116,7 +118,7 @@ class DatabaseTestCase(TestCaseWithData):
page_num = 1
instances = set()
while True:
page = self.database.paginate(Person, 'first_name, last_name', page_num, page_size)
page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
self.assertEqual(page.number_of_objects, len(data))
self.assertGreater(page.pages_total, 0)
[instances.add(obj.to_tsv()) for obj in page.objects]
@ -131,15 +133,23 @@ class DatabaseTestCase(TestCaseWithData):
# Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality
page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size)
page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size)
page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
self.assertEqual(page_a[1:], page_b[1:])
self.assertEqual([obj.to_tsv() for obj in page_a.objects],
[obj.to_tsv() for obj in page_b.objects])
self.assertEqual(
[obj.to_tsv() for obj in page_a.objects],
[obj.to_tsv() for obj in page_b.objects],
)
def test_pagination_empty_page(self):
for page_num in (-1, 1, 2):
page = self.database.paginate(Person, 'first_name, last_name', page_num, 10, conditions="first_name = 'Ziggy'")
page = self.database.paginate(
Person,
"first_name, last_name",
page_num,
10,
conditions="first_name = 'Ziggy'",
)
self.assertEqual(page.number_of_objects, 0)
self.assertEqual(page.objects, [])
self.assertEqual(page.pages_total, 0)
@ -149,22 +159,28 @@ class DatabaseTestCase(TestCaseWithData):
self._insert_and_check(self._sample_data(), len(data))
for page_num in (0, -2, -100):
with self.assertRaises(ValueError):
self.database.paginate(Person, 'first_name, last_name', page_num, 100)
self.database.paginate(Person, "first_name, last_name", page_num, 100)
def test_pagination_with_conditions(self):
self._insert_and_check(self._sample_data(), len(data))
# Conditions as string
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions="first_name < 'Ava'")
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
self.assertEqual(page.number_of_objects, 10)
# Conditions as expression
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Person.first_name < 'Ava')
page = self.database.paginate(
Person,
"first_name, last_name",
1,
100,
conditions=Person.first_name < "Ava",
)
self.assertEqual(page.number_of_objects, 10)
# Conditions as Q object
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Q(first_name__lt='Ava'))
page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
self.assertEqual(page.number_of_objects, 10)
def test_special_chars(self):
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
p = Person(first_name=s)
self.database.insert([p])
p = list(self.database.select("SELECT * from $table", Person))[0]
@ -174,29 +190,32 @@ class DatabaseTestCase(TestCaseWithData):
self._insert_and_check(self._sample_data(), len(data))
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = self.database.raw(query)
self.assertEqual(results, "Whitney\tDurham\t1977-09-15\t1.72\t\\N\nWhitney\tScott\t1971-07-04\t1.7\t\\N\n")
self.assertEqual(
results,
"Whitney\tDurham\t1977-09-15\t1.72\t\\N\nWhitney\tScott\t1971-07-04\t1.7\t\\N\n",
)
def test_invalid_user(self):
with self.assertRaises(ServerError) as cm:
Database(self.database.db_name, username='default', password='wrong')
Database(self.database.db_name, username="default", password="wrong")
exc = cm.exception
if exc.code == 193: # ClickHouse version < 20.3
self.assertTrue(exc.message.startswith('Wrong password for user default'))
self.assertTrue(exc.message.startswith("Wrong password for user default"))
elif exc.code == 516: # ClickHouse version >= 20.3
self.assertTrue(exc.message.startswith('default: Authentication failed'))
self.assertTrue(exc.message.startswith("default: Authentication failed"))
else:
raise Exception('Unexpected error code - %s' % exc.code)
raise Exception("Unexpected error code - %s" % exc.code)
def test_nonexisting_db(self):
db = Database('db_not_here', autocreate=False)
db = Database("db_not_here", autocreate=False)
with self.assertRaises(ServerError) as cm:
db.create_table(Person)
exc = cm.exception
self.assertEqual(exc.code, 81)
self.assertTrue(exc.message.startswith("Database db_not_here doesn't exist"))
# Create and delete the db twice, to ensure db_exists gets updated
for i in range(2):
for _ in range(2):
# Now create the database - should succeed
db.create_database()
self.assertTrue(db.db_exists)
@ -212,25 +231,28 @@ class DatabaseTestCase(TestCaseWithData):
def test_missing_engine(self):
class EnginelessModel(Model):
float_field = Float32Field()
with self.assertRaises(DatabaseException) as cm:
self.database.create_table(EnginelessModel)
self.assertEqual(str(cm.exception), 'EnginelessModel class must define an engine')
self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
def test_potentially_problematic_field_names(self):
class Model1(Model):
system = StringField()
readonly = StringField()
engine = Memory()
instance = Model1(system='s', readonly='r')
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r'))
instance = Model1(system="s", readonly="r")
self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
self.database.create_table(Model1)
self.database.insert([instance])
instance = Model1.objects_in(self.database)[0]
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r'))
self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
def test_does_table_exist(self):
class Person2(Person):
pass
self.assertTrue(self.database.does_table_exist(Person))
self.assertFalse(self.database.does_table_exist(Person2))
@ -239,32 +261,31 @@ class DatabaseTestCase(TestCaseWithData):
with self.assertRaises(AssertionError):
self.database.add_setting(0, 1)
# Add a setting and see that it makes the query fail
self.database.add_setting('max_columns_to_read', 1)
self.database.add_setting("max_columns_to_read", 1)
with self.assertRaises(ServerError):
list(self.database.select('SELECT * from system.tables'))
list(self.database.select("SELECT * from system.tables"))
# Remove the setting and see that now it works
self.database.add_setting('max_columns_to_read', None)
list(self.database.select('SELECT * from system.tables'))
self.database.add_setting("max_columns_to_read", None)
list(self.database.select("SELECT * from system.tables"))
def test_create_ad_hoc_field(self):
# Tests that create_ad_hoc_field works for all column types in the database
from infi.clickhouse_orm.models import ModelBase
from clickhouse_orm.models import ModelBase
query = "SELECT DISTINCT type FROM system.columns"
for row in self.database.select(query):
ModelBase.create_ad_hoc_field(row.type)
def test_get_model_for_table(self):
# Tests that get_model_for_table works for a non-system model
model = self.database.get_model_for_table('person')
model = self.database.get_model_for_table("person")
self.assertFalse(model.is_system_model())
self.assertFalse(model.is_read_only())
self.assertEqual(model.table_name(), 'person')
self.assertEqual(model.table_name(), "person")
# Read a few records
list(model.objects_in(self.database)[:10])
# Inserts should work too
self.database.insert([
model(first_name='aaa', last_name='bbb', height=1.77)
])
self.database.insert([model(first_name="aaa", last_name="bbb", height=1.77)])
def test_get_model_for_table__system(self):
# Tests that get_model_for_table works for all system tables
@ -275,11 +296,15 @@ class DatabaseTestCase(TestCaseWithData):
self.assertTrue(model.is_system_model())
self.assertTrue(model.is_read_only())
self.assertEqual(model.table_name(), row.name)
if row.name == "distributed_ddl_queue":
continue # Since zookeeper is not set up in our tests
# Read a few records
try:
list(model.objects_in(self.database)[:10])
except ServerError as e:
if 'Not enough privileges' in e.message:
if "Not enough privileges" in str(e):
pass
else:
raise

View File

@ -1,119 +0,0 @@
import unittest
import datetime
import pytz
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class DateFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old')
self.database.create_table(ModelWithDate)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert([
ModelWithDate(
date_field='2016-08-30',
datetime_field='2016-08-30 03:50:00',
datetime64_field='2016-08-30 03:50:00.123456',
datetime64_3_field='2016-08-30 03:50:00.123456'
),
ModelWithDate(
date_field='2016-08-31',
datetime_field='2016-08-31 01:30:00',
datetime64_field='2016-08-31 01:30:00.123456',
datetime64_3_field='2016-08-31 01:30:00.123456')
])
# toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
query = 'SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field'
results = list(self.database.select(query))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
self.assertEqual(results[0].datetime_field, datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].hour_start, datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
self.assertEqual(results[1].datetime_field, datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].hour_start, datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_3_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123000,
tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_3_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123000,
tzinfo=pytz.UTC))
class ModelWithDate(Model):
date_field = DateField()
datetime_field = DateTimeField()
datetime64_field = DateTime64Field()
datetime64_3_field = DateTime64Field(precision=3)
engine = MergeTree('date_field', ('date_field',))
class ModelWithTz(Model):
datetime_no_tz_field = DateTimeField() # server tz
datetime_tz_field = DateTimeField(timezone='Europe/Madrid')
datetime64_tz_field = DateTime64Field(timezone='Europe/Madrid')
datetime_utc_field = DateTimeField(timezone=pytz.UTC)
engine = MergeTree('datetime_no_tz_field', ('datetime_no_tz_field',))
class DateTimeFieldWithTzTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old')
self.database.create_table(ModelWithTz)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert([
ModelWithTz(
datetime_no_tz_field='2020-06-11 04:00:00',
datetime_tz_field='2020-06-11 04:00:00',
datetime64_tz_field='2020-06-11 04:00:00',
datetime_utc_field='2020-06-11 04:00:00',
),
ModelWithTz(
datetime_no_tz_field='2020-06-11 07:00:00+0300',
datetime_tz_field='2020-06-11 07:00:00+0300',
datetime64_tz_field='2020-06-11 07:00:00+0300',
datetime_utc_field='2020-06-11 07:00:00+0300',
),
])
query = 'SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field'
results = list(self.database.select(query))
self.assertEqual(results[0].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[0].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[0].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone)
self.assertEqual(results[1].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[1].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[1].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone)

View File

@ -1,121 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from decimal import Decimal
from infi.clickhouse_orm.database import Database, ServerError
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class DecimalFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
try:
self.database.create_table(DecimalModel)
except ServerError as e:
# This ClickHouse version does not support decimals yet
raise unittest.SkipTest(e.message)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert([
DecimalModel(date_field='2016-08-20'),
DecimalModel(date_field='2016-08-21', dec=Decimal('1.234')),
DecimalModel(date_field='2016-08-22', dec32=Decimal('12342.2345')),
DecimalModel(date_field='2016-08-23', dec64=Decimal('12342.23456')),
DecimalModel(date_field='2016-08-24', dec128=Decimal('-4545456612342.234567')),
])
def _assert_sample_data(self, results):
self.assertEqual(len(results), 5)
self.assertEqual(results[0].dec, Decimal(0))
self.assertEqual(results[0].dec32, Decimal(17))
self.assertEqual(results[1].dec, Decimal('1.234'))
self.assertEqual(results[2].dec32, Decimal('12342.2345'))
self.assertEqual(results[3].dec64, Decimal('12342.23456'))
self.assertEqual(results[4].dec128, Decimal('-4545456612342.234567'))
def test_insert_and_select(self):
self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, DecimalModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = 'SELECT * from decimalmodel ORDER BY date_field'
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_rounding(self):
d = Decimal('11111.2340000000000000001')
self.database.insert([DecimalModel(date_field='2016-08-20', dec=d, dec32=d, dec64=d, dec128=d)])
m = DecimalModel.objects_in(self.database)[0]
for val in (m.dec, m.dec32, m.dec64, m.dec128):
self.assertEqual(val, Decimal('11111.234'))
def test_assignment_ok(self):
for value in (True, False, 17, 3.14, '20.5', Decimal('20.5')):
DecimalModel(dec=value)
def test_assignment_error(self):
for value in ('abc', u'זה ארוך', None, float('NaN'), Decimal('-Infinity')):
with self.assertRaises(ValueError):
DecimalModel(dec=value)
def test_aggregation(self):
self._insert_sample_data()
result = DecimalModel.objects_in(self.database).aggregate(m='min(dec)', n='max(dec)')
self.assertEqual(result[0].m, Decimal(0))
self.assertEqual(result[0].n, Decimal('1.234'))
def test_precision_and_scale(self):
# Go over all valid combinations
for precision in range(1, 39):
for scale in range(0, precision + 1):
f = DecimalField(precision, scale)
# Some invalid combinations
for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
with self.assertRaises(AssertionError):
f = DecimalField(precision, scale)
def test_min_max(self):
# In range
f = DecimalField(3, 1)
f.validate(f.to_python('99.9', None))
f.validate(f.to_python('-99.9', None))
# In range after rounding
f.validate(f.to_python('99.94', None))
f.validate(f.to_python('-99.94', None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python('99.99', None))
with self.assertRaises(ValueError):
f.validate(f.to_python('-99.99', None))
# In range
f = Decimal32Field(4)
f.validate(f.to_python('99999.9999', None))
f.validate(f.to_python('-99999.9999', None))
# In range after rounding
f.validate(f.to_python('99999.99994', None))
f.validate(f.to_python('-99999.99994', None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python('100000', None))
with self.assertRaises(ValueError):
f.validate(f.to_python('-100000', None))
class DecimalModel(Model):
date_field = DateField()
dec = DecimalField(15, 3)
dec32 = Decimal32Field(4, default=17)
dec64 = Decimal64Field(5)
dec128 = Decimal128Field(6)
engine = Memory()

View File

@ -1,41 +1,43 @@
import unittest
import logging
import unittest
from infi.clickhouse_orm import *
from clickhouse_orm import Database, F, Memory, Model, StringField, UInt64Field
class DictionaryTestMixin:
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 11, 73):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
self._create_dictionary()
def tearDown(self):
self.database.drop_database()
def _test_func(self, func, expected_value):
sql = 'SELECT %s AS value' % func.to_sql()
def _call_func(self, func):
sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql)
result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>')
print('Comparing %s to %s' % (result[0].value, expected_value))
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
logging.info("\t==> %s", result[0].value if result else "<empty>")
return result
def _test_func(self, func, expected_value):
result = self._call_func(func)
print("Comparing %s to %s" % (result[0].value, expected_value))
assert result[0].value == expected_value
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self):
# Create a table to be used as source for the dictionary
self.database.create_table(NumberName)
self.database.insert(
NumberName(number=i, name=name)
for i, name in enumerate('Zero One Two Three Four Five Six Seven Eight Nine Ten'.split())
for i, name in enumerate("Zero One Two Three Four Five Six Seven Eight Nine Ten".split())
)
# Create the dictionary
self.database.raw("""
self.database.raw(
"""
CREATE DICTIONARY numbers_dict(
number UInt64,
name String DEFAULT '?'
@ -46,16 +48,17 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
))
LIFETIME(100)
LAYOUT(HASHED());
""")
self.dict_name = 'test-db.numbers_dict'
"""
)
self.dict_name = "test-db.numbers_dict"
def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(3)), 'Three')
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(99)), '?')
self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(3)), "Three")
self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(99)), "?")
def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(3), 'n/a'), 'Three')
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(99), 'n/a'), 'n/a')
self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(3), "n/a"), "Three")
self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(99), "n/a"), "n/a")
def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -63,19 +66,21 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self):
# Create a table to be used as source for the dictionary
self.database.create_table(Region)
self.database.insert([
Region(region_id=1, parent_region=0, region_name='Russia'),
Region(region_id=2, parent_region=1, region_name='Moscow'),
Region(region_id=3, parent_region=2, region_name='Center'),
Region(region_id=4, parent_region=0, region_name='Great Britain'),
Region(region_id=5, parent_region=4, region_name='London'),
])
self.database.insert(
[
Region(region_id=1, parent_region=0, region_name="Russia"),
Region(region_id=2, parent_region=1, region_name="Moscow"),
Region(region_id=3, parent_region=2, region_name="Center"),
Region(region_id=4, parent_region=0, region_name="Great Britain"),
Region(region_id=5, parent_region=4, region_name="London"),
]
)
# Create the dictionary
self.database.raw("""
self.database.raw(
"""
CREATE DICTIONARY regions_dict(
region_id UInt64,
parent_region UInt64 HIERARCHICAL,
@ -87,17 +92,24 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
))
LIFETIME(100)
LAYOUT(HASHED());
""")
self.dict_name = 'test-db.regions_dict'
"""
)
self.dict_name = "test-db.regions_dict"
def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(3)), 'Center')
self._test_func(F.dictGet(self.dict_name, 'parent_region', F.toUInt64(3)), 2)
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(99)), '?')
self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(3)), "Center")
self._test_func(F.dictGet(self.dict_name, "parent_region", F.toUInt64(3)), 2)
self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(99)), "?")
def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(3), 'n/a'), 'Center')
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(99), 'n/a'), 'n/a')
self._test_func(
F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(3), "n/a"),
"Center",
)
self._test_func(
F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(99), "n/a"),
"n/a",
)
def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -105,7 +117,10 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def test_dictgethierarchy(self):
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [99])
# Default behaviour changed in CH, but we're not really testing that
default = self._call_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)))
assert isinstance(default, list)
assert len(default) <= 1 # either [] or [99]
def test_dictisin(self):
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)
@ -114,7 +129,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class NumberName(Model):
''' A table to act as a source for the dictionary '''
"""A table to act as a source for the dictionary"""
number = UInt64Field()
name = StringField()

View File

@ -1,16 +1,29 @@
import unittest
import datetime
from infi.clickhouse_orm import *
import logging
import unittest
from clickhouse_orm.database import Database, DatabaseException, ServerError
from clickhouse_orm.engines import (
CollapsingMergeTree,
Log,
Memory,
Merge,
MergeTree,
ReplacingMergeTree,
SummingMergeTree,
TinyLog,
)
from clickhouse_orm.fields import DateField, Int8Field, UInt8Field, UInt16Field, UInt32Field
from clickhouse_orm.funcs import F
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
from clickhouse_orm.system_models import SystemPart
logging.getLogger("requests").setLevel(logging.WARNING)
class _EnginesHelperTestCase(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
def tearDown(self):
self.database.drop_database()
@ -19,32 +32,57 @@ class _EnginesHelperTestCase(unittest.TestCase):
class EnginesTestCase(_EnginesHelperTestCase):
def _create_and_insert(self, model_class):
self.database.create_table(model_class)
self.database.insert([
model_class(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1)
])
self.database.insert(
[
model_class(
date="2017-01-01",
event_id=23423,
event_group=13,
event_count=7,
event_version=1,
)
]
)
def test_merge_tree(self):
class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group'))
engine = MergeTree("date", ("date", "event_id", "event_group"))
self._create_and_insert(TestModel)
def test_merge_tree_with_sampling(self):
class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', 'intHash32(event_id)'), sampling_expr='intHash32(event_id)')
engine = MergeTree(
"date",
("date", "event_id", "event_group", "intHash32(event_id)"),
sampling_expr="intHash32(event_id)",
)
self._create_and_insert(TestModel)
def test_merge_tree_with_sampling__funcs(self):
class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', F.intHash32(SampleModel.event_id)), sampling_expr=F.intHash32(SampleModel.event_id))
engine = MergeTree(
"date",
("date", "event_id", "event_group", F.intHash32(SampleModel.event_id)),
sampling_expr=F.intHash32(SampleModel.event_id),
)
self._create_and_insert(TestModel)
def test_merge_tree_with_granularity(self):
class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), index_granularity=4096)
engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
self._create_and_insert(TestModel)
def test_replicated_merge_tree(self):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits', replica_name='{replica}')
engine = MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
replica_name="{replica}",
)
# In ClickHouse 1.1.54310 custom partitioning key was introduced and new syntax is used
if self.database.server_version >= (1, 1, 54310):
expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}') PARTITION BY (toYYYYMM(`date`)) ORDER BY (date, event_id, event_group) SETTINGS index_granularity=8192"
@ -54,38 +92,48 @@ class EnginesTestCase(_EnginesHelperTestCase):
def test_replicated_merge_tree_incomplete(self):
with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits')
MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
)
with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_name='{replica}')
MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
def test_collapsing_merge_tree(self):
class TestModel(SampleModel):
engine = CollapsingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_version')
engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
self._create_and_insert(TestModel)
def test_summing_merge_tree(self):
class TestModel(SampleModel):
engine = SummingMergeTree('date', ('date', 'event_group'), ('event_count',))
engine = SummingMergeTree("date", ("date", "event_group"), ("event_count",))
self._create_and_insert(TestModel)
def test_replacing_merge_tree(self):
class TestModel(SampleModel):
engine = ReplacingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_uversion')
engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
self._create_and_insert(TestModel)
def test_tiny_log(self):
class TestModel(SampleModel):
engine = TinyLog()
self._create_and_insert(TestModel)
def test_log(self):
class TestModel(SampleModel):
engine = Log()
self._create_and_insert(TestModel)
def test_memory(self):
class TestModel(SampleModel):
engine = Memory()
self._create_and_insert(TestModel)
def test_merge(self):
@ -96,7 +144,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
engine = TinyLog()
class TestMergeModel(MergeModel, SampleModel):
engine = Merge('^testmodel')
engine = Merge("^testmodel")
self.database.create_table(TestModel1)
self.database.create_table(TestModel2)
@ -104,54 +152,87 @@ class EnginesTestCase(_EnginesHelperTestCase):
# Insert operations are restricted for this model type
with self.assertRaises(DatabaseException):
self.database.insert([
TestMergeModel(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1)
])
self.database.insert(
[
TestMergeModel(
date="2017-01-01",
event_id=23423,
event_group=13,
event_count=7,
event_version=1,
)
]
)
# Testing select
self.database.insert([
TestModel1(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1)
])
self.database.insert([
TestModel2(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2)
])
self.database.insert(
[
TestModel1(
date="2017-01-01",
event_id=1,
event_group=1,
event_count=1,
event_version=1,
)
]
)
self.database.insert(
[
TestModel2(
date="2017-01-02",
event_id=2,
event_group=2,
event_count=2,
event_version=2,
)
]
)
# event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, _table, event_uversion FROM $table ORDER BY event_id', model_class=TestMergeModel)
res = self.database.select(
"SELECT *, _table, event_uversion FROM $table ORDER BY event_id",
model_class=TestMergeModel,
)
res = list(res)
self.assertEqual(2, len(res))
self.assertDictEqual({
'_table': 'testmodel1',
'date': datetime.date(2017, 1, 1),
'event_id': 1,
'event_group': 1,
'event_count': 1,
'event_version': 1,
'event_uversion': 1
}, res[0].to_dict(include_readonly=True))
self.assertDictEqual({
'_table': 'testmodel2',
'date': datetime.date(2017, 1, 2),
'event_id': 2,
'event_group': 2,
'event_count': 2,
'event_version': 2,
'event_uversion': 2
}, res[1].to_dict(include_readonly=True))
self.assertDictEqual(
{
"_table": "testmodel1",
"date": datetime.date(2017, 1, 1),
"event_id": 1,
"event_group": 1,
"event_count": 1,
"event_version": 1,
"event_uversion": 1,
},
res[0].to_dict(include_readonly=True),
)
self.assertDictEqual(
{
"_table": "testmodel2",
"date": datetime.date(2017, 1, 2),
"event_id": 2,
"event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=True),
)
def test_custom_partitioning(self):
class TestModel(SampleModel):
engine = MergeTree(
order_by=('date', 'event_id', 'event_group'),
partition_key=('toYYYYMM(date)', 'event_group')
order_by=("date", "event_id", "event_group"),
partition_key=("toYYYYMM(date)", "event_group"),
)
class TestCollapseModel(SampleModel):
sign = Int8Field()
sign = Int8Field(default=-1)
engine = CollapsingMergeTree(
sign_col='sign',
order_by=('date', 'event_id', 'event_group'),
partition_key=('toYYYYMM(date)', 'event_group')
sign_col="sign",
order_by=("date", "event_id", "event_group"),
partition_key=("toYYYYMM(date)", "event_group"),
)
self._create_and_insert(TestModel)
@ -161,30 +242,30 @@ class EnginesTestCase(_EnginesHelperTestCase):
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
self.assertEqual(2, len(parts))
self.assertEqual('testcollapsemodel', parts[0].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[0].partition.replace(' ', ''))
self.assertEqual('testmodel', parts[1].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[1].partition.replace(' ', ''))
self.assertEqual("testcollapsemodel", parts[0].table)
self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
self.assertEqual("testmodel", parts[1].table)
self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
def test_custom_primary_key(self):
if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
class TestModel(SampleModel):
engine = MergeTree(
order_by=('date', 'event_id', 'event_group'),
partition_key=('toYYYYMM(date)',),
primary_key=('date', 'event_id')
order_by=("date", "event_id", "event_group"),
partition_key=("toYYYYMM(date)",),
primary_key=("date", "event_id"),
)
class TestCollapseModel(SampleModel):
sign = Int8Field()
sign = Int8Field(default=1)
engine = CollapsingMergeTree(
sign_col='sign',
order_by=('date', 'event_id', 'event_group'),
partition_key=('toYYYYMM(date)',),
primary_key=('date', 'event_id')
sign_col="sign",
order_by=("date", "event_id", "event_group"),
partition_key=("toYYYYMM(date)",),
primary_key=("date", "event_id"),
)
self._create_and_insert(TestModel)
@ -200,23 +281,23 @@ class SampleModel(Model):
event_group = UInt32Field()
event_count = UInt16Field()
event_version = Int8Field()
event_uversion = UInt8Field(materialized='abs(event_version)')
event_uversion = UInt8Field(materialized="abs(event_version)")
class DistributedTestCase(_EnginesHelperTestCase):
def test_without_table_name(self):
engine = Distributed('my_cluster')
engine = Distributed("my_cluster")
with self.assertRaises(ValueError) as cm:
engine.create_table_sql(self.database)
exc = cm.exception
self.assertEqual(str(exc), 'Cannot create Distributed engine: specify an underlying table')
self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
def test_with_table_name(self):
engine = Distributed('my_cluster', 'foo')
engine = Distributed("my_cluster", "foo")
sql = engine.create_table_sql(self.database)
self.assertEqual(sql, 'Distributed(`my_cluster`, `test-db`, `foo`)')
self.assertEqual(sql, "Distributed(`my_cluster`, `test-db`, `foo`)")
class TestModel(SampleModel):
engine = TinyLog()
@ -231,7 +312,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_bad_cluster_name(self):
with self.assertRaises(ServerError) as cm:
d_model = self._create_distributed('cluster_name')
d_model = self._create_distributed("cluster_name")
self.database.count(d_model)
exc = cm.exception
@ -243,7 +324,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost', self.TestModel)
engine = Distributed("test_shard_localhost", self.TestModel)
self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel)
@ -251,7 +332,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_minimal_engine(self):
class TestDistributedModel(DistributedModel, self.TestModel):
engine = Distributed('test_shard_localhost')
engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel)
@ -263,64 +344,89 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost')
engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel)
exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure '
'that your model has exactly one non-distributed superclass')
self.assertEqual(
str(exc),
"When defining Distributed engine without the table_name ensure "
"that your model has exactly one non-distributed superclass",
)
def test_minimal_engine_no_superclasses(self):
class TestDistributedModel(DistributedModel):
engine = Distributed('test_shard_localhost')
engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel)
exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure '
'that your model has a parent model')
self.assertEqual(
str(exc),
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
)
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
d_model = self._create_distributed('test_shard_localhost', underlying=test_model)
d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
if local_to_distributed:
to_insert, to_select = test_model, d_model
else:
to_insert, to_select = d_model, test_model
self.database.insert([
to_insert(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1),
to_insert(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2)
])
self.database.insert(
[
to_insert(
date="2017-01-01",
event_id=1,
event_group=1,
event_count=1,
event_version=1,
),
to_insert(
date="2017-01-02",
event_id=2,
event_group=2,
event_count=2,
event_version=2,
),
]
)
# event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, event_uversion FROM $table ORDER BY event_id',
model_class=to_select)
res = self.database.select(
"SELECT *, event_uversion FROM $table ORDER BY event_id",
model_class=to_select,
)
res = [row for row in res]
self.assertEqual(2, len(res))
self.assertDictEqual({
'date': datetime.date(2017, 1, 1),
'event_id': 1,
'event_group': 1,
'event_count': 1,
'event_version': 1,
'event_uversion': 1
}, res[0].to_dict(include_readonly=include_readonly))
self.assertDictEqual({
'date': datetime.date(2017, 1, 2),
'event_id': 2,
'event_group': 2,
'event_count': 2,
'event_version': 2,
'event_uversion': 2
}, res[1].to_dict(include_readonly=include_readonly))
self.assertDictEqual(
{
"date": datetime.date(2017, 1, 1),
"event_id": 1,
"event_group": 1,
"event_count": 1,
"event_version": 1,
"event_uversion": 1,
},
res[0].to_dict(include_readonly=include_readonly),
)
self.assertDictEqual(
{
"date": datetime.date(2017, 1, 2),
"event_id": 2,
"event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=include_readonly),
)
@unittest.skip("Bad support of materialized fields in Distributed tables "
"https://groups.google.com/forum/#!topic/clickhouse/XEYRRwZrsSc")
def test_insert_distributed_select_local(self):
return self._test_insert_select(local_to_distributed=False)

View File

@ -1,57 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class FixedStringFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(FixedStringModel)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert([
FixedStringModel(date_field='2016-08-30', fstr_field=''),
FixedStringModel(date_field='2016-08-30'),
FixedStringModel(date_field='2016-08-31', fstr_field='foo'),
FixedStringModel(date_field='2016-08-31', fstr_field=u'לילה')
])
def _assert_sample_data(self, results):
self.assertEqual(len(results), 4)
self.assertEqual(results[0].fstr_field, '')
self.assertEqual(results[1].fstr_field, 'ABCDEFGHIJK')
self.assertEqual(results[2].fstr_field, 'foo')
self.assertEqual(results[3].fstr_field, u'לילה')
def test_insert_and_select(self):
self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, FixedStringModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = 'SELECT * from $db.fixedstringmodel ORDER BY date_field'
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_assignment_error(self):
for value in (17, 'this is too long', u'זה ארוך', None, 99.9):
with self.assertRaises(ValueError):
FixedStringModel(fstr_field=value)
class FixedStringModel(Model):
date_field = DateField()
fstr_field = FixedStringField(12, default='ABCDEFGHIJK')
engine = MergeTree('date_field', ('date_field',))

View File

@ -1,19 +1,21 @@
import unittest
from .base_test_with_data import *
from .test_querysets import SampleModel
from datetime import date, datetime, tzinfo, timedelta
import pytz
from ipaddress import IPv4Address, IPv6Address
import logging
import unittest
from datetime import date, datetime, timedelta
from decimal import Decimal
from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import ServerError
from infi.clickhouse_orm.utils import NO_VALUE
from infi.clickhouse_orm.funcs import F
import pytz
from clickhouse_orm.database import ServerError
from clickhouse_orm.fields import DateTimeField
from clickhouse_orm.funcs import F
from clickhouse_orm.utils import NO_VALUE
from .base_test_with_data import Person, TestCaseWithData
from .test_querysets import SampleModel
class FuncsTestCase(TestCaseWithData):
def setUp(self):
super(FuncsTestCase, self).setUp()
self.database.insert(self._sample_data())
@ -23,70 +25,75 @@ class FuncsTestCase(TestCaseWithData):
count = 0
for instance in qs:
count += 1
logging.info('\t[%d]\t%s' % (count, instance.to_dict()))
logging.info("\t[%d]\t%s" % (count, instance.to_dict()))
self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=NO_VALUE):
sql = 'SELECT %s AS value' % func.to_sql()
def _call_func(self, func):
sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql)
try:
result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>')
if expected_value != NO_VALUE:
print('Comparing %s to %s' % (result[0].value, expected_value))
self.assertEqual(result[0].value, expected_value)
logging.info("\t==> %s", result[0].value if result else "<empty>")
return result[0].value if result else None
except ServerError as e:
if 'Unknown function' in e.message:
logging.warning(e.message)
if "Unknown function" in str(e):
logging.warning(str(e))
return # ignore functions that don't exist in the used ClickHouse version
raise
def _test_func(self, func, expected_value=NO_VALUE):
result = self._call_func(func)
if expected_value != NO_VALUE:
print("Comparing %s to %s" % (result, expected_value))
self.assertEqual(result, expected_value)
return result if result else None
def _test_aggr(self, func, expected_value=NO_VALUE):
qs = Person.objects_in(self.database).aggregate(value=func)
logging.info(qs.as_sql())
try:
result = list(qs)
logging.info('\t==> %s', result[0].value if result else '<empty>')
logging.info("\t==> %s", result[0].value if result else "<empty>")
if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None
except ServerError as e:
if 'Unknown function' in e.message:
logging.warning(e.message)
if "Unknown function" in str(e):
logging.warning(str(e))
return # ignore functions that don't exist in the used ClickHouse version
raise
def test_func_to_sql(self):
# No args
self.assertEqual(F('func').to_sql(), 'func()')
self.assertEqual(F("func").to_sql(), "func()")
# String args
self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
self.assertEqual(F("func", "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
# Numeric args
self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)")
self.assertEqual(F("func", 1, 1.1, Decimal("3.3")).to_sql(), "func(1, 1.1, 3.3)")
# Date args
self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))")
self.assertEqual(F("func", date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))")
# Datetime args
self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))")
self.assertEqual(F("func", datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))")
# Boolean args
self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)")
self.assertEqual(F("func", True, False).to_sql(), "func(1, 0)")
# Timezone args
self.assertEqual(F('func', pytz.utc).to_sql(), "func('UTC')")
self.assertEqual(F('func', pytz.timezone('Europe/Athens')).to_sql(), "func('Europe/Athens')")
self.assertEqual(F("func", pytz.utc).to_sql(), "func('UTC')")
self.assertEqual(F("func", pytz.timezone("Europe/Athens")).to_sql(), "func('Europe/Athens')")
# Null args
self.assertEqual(F('func', None).to_sql(), "func(NULL)")
self.assertEqual(F("func", None).to_sql(), "func(NULL)")
# Fields as args
self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)")
self.assertEqual(F("func", SampleModel.color).to_sql(), "func(`color`)")
# Funcs as args
self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))')
self.assertEqual(F("func", F("sqrt", 25)).to_sql(), "func(sqrt(25))")
# Iterables as args
x = [1, 'z', F('foo', 17)]
x = [1, "z", F("foo", 17)]
for y in [x, iter(x)]:
self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
self.assertEqual(F("func", y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
# Tuples as args
self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])")
self.assertEqual(F('func', tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)")
self.assertEqual(F("func", [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])")
self.assertEqual(F("func", tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)")
# Binary operator functions
self.assertEqual(F.plus(1, 2).to_sql(), "(1 + 2)")
self.assertEqual(F.lessOrEquals(1, 2).to_sql(), "(1 <= 2)")
@ -106,32 +113,32 @@ class FuncsTestCase(TestCaseWithData):
def test_filter_date_field(self):
qs = Person.objects_in(self.database)
# People born on the 30th
self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3)
self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3)
self._test_qs(qs.filter(F("equals", F("toDayOfMonth", Person.birthday), 30)), 3)
self._test_qs(qs.filter(F("toDayOfMonth", Person.birthday) == 30), 3)
self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3)
# People born on Sunday
self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18)
self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18)
self._test_qs(qs.filter(F("equals", F("toDayOfWeek", Person.birthday), 7)), 18)
self._test_qs(qs.filter(F("toDayOfWeek", Person.birthday) == 7), 18)
self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18)
# People born on 1976-10-01
self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1)
self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 1))), 1)
self._test_qs(qs.filter(F("equals", Person.birthday, "1976-10-01")), 1)
self._test_qs(qs.filter(F("equals", Person.birthday, date(1976, 10, 1))), 1)
self._test_qs(qs.filter(Person.birthday == date(1976, 10, 1)), 1)
def test_func_as_field_value(self):
qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96)
self._test_qs(qs.exclude(birthday=F.today()), 100)
self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100)
self._test_qs(qs.filter(birthday__between=["1970-01-01", F.today()]), 100)
def test_in_and_not_in(self):
qs = Person.objects_in(self.database)
self._test_qs(qs.filter(Person.first_name.isIn(['Ciaran', 'Elton'])), 4)
self._test_qs(qs.filter(~Person.first_name.isIn(['Ciaran', 'Elton'])), 96)
self._test_qs(qs.filter(Person.first_name.isNotIn(['Ciaran', 'Elton'])), 96)
self._test_qs(qs.exclude(Person.first_name.isIn(['Ciaran', 'Elton'])), 96)
self._test_qs(qs.filter(Person.first_name.isIn(["Ciaran", "Elton"])), 4)
self._test_qs(qs.filter(~Person.first_name.isIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.filter(Person.first_name.isNotIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.exclude(Person.first_name.isIn(["Ciaran", "Elton"])), 96)
# In subquery
subquery = qs.filter(F.startsWith(Person.last_name, 'M')).only(Person.first_name)
subquery = qs.filter(F.startsWith(Person.last_name, "M")).only(Person.first_name)
self._test_qs(qs.filter(Person.first_name.isIn(subquery)), 4)
def test_comparison_operators(self):
@ -213,7 +220,7 @@ class FuncsTestCase(TestCaseWithData):
dt = datetime(2018, 12, 31, 11, 22, 33)
self._test_func(F.toYear(d), 2018)
self._test_func(F.toYear(dt), 2018)
self._test_func(F.toISOYear(dt, 'Europe/Athens'), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toISOYear(dt, "Europe/Athens"), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toQuarter(d), 4)
self._test_func(F.toQuarter(dt), 4)
self._test_func(F.toMonth(d), 12)
@ -239,189 +246,256 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.toStartOfYear(d), date(2018, 1, 1))
self._test_func(F.toStartOfYear(dt), date(2018, 1, 1))
self._test_func(F.toStartOfMinute(dt), datetime(2018, 12, 31, 11, 22, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfFiveMinute(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfFifteenMinutes(dt), datetime(2018, 12, 31, 11, 15, 0, tzinfo=pytz.utc))
self._test_func(
F.toStartOfFiveMinute(dt),
datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc),
)
self._test_func(
F.toStartOfFifteenMinutes(dt),
datetime(2018, 12, 31, 11, 15, 0, tzinfo=pytz.utc),
)
self._test_func(F.toStartOfHour(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfISOYear(dt), date(2018, 12, 31))
self._test_func(F.toStartOfTenMinutes(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc))
self._test_func(
F.toStartOfTenMinutes(dt),
datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc),
)
self._test_func(F.toStartOfWeek(dt), date(2018, 12, 30))
self._test_func(F.toTime(dt), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toUnixTimestamp(dt, 'UTC'), int(dt.replace(tzinfo=pytz.utc).timestamp()))
self._test_func(F.toUnixTimestamp(dt, "UTC"), int(dt.replace(tzinfo=pytz.utc).timestamp()))
self._test_func(F.toYYYYMM(d), 201812)
self._test_func(F.toYYYYMM(dt), 201812)
self._test_func(F.toYYYYMM(dt, 'Europe/Athens'), 201812)
self._test_func(F.toYYYYMM(dt, "Europe/Athens"), 201812)
self._test_func(F.toYYYYMMDD(d), 20181231)
self._test_func(F.toYYYYMMDD(dt), 20181231)
self._test_func(F.toYYYYMMDD(dt, 'Europe/Athens'), 20181231)
self._test_func(F.toYYYYMMDD(dt, "Europe/Athens"), 20181231)
self._test_func(F.toYYYYMMDDhhmmss(d), 20181231000000)
self._test_func(F.toYYYYMMDDhhmmss(dt, 'Europe/Athens'), 20181231132233)
self._test_func(F.toYYYYMMDDhhmmss(dt, "Europe/Athens"), 20181231132233)
self._test_func(F.toRelativeYearNum(dt), 2018)
self._test_func(F.toRelativeYearNum(dt, 'Europe/Athens'), 2018)
self._test_func(F.toRelativeYearNum(dt, "Europe/Athens"), 2018)
self._test_func(F.toRelativeMonthNum(dt), 2018 * 12 + 12)
self._test_func(F.toRelativeMonthNum(dt, 'Europe/Athens'), 2018 * 12 + 12)
self._test_func(F.toRelativeMonthNum(dt, "Europe/Athens"), 2018 * 12 + 12)
self._test_func(F.toRelativeWeekNum(dt), 2557)
self._test_func(F.toRelativeWeekNum(dt, 'Europe/Athens'), 2557)
self._test_func(F.toRelativeWeekNum(dt, "Europe/Athens"), 2557)
self._test_func(F.toRelativeDayNum(dt), 17896)
self._test_func(F.toRelativeDayNum(dt, 'Europe/Athens'), 17896)
self._test_func(F.toRelativeDayNum(dt, "Europe/Athens"), 17896)
self._test_func(F.toRelativeHourNum(dt), 429515)
self._test_func(F.toRelativeHourNum(dt, 'Europe/Athens'), 429515)
self._test_func(F.toRelativeHourNum(dt, "Europe/Athens"), 429515)
self._test_func(F.toRelativeMinuteNum(dt), 25770922)
self._test_func(F.toRelativeMinuteNum(dt, 'Europe/Athens'), 25770922)
self._test_func(F.toRelativeMinuteNum(dt, "Europe/Athens"), 25770922)
self._test_func(F.toRelativeSecondNum(dt), 1546255353)
self._test_func(F.toRelativeSecondNum(dt, 'Europe/Athens'), 1546255353)
self._test_func(F.toRelativeSecondNum(dt, "Europe/Athens"), 1546255353)
self._test_func(F.timeSlot(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc))
self._test_func(F.timeSlots(dt, 300), [datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)])
self._test_func(F.formatDateTime(dt, '%D %T', 'Europe/Athens'), '12/31/18 13:22:33')
self._test_func(F.formatDateTime(dt, "%D %T", "Europe/Athens"), "12/31/18 13:22:33")
self._test_func(F.addDays(d, 7), date(2019, 1, 7))
self._test_func(F.addDays(dt, 7, 'Europe/Athens'))
self._test_func(F.addHours(dt, 7, 'Europe/Athens'))
self._test_func(F.addMinutes(dt, 7, 'Europe/Athens'))
self._test_func(F.addDays(dt, 7, "Europe/Athens"))
self._test_func(F.addHours(dt, 7, "Europe/Athens"))
self._test_func(F.addMinutes(dt, 7, "Europe/Athens"))
self._test_func(F.addMonths(d, 7), date(2019, 7, 31))
self._test_func(F.addMonths(dt, 7, 'Europe/Athens'))
self._test_func(F.addMonths(dt, 7, "Europe/Athens"))
self._test_func(F.addQuarters(d, 7))
self._test_func(F.addQuarters(dt, 7, 'Europe/Athens'))
self._test_func(F.addQuarters(dt, 7, "Europe/Athens"))
self._test_func(F.addSeconds(d, 7))
self._test_func(F.addSeconds(dt, 7, 'Europe/Athens'))
self._test_func(F.addSeconds(dt, 7, "Europe/Athens"))
self._test_func(F.addWeeks(d, 7))
self._test_func(F.addWeeks(dt, 7, 'Europe/Athens'))
self._test_func(F.addWeeks(dt, 7, "Europe/Athens"))
self._test_func(F.addYears(d, 7))
self._test_func(F.addYears(dt, 7, 'Europe/Athens'))
self._test_func(F.addYears(dt, 7, "Europe/Athens"))
self._test_func(F.subtractDays(d, 3))
self._test_func(F.subtractDays(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractDays(dt, 3, "Europe/Athens"))
self._test_func(F.subtractHours(d, 3))
self._test_func(F.subtractHours(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractHours(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMinutes(d, 3))
self._test_func(F.subtractMinutes(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractMinutes(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMonths(d, 3))
self._test_func(F.subtractMonths(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractMonths(dt, 3, "Europe/Athens"))
self._test_func(F.subtractQuarters(d, 3))
self._test_func(F.subtractQuarters(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractQuarters(dt, 3, "Europe/Athens"))
self._test_func(F.subtractSeconds(d, 3))
self._test_func(F.subtractSeconds(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractSeconds(dt, 3, "Europe/Athens"))
self._test_func(F.subtractWeeks(d, 3))
self._test_func(F.subtractWeeks(dt, 3, 'Europe/Athens'))
self._test_func(F.subtractWeeks(dt, 3, "Europe/Athens"))
self._test_func(F.subtractYears(d, 3))
self._test_func(F.subtractYears(dt, 3, 'Europe/Athens'))
self._test_func(F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3))
self._test_func(F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3))
self._test_func(F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000))
self._test_func(F.subtractYears(dt, 3, "Europe/Athens"))
self._test_func(
F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3)
)
self._test_func(
F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3)
)
self._test_func(
F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000)
)
def test_date_functions__utc_only(self):
def test_date_functions_utc_only(self):
if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone')
raise unittest.SkipTest("This test must run with UTC as the server timezone")
d = date(2018, 12, 31)
dt = datetime(2018, 12, 31, 11, 22, 33)
athens_tz = pytz.timezone('Europe/Athens')
athens_tz = pytz.timezone("Europe/Athens")
self._test_func(F.toHour(dt), 11)
self._test_func(F.toStartOfDay(dt), datetime(2018, 12, 31, 0, 0, 0, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, pytz.utc), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, 'Europe/Athens'), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)))
self._test_func(F.toTime(dt, athens_tz), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)))
self._test_func(F.toTimeZone(dt, 'Europe/Athens'), athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33)))
self._test_func(F.now(), datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)) # FIXME this may fail if the timing is just right
self._test_func(
F.toTime(dt, "Europe/Athens"),
athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)),
)
self._test_func(
F.toTime(dt, athens_tz),
athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)),
)
self._test_func(
F.toTimeZone(dt, "Europe/Athens"),
athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33)),
)
self._test_func(F.today(), datetime.utcnow().date())
self._test_func(F.yesterday(), datetime.utcnow().date() - timedelta(days=1))
self._test_func(F.toYYYYMMDDhhmmss(dt), 20181231112233)
self._test_func(F.formatDateTime(dt, '%D %T'), '12/31/18 11:22:33')
self._test_func(F.formatDateTime(dt, "%D %T"), "12/31/18 11:22:33")
self._test_func(F.addHours(d, 7), datetime(2018, 12, 31, 7, 0, 0, tzinfo=pytz.utc))
self._test_func(F.addMinutes(d, 7), datetime(2018, 12, 31, 0, 7, 0, tzinfo=pytz.utc))
actual = self._call_func(F.now())
expected = datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)
self.assertLess((actual - expected).total_seconds(), 1e-3)
def test_type_conversion_functions(self):
for f in (F.toUInt8, F.toUInt16, F.toUInt32, F.toUInt64, F.toInt8, F.toInt16, F.toInt32, F.toInt64, F.toFloat32, F.toFloat64):
for f in (
F.toUInt8,
F.toUInt16,
F.toUInt32,
F.toUInt64,
F.toInt8,
F.toInt16,
F.toInt32,
F.toInt64,
F.toFloat32,
F.toFloat64,
):
self._test_func(f(17), 17)
self._test_func(f('17'), 17)
for f in (F.toUInt8OrZero, F.toUInt16OrZero, F.toUInt32OrZero, F.toUInt64OrZero, F.toInt8OrZero, F.toInt16OrZero, F.toInt32OrZero, F.toInt64OrZero, F.toFloat32OrZero, F.toFloat64OrZero):
self._test_func(f('17'), 17)
self._test_func(f('a'), 0)
self._test_func(f("17"), 17)
for f in (
F.toUInt8OrZero,
F.toUInt16OrZero,
F.toUInt32OrZero,
F.toUInt64OrZero,
F.toInt8OrZero,
F.toInt16OrZero,
F.toInt32OrZero,
F.toInt64OrZero,
F.toFloat32OrZero,
F.toFloat64OrZero,
):
self._test_func(f("17"), 17)
self._test_func(f("a"), 0)
for f in (F.toDecimal32, F.toDecimal64, F.toDecimal128):
self._test_func(f(17.17, 2), Decimal('17.17'))
self._test_func(f('17.17', 2), Decimal('17.17'))
self._test_func(F.toDate('2018-12-31'), date(2018, 12, 31))
self._test_func(F.toString(123), '123')
self._test_func(F.toFixedString('123', 5), '123')
self._test_func(F.toStringCutToZero('123\0'), '123')
self._test_func(F.CAST(17, 'String'), '17')
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM', 'Europe/Athens'))
self._test_func(f(17.17, 2), Decimal("17.17"))
self._test_func(f("17.17", 2), Decimal("17.17"))
self._test_func(F.toDate("2018-12-31"), date(2018, 12, 31))
self._test_func(F.toString(123), "123")
self._test_func(F.toFixedString("123", 5), "123")
self._test_func(F.toStringCutToZero("123\0"), "123")
self._test_func(F.CAST(17, "String"), "17")
self._test_func(F.parseDateTimeBestEffort("31/12/2019 10:05AM", "Europe/Athens"))
with self.assertRaises(ServerError):
self._test_func(F.parseDateTimeBestEffort('foo'))
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM', 'Europe/Athens'))
self._test_func(F.parseDateTimeBestEffortOrNull('foo'), None)
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM', 'Europe/Athens'))
self._test_func(F.parseDateTimeBestEffortOrZero('foo'), DateTimeField.class_default)
self._test_func(F.parseDateTimeBestEffort("foo"))
self._test_func(F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrNull("foo"), None)
self._test_func(F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrZero("foo"), DateTimeField.class_default)
def test_type_conversion_functions__utc_only(self):
if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone')
self._test_func(F.toDateTime('2018-12-31 11:22:33'), datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toDateTime64('2018-12-31 11:22:33.001', 6), datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc))
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc))
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc))
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc))
raise unittest.SkipTest("This test must run with UTC as the server timezone")
self._test_func(
F.toDateTime("2018-12-31 11:22:33"),
datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc),
)
self._test_func(
F.toDateTime64("2018-12-31 11:22:33.001", 6),
datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffort("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
def test_string_functions(self):
self._test_func(F.empty(''), 1)
self._test_func(F.empty('x'), 0)
self._test_func(F.notEmpty(''), 0)
self._test_func(F.notEmpty('x'), 1)
self._test_func(F.length('x'), 1)
self._test_func(F.lengthUTF8('x'), 1)
self._test_func(F.lower('Ab'), 'ab')
self._test_func(F.upper('Ab'), 'AB')
self._test_func(F.lowerUTF8('Ab'), 'ab')
self._test_func(F.upperUTF8('Ab'), 'AB')
self._test_func(F.reverse('Ab'), 'bA')
self._test_func(F.reverseUTF8('Ab'), 'bA')
self._test_func(F.concat('Ab', 'Cd', 'Ef'), 'AbCdEf')
self._test_func(F.substring('123456', 3, 2), '34')
self._test_func(F.substringUTF8('123456', 3, 2), '34')
self._test_func(F.appendTrailingCharIfAbsent('Hello', '!'), 'Hello!')
self._test_func(F.appendTrailingCharIfAbsent('Hello!', '!'), 'Hello!')
self._test_func(F.convertCharset(F.convertCharset('Hello', 'latin1', 'utf16'), 'utf16', 'latin1'), 'Hello')
self._test_func(F.startsWith('aaa', 'aa'), True)
self._test_func(F.startsWith('aaa', 'bb'), False)
self._test_func(F.endsWith('aaa', 'aa'), True)
self._test_func(F.endsWith('aaa', 'bb'), False)
self._test_func(F.trimLeft(' abc '), 'abc ')
self._test_func(F.trimRight(' abc '), ' abc')
self._test_func(F.trimBoth(' abc '), 'abc')
self._test_func(F.CRC32('whoops'), 3361378926)
self._test_func(F.empty(""), 1)
self._test_func(F.empty("x"), 0)
self._test_func(F.notEmpty(""), 0)
self._test_func(F.notEmpty("x"), 1)
self._test_func(F.length("x"), 1)
self._test_func(F.lengthUTF8("x"), 1)
self._test_func(F.lower("Ab"), "ab")
self._test_func(F.upper("Ab"), "AB")
self._test_func(F.lowerUTF8("Ab"), "ab")
self._test_func(F.upperUTF8("Ab"), "AB")
self._test_func(F.reverse("Ab"), "bA")
self._test_func(F.reverseUTF8("Ab"), "bA")
self._test_func(F.concat("Ab", "Cd", "Ef"), "AbCdEf")
self._test_func(F.substring("123456", 3, 2), "34")
self._test_func(F.substringUTF8("123456", 3, 2), "34")
self._test_func(F.appendTrailingCharIfAbsent("Hello", "!"), "Hello!")
self._test_func(F.appendTrailingCharIfAbsent("Hello!", "!"), "Hello!")
self._test_func(
F.convertCharset(F.convertCharset("Hello", "latin1", "utf16"), "utf16", "latin1"),
"Hello",
)
self._test_func(F.startsWith("aaa", "aa"), True)
self._test_func(F.startsWith("aaa", "bb"), False)
self._test_func(F.endsWith("aaa", "aa"), True)
self._test_func(F.endsWith("aaa", "bb"), False)
self._test_func(F.trimLeft(" abc "), "abc ")
self._test_func(F.trimRight(" abc "), " abc")
self._test_func(F.trimBoth(" abc "), "abc")
self._test_func(F.CRC32("whoops"), 3361378926)
def test_string_search_functions(self):
self._test_func(F.position('Hello, world!', '!'), 13)
self._test_func(F.positionCaseInsensitive('Hello, world!', 'hello'), 1)
self._test_func(F.positionUTF8('Привет, мир!', '!'), 12)
self._test_func(F.positionCaseInsensitiveUTF8('Привет, мир!', 'Мир'), 9)
self._test_func(F.like('Hello, world!', '%ll%'), 1)
self._test_func(F.notLike('Hello, world!', '%ll%'), 0)
self._test_func(F.match('Hello, world!', '[lmnop]{3}'), 1)
self._test_func(F.extract('Hello, world!', '[lmnop]{3}'), 'llo')
self._test_func(F.extractAll('Hello, world!', '[a-z]+'), ['ello', 'world'])
self._test_func(F.ngramDistance('Hello', 'Hello'), 0)
self._test_func(F.ngramDistanceCaseInsensitive('Hello', 'hello'), 0)
self._test_func(F.ngramDistanceUTF8('Hello', 'Hello'), 0)
self._test_func(F.ngramDistanceCaseInsensitiveUTF8('Hello', 'hello'), 0)
self._test_func(F.ngramSearch('Hello', 'Hello'), 1)
self._test_func(F.ngramSearchCaseInsensitive('Hello', 'hello'), 1)
self._test_func(F.ngramSearchUTF8('Hello', 'Hello'), 1)
self._test_func(F.ngramSearchCaseInsensitiveUTF8('Hello', 'hello'), 1)
self._test_func(F.position("Hello, world!", "!"), 13)
self._test_func(F.positionCaseInsensitive("Hello, world!", "hello"), 1)
self._test_func(F.positionUTF8("Привет, мир!", "!"), 12)
self._test_func(F.positionCaseInsensitiveUTF8("Привет, мир!", "Мир"), 9)
self._test_func(F.like("Hello, world!", "%ll%"), 1)
self._test_func(F.notLike("Hello, world!", "%ll%"), 0)
self._test_func(F.match("Hello, world!", "[lmnop]{3}"), 1)
self._test_func(F.extract("Hello, world!", "[lmnop]{3}"), "llo")
self._test_func(F.extractAll("Hello, world!", "[a-z]+"), ["ello", "world"])
self._test_func(F.ngramDistance("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitive("Hello", "hello"), 0)
self._test_func(F.ngramDistanceUTF8("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitiveUTF8("Hello", "hello"), 0)
self._test_func(F.ngramSearch("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitive("Hello", "hello"), 1)
self._test_func(F.ngramSearchUTF8("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitiveUTF8("Hello", "hello"), 1)
def test_base64_functions(self):
try:
self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello')
self._test_func(F.tryBase64Decode(':-)'))
self._test_func(F.base64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(":-)"))
except ServerError as e:
# ClickHouse version that doesn't support these functions
raise unittest.SkipTest(e.message)
raise unittest.SkipTest(str(e))
def test_replace_functions(self):
haystack = 'hello'
self._test_func(F.replace(haystack, 'l', 'L'), 'heLLo')
self._test_func(F.replaceAll(haystack, 'l', 'L'), 'heLLo')
self._test_func(F.replaceOne(haystack, 'l', 'L'), 'heLlo')
self._test_func(F.replaceRegexpAll(haystack, '[eo]', 'X'), 'hXllX')
self._test_func(F.replaceRegexpOne(haystack, '[eo]', 'X'), 'hXllo')
self._test_func(F.regexpQuoteMeta('[eo]'), '\\[eo\\]')
haystack = "hello"
self._test_func(F.replace(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceAll(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceOne(haystack, "l", "L"), "heLlo")
self._test_func(F.replaceRegexpAll(haystack, "[eo]", "X"), "hXllX")
self._test_func(F.replaceRegexpOne(haystack, "[eo]", "X"), "hXllo")
self._test_func(F.regexpQuoteMeta("[eo]"), "\\[eo\\]")
def test_math_functions(self):
x = 17
@ -515,15 +589,15 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.arrayDifference(arr), [0, 1, 1])
self._test_func(F.arrayDistinct(arr + arr), arr)
self._test_func(F.arrayIntersect(arr, [3, 4]), [3])
self._test_func(F.arrayReduce('min', arr), 1)
self._test_func(F.arrayReduce("min", arr), 1)
self._test_func(F.arrayReverse(arr), [3, 2, 1])
def test_split_and_merge_functions(self):
self._test_func(F.splitByChar('_', 'a_b_c'), ['a', 'b', 'c'])
self._test_func(F.splitByString('__', 'a__b__c'), ['a', 'b', 'c'])
self._test_func(F.arrayStringConcat(['a', 'b', 'c']), 'abc')
self._test_func(F.arrayStringConcat(['a', 'b', 'c'], '_'), 'a_b_c')
self._test_func(F.alphaTokens('aaa.bbb.111'), ['aaa', 'bbb'])
self._test_func(F.splitByChar("_", "a_b_c"), ["a", "b", "c"])
self._test_func(F.splitByString("__", "a__b__c"), ["a", "b", "c"])
self._test_func(F.arrayStringConcat(["a", "b", "c"]), "abc")
self._test_func(F.arrayStringConcat(["a", "b", "c"], "_"), "a_b_c")
self._test_func(F.alphaTokens("aaa.bbb.111"), ["aaa", "bbb"])
def test_bit_functions(self):
x = 17
@ -546,23 +620,44 @@ class FuncsTestCase(TestCaseWithData):
def test_bitmap_functions(self):
self._test_func(F.bitmapToArray(F.bitmapBuild([1, 2, 3])), [1, 2, 3])
self._test_func(F.bitmapContains(F.bitmapBuild([1, 5, 7, 9]), F.toUInt32(9)), 1)
self._test_func(F.bitmapHasAny(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 1)
self._test_func(F.bitmapHasAll(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 0)
self._test_func(F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [3])
self._test_func(F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 3, 4, 5])
self._test_func(F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 4, 5])
self._test_func(F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2])
self._test_func(F.bitmapHasAny(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 1)
self._test_func(F.bitmapHasAll(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 0)
self._test_func(
F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[3],
)
self._test_func(
F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2, 3, 4, 5],
)
self._test_func(
F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2, 4, 5],
)
self._test_func(
F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2],
)
self._test_func(F.bitmapCardinality(F.bitmapBuild([1, 2, 3, 4, 5])), 5)
self._test_func(F.bitmapAndCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 1)
self._test_func(
F.bitmapAndCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
1,
)
self._test_func(F.bitmapOrCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 5)
self._test_func(F.bitmapXorCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 4)
self._test_func(F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 2)
self._test_func(
F.bitmapXorCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
4,
)
self._test_func(
F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
2,
)
def test_hash_functions(self):
args = ['x', 'y', 'z']
args = ["x", "y", "z"]
x = 17
s = 'hello'
url = 'http://example.com/a/b/c/d'
s = "hello"
url = "http://example.com/a/b/c/d"
self._test_func(F.hex(F.MD5(s)))
self._test_func(F.hex(F.sipHash128(s)))
self._test_func(F.hex(F.cityHash64(*args)))
@ -599,12 +694,13 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.randConstant(17))
def test_encoding_functions(self):
self._test_func(F.hex(F.unhex('0FA1')), '0FA1')
self._test_func(F.hex(F.unhex("0FA1")), "0FA1")
self._test_func(F.bitmaskToArray(17))
self._test_func(F.bitmaskToList(18))
def test_uuid_functions(self):
from uuid import UUID
uuid = self._test_func(F.generateUUIDv4())
self.assertEqual(type(uuid), UUID)
s = str(uuid)
@ -612,17 +708,30 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.UUIDNumToString(F.UUIDStringToNum(s)), s)
def test_ip_funcs(self):
self._test_func(F.IPv4NumToString(F.toUInt32(1)), '0.0.0.1')
self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), '0.0.0.xxx')
self._test_func(F.IPv4StringToNum('0.0.0.17'), 17)
self._test_func(F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum('192.168.0.1'))), '::ffff:192.168.0.1')
self._test_func(F.IPv6NumToString(F.IPv6StringToNum('2a02:6b8::11')), '2a02:6b8::11')
self._test_func(F.toIPv4('10.20.30.40'), IPv4Address('10.20.30.40'))
self._test_func(F.toIPv6('2001:438:ffff::407d:1bc1'), IPv6Address('2001:438:ffff::407d:1bc1'))
self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16),
[IPv4Address('192.168.0.0'), IPv4Address('192.168.255.255')])
self._test_func(F.IPv6CIDRToRange(F.toIPv6('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 32),
[IPv6Address('2001:db8::'), IPv6Address('2001:db8:ffff:ffff:ffff:ffff:ffff:ffff')])
self._test_func(F.IPv4NumToString(F.toUInt32(1)), "0.0.0.1")
self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), "0.0.0.xxx")
self._test_func(F.IPv4StringToNum("0.0.0.17"), 17)
self._test_func(
F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum("192.168.0.1"))),
"::ffff:192.168.0.1",
)
self._test_func(F.IPv6NumToString(F.IPv6StringToNum("2a02:6b8::11")), "2a02:6b8::11")
self._test_func(F.toIPv4("10.20.30.40"), IPv4Address("10.20.30.40"))
self._test_func(
F.toIPv6("2001:438:ffff::407d:1bc1"),
IPv6Address("2001:438:ffff::407d:1bc1"),
)
self._test_func(
F.IPv4CIDRToRange(F.toIPv4("192.168.5.2"), 16),
[IPv4Address("192.168.0.0"), IPv4Address("192.168.255.255")],
)
self._test_func(
F.IPv6CIDRToRange(F.toIPv6("2001:0db8:0000:85a3:0000:0000:ac1f:8001"), 32),
[
IPv6Address("2001:db8::"),
IPv6Address("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
],
)
def test_aggregate_funcs(self):
self._test_aggr(F.any(Person.first_name))
@ -630,7 +739,10 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.anyLast(Person.first_name))
self._test_aggr(F.argMin(Person.first_name, Person.height))
self._test_aggr(F.argMax(Person.first_name, Person.height))
self._test_aggr(F.round(F.avg(Person.height), 4), sum(p.height for p in self._sample_data()) / 100)
self._test_aggr(
F.round(F.avg(Person.height), 4),
sum(p.height for p in self._sample_data()) / 100,
)
self._test_aggr(F.corr(Person.height, Person.height), 1)
self._test_aggr(F.count(), 100)
self._test_aggr(F.round(F.covarPop(Person.height, Person.height), 2), 0)
@ -649,32 +761,32 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.varSamp(Person.height))
def test_aggregate_funcs__or_default(self):
self.database.raw('TRUNCATE TABLE person')
self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrDefault(), 0)
self._test_aggr(F.maxOrDefault(Person.height), 0)
def test_aggregate_funcs__or_null(self):
self.database.raw('TRUNCATE TABLE person')
self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrNull(), None)
self._test_aggr(F.maxOrNull(Person.height), None)
def test_aggregate_funcs__if(self):
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H'))
self._test_aggr(F.countIf(Person.last_name > 'H'), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6)
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > "H"))
self._test_aggr(F.countIf(Person.last_name > "H"), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > "H"), 1.6)
def test_aggregate_funcs__or_default_if(self):
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0)
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrDefaultIf(Person.last_name > "Z"), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > "Z"), 0)
def test_aggregate_funcs__or_null_if(self):
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z'))
self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None)
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrNullIf(Person.last_name > "Z"), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > "Z"), None)
def test_quantile_funcs(self):
cond = Person.last_name > 'H'
cond = Person.last_name > "H"
weight_expr = F.toUInt32(F.round(Person.height))
# Quantile
self._test_aggr(F.quantile(0.9)(Person.height))
@ -712,13 +824,13 @@ class FuncsTestCase(TestCaseWithData):
def test_top_k_funcs(self):
self._test_aggr(F.topK(3)(Person.height))
self._test_aggr(F.topKOrDefault(3)(Person.height))
self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > 'H'))
self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > 'H'))
self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > "H"))
self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > "H"))
weight_expr = F.toUInt32(F.round(Person.height))
self._test_aggr(F.topKWeighted(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedOrDefault(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > 'H'))
self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > 'H'))
self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > "H"))
self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > "H"))
def test_null_funcs(self):
self._test_func(F.ifNull(17, 18), 17)

View File

@ -1,14 +1,14 @@
import unittest
from infi.clickhouse_orm import *
from clickhouse_orm import Database, F, Index, MergeTree, Model
from clickhouse_orm.fields import DateField, Int32Field, StringField
class IndexesTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old')
raise unittest.SkipTest("ClickHouse version too old")
def tearDown(self):
self.database.drop_database()
@ -29,4 +29,4 @@ class ModelWithIndexes(Model):
i4 = Index(F.lower(f2), type=Index.tokenbf_v1(256, 2, 0), granularity=2)
i5 = Index((F.toQuarter(date), f2), type=Index.bloom_filter(), granularity=3)
engine = MergeTree('date', ('date',))
engine = MergeTree("date", ("date",))

Some files were not shown because too many files have changed in this diff Show More