This commit is contained in:
Beda Kosata 2021-10-04 08:08:24 +00:00 committed by GitHub
commit 4661f5ba55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
111 changed files with 4739 additions and 4495 deletions

40
.github/workflows/python-publish.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Upload to PIP
# Controls when the action will run.
on:
# Triggers the workflow when a release is created
release:
types: [created]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "upload"
upload:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Sets up python
- uses: actions/setup-python@v2
with:
python-version: 3.9
# Install dependencies
- name: Install dependencies
run: |
python -m pip install poetry
poetry install --no-dev
# Build and upload to PyPI
- name: Builds and upload to PyPI
run: |
poetry config pypi-token.pypi "$TWINE_TOKEN"
poetry publish --build
env:
TWINE_TOKEN: ${{ secrets.TWINE_TOKEN }}

89
.github/workflows/python-test.yml vendored Normal file
View File

@ -0,0 +1,89 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Tests
on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
jobs:
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
# Lint on smallest and largest active versions
python-version: [3.6, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Lint with flake8
run: |
poetry run flake8
- name: Check formatting with black
run: |
poetry run black --check .
test:
runs-on: ubuntu-latest
services:
clickhouse:
image: yandex/clickhouse-server:21.3
ports:
- 8123:8123
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9, pypy3]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Test with pytest
run: |
poetry run pytest
test_compat:
# Tests compatibility with an older LTS release of clickhouse
runs-on: ubuntu-latest
services:
clickhouse:
image: yandex/clickhouse-server:20.8
ports:
- 8123:8123
strategy:
fail-fast: false
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install poetry
poetry install
- name: Test with pytest
run: |
poetry run pytest

59
.gitignore vendored
View File

@ -63,3 +63,62 @@ cover/
# tox # tox
.tox/ .tox/
# misc
*.7z
*.apk
*.backup
*.bak
*.bk
*.bz2
*.deb
*.doc
*.docx
*.gz
*.gzip
*.img
*.iso
*.jar
*.jpeg
*.jpg
*.log
*.ods
*.part
*.pdf
*.pkg
*.png
*.pps
*.ppsx
*.ppt
*.pptx
*.ps
*.pyc
*.rar
*.swp
*.sys
*.tar
*.tgz
*.tmp
*.xls
*.xlsx
*.xz
*.zip
**/*venv/
.cache
.coverage*
.idea/
.isort.cfg
**.directory
venv
.pytest_cache/
.vscode/
*.egg-info/
.tox/
.cargo/
.expected
.hypothesis/
.mypy_cache/
**/__pycache__/
# poetry
poetry.lock

1073
.noseids

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,26 @@
Change Log Change Log
========== ==========
v2.2.2
------
- Unpined requirements to enhance compatability
v2.2.1
------
- Minor tooling changes for PyPI
- Better project description for PyPI
v2.2.0
------
- Support up to clickhouse 20.12, including LTS 20.8 release
- Fixed boolean logic for Q objects (https://github.com/Infinidat/infi.clickhouse_orm/issues/158)
- Remove implicit use of '\N' character (which was causing deprecation warnings in queries)
- Tooling updates: use poetry, pytest, isort, black
**Backwards incompatible changes**
You can no longer supply a codec for an `alias` field. Previously this had no effect in clickhouse, but now it explicitly returns an error.
v2.1.0 v2.1.0
------ ------
- Support for model constraints - Support for model constraints

View File

@ -1,3 +1,4 @@
Copyright (c) 2021 Suade Labs
Copyright (c) 2017 INFINIDAT Copyright (c) 2017 INFINIDAT
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

View File

@ -1,3 +1,8 @@
A fork of [infi.clikchouse_orm](https://github.com/Infinidat/infi.clickhouse_orm) aimed at more frequent maintenance and bugfixes.
[![Tests](https://github.com/SuadeLabs/clickhouse_orm/actions/workflows/python-test.yml/badge.svg)](https://github.com/SuadeLabs/clickhouse_orm/actions/workflows/python-test.yml)
![PyPI](https://img.shields.io/pypi/v/clickhouse_orm)
Introduction Introduction
============ ============
@ -8,7 +13,7 @@ Let's jump right in with a simple example of monitoring CPU usage. First we need
connect to the database and create a table for the model: connect to the database and create a table for the model:
```python ```python
from infi.clickhouse_orm import Database, Model, DateTimeField, UInt16Field, Float32Field, Memory, F from clickhouse_orm import Database, Model, DateTimeField, UInt16Field, Float32Field, Memory, F
class CPUStats(Model): class CPUStats(Model):

View File

@ -1,68 +0,0 @@
[buildout]
prefer-final = false
newest = false
download-cache = .cache
develop = .
parts =
relative-paths = true
[project]
name = infi.clickhouse_orm
company = Infinidat
namespace_packages = ['infi']
install_requires = [
'iso8601 >= 0.1.12',
'pytz',
'requests',
'setuptools'
]
version_file = src/infi/clickhouse_orm/__version__.py
description = A Python library for working with the ClickHouse database
long_description = A Python library for working with the ClickHouse database
console_scripts = []
gui_scripts = []
package_data = []
upgrade_code = {58530fba-3932-11e6-a20e-7071bc32067f}
product_name = infi.clickhouse_orm
post_install_script_name = None
pre_uninstall_script_name = None
homepage = https://github.com/Infinidat/infi.clickhouse_orm
[isolated-python]
recipe = infi.recipe.python
version = v3.8.0.2
[setup.py]
recipe = infi.recipe.template.version
input = setup.in
output = setup.py
[__version__.py]
recipe = infi.recipe.template.version
output = ${project:version_file}
[development-scripts]
dependent-scripts = true
recipe = infi.recipe.console_scripts
eggs = ${project:name}
ipython<6
nose
coverage
enum-compat
infi.unittest
infi.traceback
memory_profiler
profilehooks
psutil
zc.buildout
scripts = ipython
nosetests
interpreter = python
[pack]
recipe = infi.recipe.application_packager
[sublime]
recipe = corneti.recipes.codeintel
eggs = ${development-scripts:eggs}

View File

@ -0,0 +1,12 @@
from inspect import isclass
from .database import * # noqa: F401, F403
from .engines import * # noqa: F401, F403
from .fields import * # noqa: F401, F403
from .funcs import * # noqa: F401, F403
from .migrations import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .query import * # noqa: F401, F403
from .system_models import * # noqa: F401, F403
__all__ = [c.__name__ for c in locals().values() if isclass(c)]

View File

@ -1,26 +1,23 @@
from __future__ import unicode_literals
import re
import requests
from collections import namedtuple
from .models import ModelBase
from .utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime import datetime
from string import Template
import pytz
import logging import logging
logger = logging.getLogger('clickhouse_orm') import re
from math import ceil
from string import Template
import pytz
import requests
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') from .models import ModelBase
from .utils import Page, import_submodules, parse_tsv
logger = logging.getLogger("clickhouse_orm")
class DatabaseException(Exception): class DatabaseException(Exception):
''' """
Raised when a database operation fails. Raised when a database operation fails.
''' """
pass pass
@ -28,6 +25,7 @@ class ServerError(DatabaseException):
""" """
Raised when a server returns an error. Raised when a server returns an error.
""" """
def __init__(self, message): def __init__(self, message):
self.code = None self.code = None
processed = self.get_error_code_msg(message) processed = self.get_error_code_msg(message)
@ -41,16 +39,22 @@ class ServerError(DatabaseException):
ERROR_PATTERNS = ( ERROR_PATTERNS = (
# ClickHouse prior to v19.3.3 # ClickHouse prior to v19.3.3
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?), \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
\ e.what\(\)\ =\ (?P<type2>[^ \n]+) \ e.what\(\)\ =\ (?P<type2>[^ \n]+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
# ClickHouse v19.3.3+ # ClickHouse v19.3.3+
re.compile(r''' re.compile(
r"""
Code:\ (?P<code>\d+), Code:\ (?P<code>\d+),
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+) \ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+)
''', re.VERBOSE | re.DOTALL), """,
re.VERBOSE | re.DOTALL,
),
) )
@classmethod @classmethod
@ -65,7 +69,7 @@ class ServerError(DatabaseException):
match = pattern.match(full_error_message) match = pattern.match(full_error_message)
if match: if match:
# assert match.group('type1') == match.group('type2') # assert match.group('type1') == match.group('type2')
return int(match.group('code')), match.group('msg').strip() return int(match.group("code")), match.group("msg").strip()
return 0, full_error_message return 0, full_error_message
@ -75,15 +79,24 @@ class ServerError(DatabaseException):
class Database(object): class Database(object):
''' """
Database instances connect to a specific ClickHouse database for running queries, Database instances connect to a specific ClickHouse database for running queries,
inserting data and other operations. inserting data and other operations.
''' """
def __init__(self, db_name, db_url='http://localhost:8123/', def __init__(
username=None, password=None, readonly=False, autocreate=True, self,
timeout=60, verify_ssl_cert=True, log_statements=False): db_name,
''' db_url="http://localhost:8123/",
username=None,
password=None,
readonly=False,
autocreate=True,
timeout=60,
verify_ssl_cert=True,
log_statements=False,
):
"""
Initializes a database instance. Unless it's readonly, the database will be Initializes a database instance. Unless it's readonly, the database will be
created on the ClickHouse server if it does not already exist. created on the ClickHouse server if it does not already exist.
@ -96,7 +109,7 @@ class Database(object):
- `timeout`: the connection timeout in seconds. - `timeout`: the connection timeout in seconds.
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS. - `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
- `log_statements`: when True, all database statements are logged. - `log_statements`: when True, all database statements are logged.
''' """
self.db_name = db_name self.db_name = db_name
self.db_url = db_url self.db_url = db_url
self.readonly = False self.readonly = False
@ -104,14 +117,14 @@ class Database(object):
self.request_session = requests.Session() self.request_session = requests.Session()
self.request_session.verify = verify_ssl_cert self.request_session.verify = verify_ssl_cert
if username: if username:
self.request_session.auth = (username, password or '') self.request_session.auth = (username, password or "")
self.log_statements = log_statements self.log_statements = log_statements
self.settings = {} self.settings = {}
self.db_exists = False # this is required before running _is_existing_database self.db_exists = False # this is required before running _is_existing_database
self.db_exists = self._is_existing_database() self.db_exists = self._is_existing_database()
if readonly: if readonly:
if not self.db_exists: if not self.db_exists:
raise DatabaseException('Database does not exist, and cannot be created under readonly connection') raise DatabaseException("Database does not exist, and cannot be created under readonly connection")
self.connection_readonly = self._is_connection_readonly() self.connection_readonly = self._is_connection_readonly()
self.readonly = True self.readonly = True
elif autocreate and not self.db_exists: elif autocreate and not self.db_exists:
@ -125,56 +138,56 @@ class Database(object):
self.has_low_cardinality_support = self.server_version >= (19, 0) self.has_low_cardinality_support = self.server_version >= (19, 0)
def create_database(self): def create_database(self):
''' """
Creates the database on the ClickHouse server if it does not already exist. Creates the database on the ClickHouse server if it does not already exist.
''' """
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self._send("CREATE DATABASE IF NOT EXISTS `%s`" % self.db_name)
self.db_exists = True self.db_exists = True
def drop_database(self): def drop_database(self):
''' """
Deletes the database on the ClickHouse server. Deletes the database on the ClickHouse server.
''' """
self._send('DROP DATABASE `%s`' % self.db_name) self._send("DROP DATABASE `%s`" % self.db_name)
self.db_exists = False self.db_exists = False
def create_table(self, model_class): def create_table(self, model_class):
''' """
Creates a table for the given model class, if it does not exist already. Creates a table for the given model class, if it does not exist already.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't create system table") raise DatabaseException("You can't create system table")
if getattr(model_class, 'engine') is None: if model_class.engine is None:
raise DatabaseException("%s class must define an engine" % model_class.__name__) raise DatabaseException("%s class must define an engine" % model_class.__name__)
self._send(model_class.create_table_sql(self)) self._send(model_class.create_table_sql(self))
def drop_table(self, model_class): def drop_table(self, model_class):
''' """
Drops the database table of the given model class, if it exists. Drops the database table of the given model class, if it exists.
''' """
if model_class.is_system_model(): if model_class.is_system_model():
raise DatabaseException("You can't drop system table") raise DatabaseException("You can't drop system table")
self._send(model_class.drop_table_sql(self)) self._send(model_class.drop_table_sql(self))
def does_table_exist(self, model_class): def does_table_exist(self, model_class):
''' """
Checks whether a table for the given model class already exists. Checks whether a table for the given model class already exists.
Note that this only checks for existence of a table with the expected name. Note that this only checks for existence of a table with the expected name.
''' """
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'" sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
r = self._send(sql % (self.db_name, model_class.table_name())) r = self._send(sql % (self.db_name, model_class.table_name()))
return r.text.strip() == '1' return r.text.strip() == "1"
def get_model_for_table(self, table_name, system_table=False): def get_model_for_table(self, table_name, system_table=False):
''' """
Generates a model class from an existing table in the database. Generates a model class from an existing table in the database.
This can be used for querying tables which don't have a corresponding model class, This can be used for querying tables which don't have a corresponding model class,
for example system tables. for example system tables.
- `table_name`: the table to create a model for - `table_name`: the table to create a model for
- `system_table`: whether the table is a system table, or belongs to the current database - `system_table`: whether the table is a system table, or belongs to the current database
''' """
db_name = 'system' if system_table else self.db_name db_name = "system" if system_table else self.db_name
sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name) sql = "DESCRIBE `%s`.`%s` FORMAT TSV" % (db_name, table_name)
lines = self._send(sql).iter_lines() lines = self._send(sql).iter_lines()
fields = [parse_tsv(line)[:2] for line in lines] fields = [parse_tsv(line)[:2] for line in lines]
@ -184,27 +197,28 @@ class Database(object):
return model return model
def add_setting(self, name, value): def add_setting(self, name, value):
''' """
Adds a database setting that will be sent with every request. Adds a database setting that will be sent with every request.
For example, `db.add_setting("max_execution_time", 10)` will For example, `db.add_setting("max_execution_time", 10)` will
limit query execution time to 10 seconds. limit query execution time to 10 seconds.
The name must be string, and the value is converted to string in case The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value. it isn't. To remove a setting, pass `None` as the value.
''' """
assert isinstance(name, str), 'Setting name must be a string' assert isinstance(name, str), "Setting name must be a string"
if value is None: if value is None:
self.settings.pop(name, None) self.settings.pop(name, None)
else: else:
self.settings[name] = str(value) self.settings[name] = str(value)
def insert(self, model_instances, batch_size=1000): def insert(self, model_instances, batch_size=1000):
''' """
Insert records into the database. Insert records into the database.
- `model_instances`: any iterable containing instances of a single model class. - `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large). - `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
''' """
from io import BytesIO from io import BytesIO
i = iter(model_instances) i = iter(model_instances)
try: try:
first_instance = next(i) first_instance = next(i)
@ -215,14 +229,13 @@ class Database(object):
if first_instance.is_read_only() or first_instance.is_system_model(): if first_instance.is_read_only() or first_instance.is_system_model():
raise DatabaseException("You can't insert into read only and system tables") raise DatabaseException("You can't insert into read only and system tables")
fields_list = ','.join( fields_list = ",".join(["`%s`" % name for name in first_instance.fields(writable=True)])
['`%s`' % name for name in first_instance.fields(writable=True)]) fmt = "TSKV" if model_class.has_funcs_as_defaults() else "TabSeparated"
fmt = 'TSKV' if model_class.has_funcs_as_defaults() else 'TabSeparated' query = "INSERT INTO $table (%s) FORMAT %s\n" % (fields_list, fmt)
query = 'INSERT INTO $table (%s) FORMAT %s\n' % (fields_list, fmt)
def gen(): def gen():
buf = BytesIO() buf = BytesIO()
buf.write(self._substitute(query, model_class).encode('utf-8')) buf.write(self._substitute(query, model_class).encode("utf-8"))
first_instance.set_database(self) first_instance.set_database(self)
buf.write(first_instance.to_db_string()) buf.write(first_instance.to_db_string())
# Collect lines in batches of batch_size # Collect lines in batches of batch_size
@ -240,35 +253,37 @@ class Database(object):
# Return any remaining lines in partial batch # Return any remaining lines in partial batch
if lines: if lines:
yield buf.getvalue() yield buf.getvalue()
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):
''' """
Counts the number of records in the model's table. Counts the number of records in the model's table.
- `model_class`: the model to count. - `model_class`: the model to count.
- `conditions`: optional SQL conditions (contents of the WHERE clause). - `conditions`: optional SQL conditions (contents of the WHERE clause).
''' """
from infi.clickhouse_orm.query import Q from clickhouse_orm.query import Q
query = 'SELECT count() FROM $table'
query = "SELECT count() FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query) r = self._send(query)
return int(r.text) if r.text else 0 return int(r.text) if r.text else 0
def select(self, query, model_class=None, settings=None): def select(self, query, model_class=None, settings=None):
''' """
Performs a query and returns a generator of model instances. Performs a query and returns a generator of model instances.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
or `None` for getting back instances of an ad-hoc model. or `None` for getting back instances of an ad-hoc model.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
''' """
query += ' FORMAT TabSeparatedWithNamesAndTypes' query += " FORMAT TabSeparatedWithNamesAndTypes"
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
r = self._send(query, settings, True) r = self._send(query, settings, True)
lines = r.iter_lines() lines = r.iter_lines()
@ -281,18 +296,18 @@ class Database(object):
yield model_class.from_tsv(line, field_names, self.server_timezone, self) yield model_class.from_tsv(line, field_names, self.server_timezone, self)
def raw(self, query, settings=None, stream=False): def raw(self, query, settings=None, stream=False):
''' """
Performs a query and returns its output as text. Performs a query and returns its output as text.
- `query`: the SQL query to execute. - `query`: the SQL query to execute.
- `settings`: query settings to send as HTTP GET parameters - `settings`: query settings to send as HTTP GET parameters
- `stream`: if true, the HTTP response from ClickHouse will be streamed. - `stream`: if true, the HTTP response from ClickHouse will be streamed.
''' """
query = self._substitute(query, None) query = self._substitute(query, None)
return self._send(query, settings=settings, stream=stream).text return self._send(query, settings=settings, stream=stream).text
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None): def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
''' """
Selects records and returns a single page of model instances. Selects records and returns a single page of model instances.
- `model_class`: the model class matching the query's table, - `model_class`: the model class matching the query's table,
@ -305,54 +320,63 @@ class Database(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`, The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
''' """
from infi.clickhouse_orm.query import Q from clickhouse_orm.query import Q
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
if page_num == -1: if page_num == -1:
page_num = max(pages_total, 1) page_num = max(pages_total, 1)
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
query = 'SELECT * FROM $table' query = "SELECT * FROM $table"
if conditions: if conditions:
if isinstance(conditions, Q): if isinstance(conditions, Q):
conditions = conditions.to_sql(model_class) conditions = conditions.to_sql(model_class)
query += ' WHERE ' + str(conditions) query += " WHERE " + str(conditions)
query += ' ORDER BY %s' % order_by query += " ORDER BY %s" % order_by
query += ' LIMIT %d, %d' % (offset, page_size) query += " LIMIT %d, %d" % (offset, page_size)
query = self._substitute(query, model_class) query = self._substitute(query, model_class)
return Page( return Page(
objects=list(self.select(query, model_class, settings)) if count else [], objects=list(self.select(query, model_class, settings)) if count else [],
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def migrate(self, migrations_package_name, up_to=9999): def migrate(self, migrations_package_name, up_to=9999):
''' """
Executes schema migrations. Executes schema migrations.
- `migrations_package_name` - fully qualified name of the Python package - `migrations_package_name` - fully qualified name of the Python package
containing the migrations. containing the migrations.
- `up_to` - number of the last migration to apply. - `up_to` - number of the last migration to apply.
''' """
from .migrations import MigrationHistory from .migrations import MigrationHistory
logger = logging.getLogger('migrations')
logger = logging.getLogger("migrations")
applied_migrations = self._get_applied_migrations(migrations_package_name) applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name) modules = import_submodules(migrations_package_name)
unapplied_migrations = set(modules.keys()) - applied_migrations unapplied_migrations = set(modules.keys()) - applied_migrations
for name in sorted(unapplied_migrations): for name in sorted(unapplied_migrations):
logger.info('Applying migration %s...', name) logger.info("Applying migration %s...", name)
for operation in modules[name].operations: for operation in modules[name].operations:
operation.apply(self) operation.apply(self)
self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())]) self.insert(
[
MigrationHistory(
package_name=migrations_package_name, module_name=name, applied=datetime.date.today()
)
]
)
if int(name[:4]) >= up_to: if int(name[:4]) >= up_to:
break break
def _get_applied_migrations(self, migrations_package_name): def _get_applied_migrations(self, migrations_package_name):
from .migrations import MigrationHistory from .migrations import MigrationHistory
self.create_table(MigrationHistory) self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory) query = self._substitute(query, MigrationHistory)
@ -360,7 +384,7 @@ class Database(object):
def _send(self, data, settings=None, stream=False): def _send(self, data, settings=None, stream=False):
if isinstance(data, str): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode("utf-8")
if self.log_statements: if self.log_statements:
logger.info(data) logger.info(data)
params = self._build_params(settings) params = self._build_params(settings)
@ -373,50 +397,50 @@ class Database(object):
params = dict(settings or {}) params = dict(settings or {})
params.update(self.settings) params.update(self.settings)
if self.db_exists: if self.db_exists:
params['database'] = self.db_name params["database"] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error) # Send the readonly flag, unless the connection is already readonly (to prevent db error)
if self.readonly and not self.connection_readonly: if self.readonly and not self.connection_readonly:
params['readonly'] = '1' params["readonly"] = "1"
return params return params
def _substitute(self, query, model_class=None): def _substitute(self, query, model_class=None):
''' """
Replaces $db and $table placeholders in the query. Replaces $db and $table placeholders in the query.
''' """
if '$' in query: if "$" in query:
mapping = dict(db="`%s`" % self.db_name) mapping = dict(db="`%s`" % self.db_name)
if model_class: if model_class:
if model_class.is_system_model(): if model_class.is_system_model():
mapping['table'] = "`system`.`%s`" % model_class.table_name() mapping["table"] = "`system`.`%s`" % model_class.table_name()
else: else:
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) mapping["table"] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).safe_substitute(mapping) query = Template(query).safe_substitute(mapping)
return query return query
def _get_server_timezone(self): def _get_server_timezone(self):
try: try:
r = self._send('SELECT timezone()') r = self._send("SELECT timezone()")
return pytz.timezone(r.text.strip()) return pytz.timezone(r.text.strip())
except ServerError as e: except ServerError as e:
logger.exception('Cannot determine server timezone (%s), assuming UTC', e) logger.exception("Cannot determine server timezone (%s), assuming UTC", e)
return pytz.utc return pytz.utc
def _get_server_version(self, as_tuple=True): def _get_server_version(self, as_tuple=True):
try: try:
r = self._send('SELECT version();') r = self._send("SELECT version();")
ver = r.text ver = r.text
except ServerError as e: except ServerError as e:
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e) logger.exception("Cannot determine server version (%s), assuming 1.1.0", e)
ver = '1.1.0' ver = "1.1.0"
return tuple(int(n) for n in ver.split('.')) if as_tuple else ver return tuple(int(n) for n in ver.split(".")) if as_tuple else ver
def _is_existing_database(self): def _is_existing_database(self):
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name) r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
return r.text.strip() == '1' return r.text.strip() == "1"
def _is_connection_readonly(self): def _is_connection_readonly(self):
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'") r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
return r.text.strip() != '0' return r.text.strip() != "0"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -1,54 +1,60 @@
from __future__ import unicode_literals
import logging import logging
from .utils import comma_join, get_subclass_names from .utils import comma_join, get_subclass_names
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger("clickhouse_orm")
class Engine(object): class Engine(object):
def create_table_sql(self, db): def create_table_sql(self, db):
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class TinyLog(Engine): class TinyLog(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'TinyLog' return "TinyLog"
class Log(Engine): class Log(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Log' return "Log"
class Memory(Engine): class Memory(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
return 'Memory' return "Memory"
class MergeTree(Engine): class MergeTree(Engine):
def __init__(
def __init__(self, date_col=None, order_by=(), sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple' sampling_expr=None,
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present' index_granularity=8192,
assert primary_key is None or type(primary_key) in (list, tuple), 'primary_key must be a list or tuple' replica_table_path=None,
assert partition_key is None or type(partition_key) in (list, tuple),\ replica_name=None,
'partition_key must be tuple or list if present' partition_key=None,
assert (replica_table_path is None) == (replica_name is None), \ primary_key=None,
'both replica_table_path and replica_name must be specified' settings=None,
):
assert type(order_by) in (list, tuple), "order_by must be a list or tuple"
assert date_col is None or isinstance(date_col, str), "date_col must be string if present"
assert primary_key is None or type(primary_key) in (list, tuple), "primary_key must be a list or tuple"
assert partition_key is None or type(partition_key) in (
list,
tuple,
), "partition_key must be tuple or list if present"
assert (replica_table_path is None) == (
replica_name is None
), "both replica_table_path and replica_name must be specified"
assert settings is None or type(settings) is dict, 'settings must be dict'
# These values conflict with each other (old and new syntax of table engines. # These values conflict with each other (old and new syntax of table engines.
# So let's control only one of them is given. # So let's control only one of them is given.
assert date_col or partition_key, "You must set either date_col or partition_key" assert date_col or partition_key, "You must set either date_col or partition_key"
self.date_col = date_col self.date_col = date_col
self.partition_key = partition_key if partition_key else ('toYYYYMM(`%s`)' % date_col,) self.partition_key = partition_key if partition_key else ("toYYYYMM(`%s`)" % date_col,)
self.primary_key = primary_key self.primary_key = primary_key
self.order_by = order_by self.order_by = order_by
@ -56,50 +62,62 @@ class MergeTree(Engine):
self.index_granularity = index_granularity self.index_granularity = index_granularity
self.replica_table_path = replica_table_path self.replica_table_path = replica_table_path
self.replica_name = replica_name self.replica_name = replica_name
self.settings = settings
# I changed field name for new reality and syntax # I changed field name for new reality and syntax
@property @property
def key_cols(self): def key_cols(self):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead') logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
return self.order_by return self.order_by
@key_cols.setter @key_cols.setter
def key_cols(self, value): def key_cols(self, value):
logger.warning('`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead') logger.warning(
"`key_cols` attribute is deprecated and may be removed in future. Use `order_by` attribute instead"
)
self.order_by = value self.order_by = value
def create_table_sql(self, db): def create_table_sql(self, db):
name = self.__class__.__name__ name = self.__class__.__name__
if self.replica_name: if self.replica_name:
name = 'Replicated' + name name = "Replicated" + name
# In ClickHouse 1.1.54310 custom partitioning key was introduced # In ClickHouse 1.1.54310 custom partitioning key was introduced
# https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/ # https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/
# Let's check version and use new syntax if available # Let's check version and use new syntax if available
if db.server_version >= (1, 1, 54310): if db.server_version >= (1, 1, 54310):
partition_sql = "PARTITION BY (%s) ORDER BY (%s)" \ partition_sql = "PARTITION BY (%s) ORDER BY (%s)" % (
% (comma_join(self.partition_key, stringify=True), comma_join(map(str, self.partition_key)),
comma_join(self.order_by, stringify=True)) comma_join(map(str, self.order_by)),
)
if self.primary_key: if self.primary_key:
partition_sql += " PRIMARY KEY (%s)" % comma_join(self.primary_key, stringify=True) partition_sql += " PRIMARY KEY (%s)" % comma_join(map(str, self.primary_key))
if self.sampling_expr: if self.sampling_expr:
partition_sql += " SAMPLE BY %s" % self.sampling_expr partition_sql += " SAMPLE BY %s" % self.sampling_expr
partition_sql += " SETTINGS index_granularity=%d" % self.index_granularity partition_sql += " SETTINGS index_granularity=%d" % self.index_granularity
if self.settings:
settings_sql = ", ".join('%s=%s' % (key, value) for key, value in self.settings.items())
partition_sql += ", " + settings_sql
elif not self.date_col: elif not self.date_col:
# Can't import it globally due to circular import # Can't import it globally due to circular import
from infi.clickhouse_orm.database import DatabaseException from clickhouse_orm.database import DatabaseException
raise DatabaseException("Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax." raise DatabaseException(
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/") "Custom partitioning is not supported before ClickHouse 1.1.54310. "
"Please update your server or use date_col syntax."
"https://clickhouse.tech/docs/en/table_engines/custom_partitioning_key/"
)
else: else:
partition_sql = '' partition_sql = ""
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s) %s' % (name, comma_join(params), partition_sql) return "%s(%s) %s" % (name, comma_join(params), partition_sql)
def _build_sql_params(self, db): def _build_sql_params(self, db):
params = [] params = []
@ -114,19 +132,37 @@ class MergeTree(Engine):
params.append(self.date_col) params.append(self.date_col)
if self.sampling_expr: if self.sampling_expr:
params.append(self.sampling_expr) params.append(self.sampling_expr)
params.append('(%s)' % comma_join(self.order_by, stringify=True)) params.append("(%s)" % comma_join(map(str(self.order_by))))
params.append(str(self.index_granularity)) params.append(str(self.index_granularity))
return params return params
class CollapsingMergeTree(MergeTree): class CollapsingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), sign_col='sign', sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(CollapsingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, sign_col="sign",
replica_table_path, replica_name, partition_key, primary_key) sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
settings=None,
):
super(CollapsingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
settings=settings,
)
self.sign_col = sign_col self.sign_col = sign_col
def _build_sql_params(self, db): def _build_sql_params(self, db):
@ -136,29 +172,61 @@ class CollapsingMergeTree(MergeTree):
class SummingMergeTree(MergeTree): class SummingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), summing_cols=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(SummingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, replica_table_path, summing_cols=None,
replica_name, partition_key, primary_key) sampling_expr=None,
assert type is None or type(summing_cols) in (list, tuple), 'summing_cols must be a list or tuple' index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(SummingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
assert type is None or type(summing_cols) in (list, tuple), "summing_cols must be a list or tuple"
self.summing_cols = summing_cols self.summing_cols = summing_cols
def _build_sql_params(self, db): def _build_sql_params(self, db):
params = super(SummingMergeTree, self)._build_sql_params(db) params = super(SummingMergeTree, self)._build_sql_params(db)
if self.summing_cols: if self.summing_cols:
params.append('(%s)' % comma_join(self.summing_cols)) params.append("(%s)" % comma_join(self.summing_cols))
return params return params
class ReplacingMergeTree(MergeTree): class ReplacingMergeTree(MergeTree):
def __init__(
def __init__(self, date_col=None, order_by=(), ver_col=None, sampling_expr=None, self,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, date_col=None,
primary_key=None): order_by=(),
super(ReplacingMergeTree, self).__init__(date_col, order_by, sampling_expr, index_granularity, ver_col=None,
replica_table_path, replica_name, partition_key, primary_key) sampling_expr=None,
index_granularity=8192,
replica_table_path=None,
replica_name=None,
partition_key=None,
primary_key=None,
):
super(ReplacingMergeTree, self).__init__(
date_col,
order_by,
sampling_expr,
index_granularity,
replica_table_path,
replica_name,
partition_key,
primary_key,
)
self.ver_col = ver_col self.ver_col = ver_col
def _build_sql_params(self, db): def _build_sql_params(self, db):
@ -175,9 +243,18 @@ class Buffer(Engine):
Read more [here](https://clickhouse.tech/docs/en/engines/table-engines/special/buffer/). Read more [here](https://clickhouse.tech/docs/en/engines/table-engines/special/buffer/).
""" """
#Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes) # Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes)
def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, def __init__(
min_bytes=10000000, max_bytes=100000000): self,
main_model,
num_layers=16,
min_time=10,
max_time=100,
min_rows=10000,
max_rows=1000000,
min_bytes=10000000,
max_bytes=100000000,
):
self.main_model = main_model self.main_model = main_model
self.num_layers = num_layers self.num_layers = num_layers
self.min_time = min_time self.min_time = min_time
@ -190,11 +267,17 @@ class Buffer(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
# Overriden create_table_sql example: # Overriden create_table_sql example:
# sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' # sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)'
sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( sql = "ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)" % (
db.db_name, self.main_model.table_name(), self.num_layers, db.db_name,
self.min_time, self.max_time, self.min_rows, self.main_model.table_name(),
self.max_rows, self.min_bytes, self.max_bytes self.num_layers,
) self.min_time,
self.max_time,
self.min_rows,
self.max_rows,
self.min_bytes,
self.max_bytes,
)
return sql return sql
@ -224,6 +307,7 @@ class Distributed(Engine):
See full documentation here See full documentation here
https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/ https://clickhouse.tech/docs/en/engines/table-engines/special/distributed/
""" """
def __init__(self, cluster, table=None, sharding_key=None): def __init__(self, cluster, table=None, sharding_key=None):
""" """
- `cluster`: what cluster to access data from - `cluster`: what cluster to access data from
@ -252,12 +336,11 @@ class Distributed(Engine):
def create_table_sql(self, db): def create_table_sql(self, db):
name = self.__class__.__name__ name = self.__class__.__name__
params = self._build_sql_params(db) params = self._build_sql_params(db)
return '%s(%s)' % (name, ', '.join(params)) return "%s(%s)" % (name, ", ".join(params))
def _build_sql_params(self, db): def _build_sql_params(self, db):
if self.table_name is None: if self.table_name is None:
raise ValueError("Cannot create {} engine: specify an underlying table".format( raise ValueError("Cannot create {} engine: specify an underlying table".format(self.__class__.__name__))
self.__class__.__name__))
params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]] params = ["`%s`" % p for p in [self.cluster, db.db_name, self.table_name]]
if self.sharding_key: if self.sharding_key:

View File

@ -1,39 +1,45 @@
from __future__ import unicode_literals
import datetime import datetime
import iso8601
import pytz
from calendar import timegm from calendar import timegm
from decimal import Decimal, localcontext from decimal import Decimal, localcontext
from uuid import UUID
from logging import getLogger
from pytz import BaseTzInfo
from .utils import escape, parse_array, comma_join, string_or_func, get_subclass_names
from .funcs import F, FunctionOperatorsMixin
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from logging import getLogger
from uuid import UUID
logger = getLogger('clickhouse_orm') import iso8601
import pytz
from pytz import BaseTzInfo
from .funcs import F, FunctionOperatorsMixin
from .utils import comma_join, escape, get_subclass_names, parse_array, string_or_func
logger = getLogger("clickhouse_orm")
class Field(FunctionOperatorsMixin): class Field(FunctionOperatorsMixin):
''' """
Abstract base class for all field types. Abstract base class for all field types.
''' """
name = None # this is set by the parent model
parent = None # this is set by the parent model name = None # this is set by the parent model
creation_counter = 0 # used for keeping the model fields ordered parent = None # this is set by the parent model
class_default = 0 # should be overridden by concrete subclasses creation_counter = 0 # used for keeping the model fields ordered
db_type = None # should be overridden by concrete subclasses class_default = 0 # should be overridden by concrete subclasses
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert [default, alias, materialized].count(None) >= 2, \ assert [default, alias, materialized].count(
"Only one of default, alias and materialized parameters can be given" None
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\ ) >= 2, "Only one of default, alias and materialized parameters can be given"
"Alias parameter must be a string or function object, if given" assert (
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\ alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != ""
"Materialized parameter must be a string or function object, if given" ), "Alias parameter must be a string or function object, if given"
assert (
materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != ""
), "Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given" assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, str) and codec != "", \ assert codec is None or isinstance(codec, str) and codec != "", "Codec field must be string, if given"
"Codec field must be string, if given" if alias:
assert codec is None, "Codec cannot be used for alias fields"
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
@ -47,49 +53,51 @@ class Field(FunctionOperatorsMixin):
return self.name return self.name
def __repr__(self): def __repr__(self):
return '<%s>' % self.__class__.__name__ return "<%s>" % self.__class__.__name__
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
''' """
Converts the input value into the expected Python data type, raising ValueError if the Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this. data can't be converted. Returns the converted value. Subclasses should override this.
The timezone_in_use parameter should be consulted when parsing datetime fields. The timezone_in_use parameter should be consulted when parsing datetime fields.
''' """
return value # pragma: no cover return value # pragma: no cover
def validate(self, value): def validate(self, value):
''' """
Called after to_python to validate that the value is suitable for the field's database type. Called after to_python to validate that the value is suitable for the field's database type.
Subclasses should override this. Subclasses should override this.
''' """
pass pass
def _range_check(self, value, min_value, max_value): def _range_check(self, value, min_value, max_value):
''' """
Utility method to check that the given value is between min_value and max_value. Utility method to check that the given value is between min_value and max_value.
''' """
if value < min_value or value > max_value: if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value)) raise ValueError(
"%s out of range - %s is not between %s and %s" % (self.__class__.__name__, value, min_value, max_value)
)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
''' """
Returns the field's value prepared for writing to the database. Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes. When quote is true, strings are surrounded by single quotes.
''' """
return escape(value, quote) return escape(value, quote)
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
''' """
Returns an SQL expression describing the field (e.g. for CREATE TABLE). Returns an SQL expression describing the field (e.g. for CREATE TABLE).
- `with_default_expression`: If True, adds default value to sql. - `with_default_expression`: If True, adds default value to sql.
It doesn't affect fields with alias and materialized values. It doesn't affect fields with alias and materialized values.
- `db`: Database, used for checking supported features. - `db`: Database, used for checking supported features.
''' """
sql = self.db_type sql = self.db_type
args = self.get_db_type_args() args = self.get_db_type_args()
if args: if args:
sql += '(%s)' % comma_join(args) sql += "(%s)" % comma_join(args)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql
@ -99,18 +107,18 @@ class Field(FunctionOperatorsMixin):
return [] return []
def _extra_params(self, db): def _extra_params(self, db):
sql = '' sql = ""
if self.alias: if self.alias:
sql += ' ALIAS %s' % string_or_func(self.alias) sql += " ALIAS %s" % string_or_func(self.alias)
elif self.materialized: elif self.materialized:
sql += ' MATERIALIZED %s' % string_or_func(self.materialized) sql += " MATERIALIZED %s" % string_or_func(self.materialized)
elif isinstance(self.default, F): elif isinstance(self.default, F):
sql += ' DEFAULT %s' % self.default.to_sql() sql += " DEFAULT %s" % self.default.to_sql()
elif self.default: elif self.default:
default = self.to_db_string(self.default) default = self.to_db_string(self.default)
sql += ' DEFAULT %s' % default sql += " DEFAULT %s" % default
if self.codec and db and db.has_codec_support: if self.codec and db and db.has_codec_support:
sql += ' CODEC(%s)' % self.codec sql += " CODEC(%s)" % self.codec
return sql return sql
def isinstance(self, types): def isinstance(self, types):
@ -124,43 +132,42 @@ class Field(FunctionOperatorsMixin):
""" """
if isinstance(self, types): if isinstance(self, types):
return True return True
inner_field = getattr(self, 'inner_field', None) inner_field = getattr(self, "inner_field", None)
while inner_field: while inner_field:
if isinstance(inner_field, types): if isinstance(inner_field, types):
return True return True
inner_field = getattr(inner_field, 'inner_field', None) inner_field = getattr(inner_field, "inner_field", None)
return False return False
class StringField(Field): class StringField(Field):
class_default = '' class_default = ""
db_type = 'String' db_type = "String"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, str): if isinstance(value, str):
return value return value
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode('UTF-8') return value.decode("utf-8")
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s: %r" % (self.__class__.__name__, value))
class FixedStringField(StringField): class FixedStringField(StringField):
def __init__(self, length, default=None, alias=None, materialized=None, readonly=None): def __init__(self, length, default=None, alias=None, materialized=None, readonly=None):
self._length = length self._length = length
self.db_type = 'FixedString(%d)' % length self.db_type = "FixedString(%d)" % length
super(FixedStringField, self).__init__(default, alias, materialized, readonly) super(FixedStringField, self).__init__(default, alias, materialized, readonly)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
value = super(FixedStringField, self).to_python(value, timezone_in_use) value = super(FixedStringField, self).to_python(value, timezone_in_use)
return value.rstrip('\0') return value.rstrip("\0")
def validate(self, value): def validate(self, value):
if isinstance(value, str): if isinstance(value, str):
value = value.encode('UTF-8') value = value.encode("utf-8")
if len(value) > self._length: if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length)) raise ValueError("Value of %d bytes is too long for FixedStringField(%d)" % (len(value), self._length))
class DateField(Field): class DateField(Field):
@ -168,7 +175,7 @@ class DateField(Field):
min_value = datetime.date(1970, 1, 1) min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2105, 12, 31) max_value = datetime.date(2105, 12, 31)
class_default = min_value class_default = min_value
db_type = 'Date' db_type = "Date"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
@ -178,10 +185,10 @@ class DateField(Field):
if isinstance(value, int): if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value) return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, str): if isinstance(value, str):
if value == '0000-00-00': if value == "0000-00-00":
return DateField.min_value return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date() return datetime.datetime.strptime(value, "%Y-%m-%d").date()
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def validate(self, value): def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value) self._range_check(value, DateField.min_value, DateField.max_value)
@ -193,10 +200,9 @@ class DateField(Field):
class DateTimeField(Field): class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc) class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime' db_type = "DateTime"
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None):
timezone=None):
super().__init__(default, alias, materialized, readonly, codec) super().__init__(default, alias, materialized, readonly, codec)
# assert not timezone, 'Temporarily field timezone is not supported' # assert not timezone, 'Temporarily field timezone is not supported'
if timezone: if timezone:
@ -217,7 +223,7 @@ class DateTimeField(Field):
if isinstance(value, int): if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str): if isinstance(value, str):
if value == '0000-00-00 00:00:00': if value == "0000-00-00 00:00:00":
return self.class_default return self.class_default
if len(value) == 10: if len(value) == 10:
try: try:
@ -235,19 +241,20 @@ class DateTimeField(Field):
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = timezone_in_use.localize(dt) dt = timezone_in_use.localize(dt)
return dt return dt
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape('%010d' % timegm(value.utctimetuple()), quote) return escape("%010d" % timegm(value.utctimetuple()), quote)
class DateTime64Field(DateTimeField): class DateTime64Field(DateTimeField):
db_type = 'DateTime64' db_type = "DateTime64"
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None, def __init__(
timezone=None, precision=6): self, default=None, alias=None, materialized=None, readonly=None, codec=None, timezone=None, precision=6
):
super().__init__(default, alias, materialized, readonly, codec, timezone) super().__init__(default, alias, materialized, readonly, codec, timezone)
assert precision is None or isinstance(precision, int), 'Precision must be int type' assert precision is None or isinstance(precision, int), "Precision must be int type"
self.precision = precision self.precision = precision
def get_db_type_args(self): def get_db_type_args(self):
@ -263,11 +270,10 @@ class DateTime64Field(DateTimeField):
Returns string in 0000000000.000000 format, where remainder digits count is equal to precision Returns string in 0000000000.000000 format, where remainder digits count is equal to precision
""" """
return escape( return escape(
'{timestamp:0{width}.{precision}f}'.format( "{timestamp:0{width}.{precision}f}".format(
timestamp=value.timestamp(), timestamp=value.timestamp(), width=11 + self.precision, precision=self.precision
width=11 + self.precision, ),
precision=self.precision), quote,
quote
) )
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
@ -277,8 +283,8 @@ class DateTime64Field(DateTimeField):
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc) return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, str): if isinstance(value, str):
left_part = value.split('.')[0] left_part = value.split(".")[0]
if left_part == '0000-00-00 00:00:00': if left_part == "0000-00-00 00:00:00":
return self.class_default return self.class_default
if len(left_part) == 10: if len(left_part) == 10:
try: try:
@ -290,14 +296,15 @@ class DateTime64Field(DateTimeField):
class BaseIntField(Field): class BaseIntField(Field):
''' """
Abstract base class for all integer-type fields. Abstract base class for all integer-type fields.
''' """
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
try: try:
return int(value) return int(value)
except: except Exception:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain # There's no need to call escape since numbers do not contain
@ -311,69 +318,69 @@ class BaseIntField(Field):
class UInt8Field(BaseIntField): class UInt8Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**8 - 1 max_value = 2 ** 8 - 1
db_type = 'UInt8' db_type = "UInt8"
class UInt16Field(BaseIntField): class UInt16Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**16 - 1 max_value = 2 ** 16 - 1
db_type = 'UInt16' db_type = "UInt16"
class UInt32Field(BaseIntField): class UInt32Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**32 - 1 max_value = 2 ** 32 - 1
db_type = 'UInt32' db_type = "UInt32"
class UInt64Field(BaseIntField): class UInt64Field(BaseIntField):
min_value = 0 min_value = 0
max_value = 2**64 - 1 max_value = 2 ** 64 - 1
db_type = 'UInt64' db_type = "UInt64"
class Int8Field(BaseIntField): class Int8Field(BaseIntField):
min_value = -2**7 min_value = -(2 ** 7)
max_value = 2**7 - 1 max_value = 2 ** 7 - 1
db_type = 'Int8' db_type = "Int8"
class Int16Field(BaseIntField): class Int16Field(BaseIntField):
min_value = -2**15 min_value = -(2 ** 15)
max_value = 2**15 - 1 max_value = 2 ** 15 - 1
db_type = 'Int16' db_type = "Int16"
class Int32Field(BaseIntField): class Int32Field(BaseIntField):
min_value = -2**31 min_value = -(2 ** 31)
max_value = 2**31 - 1 max_value = 2 ** 31 - 1
db_type = 'Int32' db_type = "Int32"
class Int64Field(BaseIntField): class Int64Field(BaseIntField):
min_value = -2**63 min_value = -(2 ** 63)
max_value = 2**63 - 1 max_value = 2 ** 63 - 1
db_type = 'Int64' db_type = "Int64"
class BaseFloatField(Field): class BaseFloatField(Field):
''' """
Abstract base class for all float-type fields. Abstract base class for all float-type fields.
''' """
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
try: try:
return float(value) return float(value)
except: except Exception:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain # There's no need to call escape since numbers do not contain
@ -383,28 +390,28 @@ class BaseFloatField(Field):
class Float32Field(BaseFloatField): class Float32Field(BaseFloatField):
db_type = 'Float32' db_type = "Float32"
class Float64Field(BaseFloatField): class Float64Field(BaseFloatField):
db_type = 'Float64' db_type = "Float64"
class DecimalField(Field): class DecimalField(Field):
''' """
Base class for all decimal fields. Can also be used directly. Base class for all decimal fields. Can also be used directly.
''' """
def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, precision, scale, default=None, alias=None, materialized=None, readonly=None):
assert 1 <= precision <= 38, 'Precision must be between 1 and 38' assert 1 <= precision <= 38, "Precision must be between 1 and 38"
assert 0 <= scale <= precision, 'Scale must be between 0 and the given precision' assert 0 <= scale <= precision, "Scale must be between 0 and the given precision"
self.precision = precision self.precision = precision
self.scale = scale self.scale = scale
self.db_type = 'Decimal(%d,%d)' % (self.precision, self.scale) self.db_type = "Decimal(%d,%d)" % (self.precision, self.scale)
with localcontext() as ctx: with localcontext() as ctx:
ctx.prec = 38 ctx.prec = 38
self.exp = Decimal(10) ** -self.scale # for rounding to the required scale self.exp = Decimal(10) ** -self.scale # for rounding to the required scale
self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp self.max_value = Decimal(10 ** (self.precision - self.scale)) - self.exp
self.min_value = -self.max_value self.min_value = -self.max_value
super(DecimalField, self).__init__(default, alias, materialized, readonly) super(DecimalField, self).__init__(default, alias, materialized, readonly)
@ -413,10 +420,10 @@ class DecimalField(Field):
if not isinstance(value, Decimal): if not isinstance(value, Decimal):
try: try:
value = Decimal(value) value = Decimal(value)
except: except Exception:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Invalid value for %s - %r" % (self.__class__.__name__, value))
if not value.is_finite(): if not value.is_finite():
raise ValueError('Non-finite value for %s - %r' % (self.__class__.__name__, value)) raise ValueError("Non-finite value for %s - %r" % (self.__class__.__name__, value))
return self._round(value) return self._round(value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
@ -432,30 +439,27 @@ class DecimalField(Field):
class Decimal32Field(DecimalField): class Decimal32Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly) super(Decimal32Field, self).__init__(9, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal32(%d)' % scale self.db_type = "Decimal32(%d)" % scale
class Decimal64Field(DecimalField): class Decimal64Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly) super(Decimal64Field, self).__init__(18, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal64(%d)' % scale self.db_type = "Decimal64(%d)" % scale
class Decimal128Field(DecimalField): class Decimal128Field(DecimalField):
def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None): def __init__(self, scale, default=None, alias=None, materialized=None, readonly=None):
super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly) super(Decimal128Field, self).__init__(38, scale, default, alias, materialized, readonly)
self.db_type = 'Decimal128(%d)' % scale self.db_type = "Decimal128(%d)" % scale
class BaseEnumField(Field): class BaseEnumField(Field):
''' """
Abstract base class for all enum-type fields. Abstract base class for all enum-type fields.
''' """
def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, enum_cls, default=None, alias=None, materialized=None, readonly=None, codec=None):
self.enum_cls = enum_cls self.enum_cls = enum_cls
@ -473,7 +477,7 @@ class BaseEnumField(Field):
except Exception: except Exception:
return self.enum_cls(value) return self.enum_cls(value)
if isinstance(value, bytes): if isinstance(value, bytes):
decoded = value.decode('UTF-8') decoded = value.decode("utf-8")
try: try:
return self.enum_cls[decoded] return self.enum_cls[decoded]
except Exception: except Exception:
@ -482,38 +486,39 @@ class BaseEnumField(Field):
return self.enum_cls(value) return self.enum_cls(value)
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value)) raise ValueError("Invalid value for %s: %r" % (self.enum_cls.__name__, value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(value.name, quote) return escape(value.name, quote)
def get_db_type_args(self): def get_db_type_args(self):
return ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls] return ["%s = %d" % (escape(item.name), item.value) for item in self.enum_cls]
@classmethod @classmethod
def create_ad_hoc_field(cls, db_type): def create_ad_hoc_field(cls, db_type):
''' """
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)" Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field. this method returns a matching enum field.
''' """
import re import re
from enum import Enum from enum import Enum
members = {} members = {}
for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type): for match in re.finditer(r"'([\w ]+)' = (-?\d+)", db_type):
members[match.group(1)] = int(match.group(2)) members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members) enum_cls = Enum("AdHocEnum", members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field field_class = Enum8Field if db_type.startswith("Enum8") else Enum16Field
return field_class(enum_cls) return field_class(enum_cls)
class Enum8Field(BaseEnumField): class Enum8Field(BaseEnumField):
db_type = 'Enum8' db_type = "Enum8"
class Enum16Field(BaseEnumField): class Enum16Field(BaseEnumField):
db_type = 'Enum16' db_type = "Enum16"
class ArrayField(Field): class ArrayField(Field):
@ -530,9 +535,9 @@ class ArrayField(Field):
if isinstance(value, str): if isinstance(value, str):
value = parse_array(value) value = parse_array(value)
elif isinstance(value, bytes): elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8')) value = parse_array(value.decode("utf-8"))
elif not isinstance(value, (list, tuple)): elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) raise ValueError("ArrayField expects list or tuple, not %s" % type(value))
return [self.inner_field.to_python(v, timezone_in_use) for v in value] return [self.inner_field.to_python(v, timezone_in_use) for v in value]
def validate(self, value): def validate(self, value):
@ -541,19 +546,19 @@ class ArrayField(Field):
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
array = [self.inner_field.to_db_string(v, quote=True) for v in value] array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + comma_join(array) + ']' return "[" + comma_join(array) + "]"
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
sql = 'Array(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) sql = "Array(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression and self.codec and db and db.has_codec_support: if with_default_expression and self.codec and db and db.has_codec_support:
sql+= ' CODEC(%s)' % self.codec sql += " CODEC(%s)" % self.codec
return sql return sql
class UUIDField(Field): class UUIDField(Field):
class_default = UUID(int=0) class_default = UUID(int=0)
db_type = 'UUID' db_type = "UUID"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, UUID): if isinstance(value, UUID):
@ -567,7 +572,7 @@ class UUIDField(Field):
elif isinstance(value, tuple): elif isinstance(value, tuple):
return UUID(fields=value) return UUID(fields=value)
else: else:
raise ValueError('Invalid value for UUIDField: %r' % value) raise ValueError("Invalid value for UUIDField: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -576,7 +581,7 @@ class UUIDField(Field):
class IPv4Field(Field): class IPv4Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv4' db_type = "IPv4"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, IPv4Address): if isinstance(value, IPv4Address):
@ -584,7 +589,7 @@ class IPv4Field(Field):
elif isinstance(value, (bytes, str, int)): elif isinstance(value, (bytes, str, int)):
return IPv4Address(value) return IPv4Address(value)
else: else:
raise ValueError('Invalid value for IPv4Address: %r' % value) raise ValueError("Invalid value for IPv4Address: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -593,7 +598,7 @@ class IPv4Field(Field):
class IPv6Field(Field): class IPv6Field(Field):
class_default = 0 class_default = 0
db_type = 'IPv6' db_type = "IPv6"
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if isinstance(value, IPv6Address): if isinstance(value, IPv6Address):
@ -601,7 +606,7 @@ class IPv6Field(Field):
elif isinstance(value, (bytes, str, int)): elif isinstance(value, (bytes, str, int)):
return IPv6Address(value) return IPv6Address(value)
else: else:
raise ValueError('Invalid value for IPv6Address: %r' % value) raise ValueError("Invalid value for IPv6Address: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(str(value), quote) return escape(str(value), quote)
@ -611,9 +616,10 @@ class NullableField(Field):
class_default = None class_default = None
def __init__(self, inner_field, default=None, alias=None, materialized=None, def __init__(self, inner_field, default=None, alias=None, materialized=None, extra_null_values=None, codec=None):
extra_null_values=None, codec=None): assert isinstance(
assert isinstance(inner_field, Field), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field) inner_field, Field
), "The first argument of NullableField must be a Field instance. Not: {}".format(inner_field)
self.inner_field = inner_field self.inner_field = inner_field
self._null_values = [None] self._null_values = [None]
if extra_null_values: if extra_null_values:
@ -621,7 +627,7 @@ class NullableField(Field):
super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec) super(NullableField, self).__init__(default, alias, materialized, readonly=None, codec=codec)
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
if value == '\\N' or value in self._null_values: if value == "\\N" or value in self._null_values:
return None return None
return self.inner_field.to_python(value, timezone_in_use) return self.inner_field.to_python(value, timezone_in_use)
@ -630,22 +636,27 @@ class NullableField(Field):
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
if value in self._null_values: if value in self._null_values:
return '\\N' return "\\N"
return self.inner_field.to_db_string(value, quote=quote) return self.inner_field.to_db_string(value, quote=quote)
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
sql = 'Nullable(%s)' % self.inner_field.get_sql(with_default_expression=False, db=db) sql = "Nullable(%s)" % self.inner_field.get_sql(with_default_expression=False, db=db)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql
class LowCardinalityField(Field): class LowCardinalityField(Field):
def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None): def __init__(self, inner_field, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert isinstance(inner_field, Field), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field) assert isinstance(
assert not isinstance(inner_field, LowCardinalityField), "LowCardinality inner fields are not supported by the ORM" inner_field, Field
assert not isinstance(inner_field, ArrayField), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead" ), "The first argument of LowCardinalityField must be a Field instance. Not: {}".format(inner_field)
assert not isinstance(
inner_field, LowCardinalityField
), "LowCardinality inner fields are not supported by the ORM"
assert not isinstance(
inner_field, ArrayField
), "Array field inside LowCardinality are not supported by the ORM. Use Array(LowCardinality) instead"
self.inner_field = inner_field self.inner_field = inner_field
self.class_default = self.inner_field.class_default self.class_default = self.inner_field.class_default
super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec) super(LowCardinalityField, self).__init__(default, alias, materialized, readonly, codec)
@ -661,10 +672,14 @@ class LowCardinalityField(Field):
def get_sql(self, with_default_expression=True, db=None): def get_sql(self, with_default_expression=True, db=None):
if db and db.has_low_cardinality_support: if db and db.has_low_cardinality_support:
sql = 'LowCardinality(%s)' % self.inner_field.get_sql(with_default_expression=False) sql = "LowCardinality(%s)" % self.inner_field.get_sql(with_default_expression=False)
else: else:
sql = self.inner_field.get_sql(with_default_expression=False) sql = self.inner_field.get_sql(with_default_expression=False)
logger.warning('LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback'.format(self.inner_field.__class__.__name__)) logger.warning(
"LowCardinalityField not supported on clickhouse-server version < 19.0 using {} as fallback".format(
self.inner_field.__class__.__name__
)
)
if with_default_expression: if with_default_expression:
sql += self._extra_params(db) sql += self._extra_params(db)
return sql return sql

File diff suppressed because it is too large Load Diff

View File

@ -1,70 +1,71 @@
from .models import Model, BufferModel
from .fields import DateField, StringField
from .engines import MergeTree
from .utils import escape, get_subclass_names
import logging import logging
logger = logging.getLogger('migrations')
from .engines import MergeTree
from .fields import DateField, StringField
from .models import BufferModel, Model
from .utils import get_subclass_names
logger = logging.getLogger("migrations")
class Operation(): class Operation:
''' """
Base class for migration operations. Base class for migration operations.
''' """
def apply(self, database): def apply(self, database):
raise NotImplementedError() # pragma: no cover raise NotImplementedError() # pragma: no cover
class ModelOperation(Operation): class ModelOperation(Operation):
''' """
Base class for migration operations that work on a specific model. Base class for migration operations that work on a specific model.
''' """
def __init__(self, model_class): def __init__(self, model_class):
''' """
Initializer. Initializer.
''' """
self.model_class = model_class self.model_class = model_class
self.table_name = model_class.table_name() self.table_name = model_class.table_name()
def _alter_table(self, database, cmd): def _alter_table(self, database, cmd):
''' """
Utility for running ALTER TABLE commands. Utility for running ALTER TABLE commands.
''' """
cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd) cmd = "ALTER TABLE $db.`%s` %s" % (self.table_name, cmd)
logger.debug(cmd) logger.debug(cmd)
database.raw(cmd) database.raw(cmd)
class CreateTable(ModelOperation): class CreateTable(ModelOperation):
''' """
A migration operation that creates a table for a given model class. A migration operation that creates a table for a given model class.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Create table %s', self.table_name) logger.info(" Create table %s", self.table_name)
if issubclass(self.model_class, BufferModel): if issubclass(self.model_class, BufferModel):
database.create_table(self.model_class.engine.main_model) database.create_table(self.model_class.engine.main_model)
database.create_table(self.model_class) database.create_table(self.model_class)
class AlterTable(ModelOperation): class AlterTable(ModelOperation):
''' """
A migration operation that compares the table of a given model class to A migration operation that compares the table of a given model class to
the model's fields, and alters the table to match the model. The operation can: the model's fields, and alters the table to match the model. The operation can:
- add new columns - add new columns
- drop obsolete columns - drop obsolete columns
- modify column types - modify column types
Default values are not altered by this operation. Default values are not altered by this operation.
''' """
def _get_table_fields(self, database): def _get_table_fields(self, database):
query = "DESC `%s`.`%s`" % (database.db_name, self.table_name) query = "DESC `%s`.`%s`" % (database.db_name, self.table_name)
return [(row.name, row.type) for row in database.select(query)] return [(row.name, row.type) for row in database.select(query)]
def apply(self, database): def apply(self, database):
logger.info(' Alter table %s', self.table_name) logger.info(" Alter table %s", self.table_name)
# Note that MATERIALIZED and ALIAS fields are always at the end of the DESC, # Note that MATERIALIZED and ALIAS fields are always at the end of the DESC,
# ADD COLUMN ... AFTER doesn't affect it # ADD COLUMN ... AFTER doesn't affect it
@ -73,8 +74,8 @@ class AlterTable(ModelOperation):
# Identify fields that were deleted from the model # Identify fields that were deleted from the model
deleted_fields = set(table_fields.keys()) - set(self.model_class.fields()) deleted_fields = set(table_fields.keys()) - set(self.model_class.fields())
for name in deleted_fields: for name in deleted_fields:
logger.info(' Drop column %s', name) logger.info(" Drop column %s", name)
self._alter_table(database, 'DROP COLUMN %s' % name) self._alter_table(database, "DROP COLUMN %s" % name)
del table_fields[name] del table_fields[name]
# Identify fields that were added to the model # Identify fields that were added to the model
@ -82,11 +83,11 @@ class AlterTable(ModelOperation):
for name, field in self.model_class.fields().items(): for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias) is_regular_field = not (field.materialized or field.alias)
if name not in table_fields: if name not in table_fields:
logger.info(' Add column %s', name) logger.info(" Add column %s", name)
assert prev_name, 'Cannot add a column to the beginning of the table' assert prev_name, "Cannot add a column to the beginning of the table"
cmd = 'ADD COLUMN %s %s' % (name, field.get_sql(db=database)) cmd = "ADD COLUMN %s %s" % (name, field.get_sql(db=database))
if is_regular_field: if is_regular_field:
cmd += ' AFTER %s' % prev_name cmd += " AFTER %s" % prev_name
self._alter_table(database, cmd) self._alter_table(database, cmd)
if is_regular_field: if is_regular_field:
@ -97,25 +98,28 @@ class AlterTable(ModelOperation):
# Identify fields whose type was changed # Identify fields whose type was changed
# The order of class attributes can be changed any time, so we can't count on it # The order of class attributes can be changed any time, so we can't count on it
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save # Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47 # attribute position. Watch https://github.com/Infinidat/clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database) model_fields = {
for name, field in self.model_class.fields().items()} name: field.get_sql(with_default_expression=False, db=database)
for name, field in self.model_class.fields().items()
}
for field_name, field_sql in self._get_table_fields(database): for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment # All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement' assert field_name in model_fields, "Model fields and table columns in disagreement"
if field_sql != model_fields[field_name]: if field_sql != model_fields[field_name]:
logger.info(' Change type of column %s from %s to %s', field_name, field_sql, logger.info(
model_fields[field_name]) " Change type of column %s from %s to %s", field_name, field_sql, model_fields[field_name]
self._alter_table(database, 'MODIFY COLUMN %s %s' % (field_name, model_fields[field_name])) )
self._alter_table(database, "MODIFY COLUMN %s %s" % (field_name, model_fields[field_name]))
class AlterTableWithBuffer(ModelOperation): class AlterTableWithBuffer(ModelOperation):
''' """
A migration operation for altering a buffer table and its underlying on-disk table. A migration operation for altering a buffer table and its underlying on-disk table.
The buffer table is dropped, the on-disk table is altered, and then the buffer table The buffer table is dropped, the on-disk table is altered, and then the buffer table
is re-created. is re-created.
''' """
def apply(self, database): def apply(self, database):
if issubclass(self.model_class, BufferModel): if issubclass(self.model_class, BufferModel):
@ -127,149 +131,152 @@ class AlterTableWithBuffer(ModelOperation):
class DropTable(ModelOperation): class DropTable(ModelOperation):
''' """
A migration operation that drops the table of a given model class. A migration operation that drops the table of a given model class.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Drop table %s', self.table_name) logger.info(" Drop table %s", self.table_name)
database.drop_table(self.model_class) database.drop_table(self.model_class)
class AlterConstraints(ModelOperation): class AlterConstraints(ModelOperation):
''' """
A migration operation that adds new constraints from the model to the database A migration operation that adds new constraints from the model to the database
table, and drops obsolete ones. Constraints are identified by their names, so table, and drops obsolete ones. Constraints are identified by their names, so
a change in an existing constraint will not be detected unless its name was changed too. a change in an existing constraint will not be detected unless its name was changed too.
ClickHouse does not check that the constraints hold for existing data in the table. ClickHouse does not check that the constraints hold for existing data in the table.
''' """
def apply(self, database): def apply(self, database):
logger.info(' Alter constraints for %s', self.table_name) logger.info(" Alter constraints for %s", self.table_name)
existing = self._get_constraint_names(database) existing = self._get_constraint_names(database)
# Go over constraints in the model # Go over constraints in the model
for constraint in self.model_class._constraints.values(): for constraint in self.model_class._constraints.values():
# Check if it's a new constraint # Check if it's a new constraint
if constraint.name not in existing: if constraint.name not in existing:
logger.info(' Add constraint %s', constraint.name) logger.info(" Add constraint %s", constraint.name)
self._alter_table(database, 'ADD %s' % constraint.create_table_sql()) self._alter_table(database, "ADD %s" % constraint.create_table_sql())
else: else:
existing.remove(constraint.name) existing.remove(constraint.name)
# Remaining constraints in `existing` are obsolete # Remaining constraints in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop constraint %s', name) logger.info(" Drop constraint %s", name)
self._alter_table(database, 'DROP CONSTRAINT `%s`' % name) self._alter_table(database, "DROP CONSTRAINT `%s`" % name)
def _get_constraint_names(self, database): def _get_constraint_names(self, database):
''' """
Returns a set containing the names of existing constraints in the table. Returns a set containing the names of existing constraints in the table.
''' """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sCONSTRAINT\s+`?(.+?)`?\s+CHECK\s", table_def)
return set(matches) return set(matches)
class AlterIndexes(ModelOperation): class AlterIndexes(ModelOperation):
''' """
A migration operation that adds new indexes from the model to the database A migration operation that adds new indexes from the model to the database
table, and drops obsolete ones. Indexes are identified by their names, so table, and drops obsolete ones. Indexes are identified by their names, so
a change in an existing index will not be detected unless its name was changed too. a change in an existing index will not be detected unless its name was changed too.
''' """
def __init__(self, model_class, reindex=False): def __init__(self, model_class, reindex=False):
''' """
Initializer. Initializer.
By default ClickHouse does not build indexes over existing data, only for By default ClickHouse does not build indexes over existing data, only for
new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build new data. Passing `reindex=True` will run `OPTIMIZE TABLE` in order to build
the indexes over the existing data. the indexes over the existing data.
''' """
super().__init__(model_class) super().__init__(model_class)
self.reindex = reindex self.reindex = reindex
def apply(self, database): def apply(self, database):
logger.info(' Alter indexes for %s', self.table_name) logger.info(" Alter indexes for %s", self.table_name)
existing = self._get_index_names(database) existing = self._get_index_names(database)
logger.info(existing) logger.info(existing)
# Go over indexes in the model # Go over indexes in the model
for index in self.model_class._indexes.values(): for index in self.model_class._indexes.values():
# Check if it's a new index # Check if it's a new index
if index.name not in existing: if index.name not in existing:
logger.info(' Add index %s', index.name) logger.info(" Add index %s", index.name)
self._alter_table(database, 'ADD %s' % index.create_table_sql()) self._alter_table(database, "ADD %s" % index.create_table_sql())
else: else:
existing.remove(index.name) existing.remove(index.name)
# Remaining indexes in `existing` are obsolete # Remaining indexes in `existing` are obsolete
for name in existing: for name in existing:
logger.info(' Drop index %s', name) logger.info(" Drop index %s", name)
self._alter_table(database, 'DROP INDEX `%s`' % name) self._alter_table(database, "DROP INDEX `%s`" % name)
# Reindex # Reindex
if self.reindex: if self.reindex:
logger.info(' Build indexes on table') logger.info(" Build indexes on table")
database.raw('OPTIMIZE TABLE $db.`%s` FINAL' % self.table_name) database.raw("OPTIMIZE TABLE $db.`%s` FINAL" % self.table_name)
def _get_index_names(self, database): def _get_index_names(self, database):
''' """
Returns a set containing the names of existing indexes in the table. Returns a set containing the names of existing indexes in the table.
''' """
import re import re
table_def = database.raw('SHOW CREATE TABLE $db.`%s`' % self.table_name)
matches = re.findall(r'\sINDEX\s+`?(.+?)`?\s+', table_def) table_def = database.raw("SHOW CREATE TABLE $db.`%s`" % self.table_name)
matches = re.findall(r"\sINDEX\s+`?(.+?)`?\s+", table_def)
return set(matches) return set(matches)
class RunPython(Operation): class RunPython(Operation):
''' """
A migration operation that executes a Python function. A migration operation that executes a Python function.
''' """
def __init__(self, func): def __init__(self, func):
''' """
Initializer. The given Python function will be called with a single Initializer. The given Python function will be called with a single
argument - the Database instance to apply the migration to. argument - the Database instance to apply the migration to.
''' """
assert callable(func), "'func' argument must be function" assert callable(func), "'func' argument must be function"
self._func = func self._func = func
def apply(self, database): def apply(self, database):
logger.info(' Executing python operation %s', self._func.__name__) logger.info(" Executing python operation %s", self._func.__name__)
self._func(database) self._func(database)
class RunSQL(Operation): class RunSQL(Operation):
''' """
A migration operation that executes arbitrary SQL statements. A migration operation that executes arbitrary SQL statements.
''' """
def __init__(self, sql): def __init__(self, sql):
''' """
Initializer. The given sql argument must be a valid SQL statement or Initializer. The given sql argument must be a valid SQL statement or
list of statements. list of statements.
''' """
if isinstance(sql, str): if isinstance(sql, str):
sql = [sql] sql = [sql]
assert isinstance(sql, list), "'sql' argument must be string or list of strings" assert isinstance(sql, list), "'sql' argument must be string or list of strings"
self._sql = sql self._sql = sql
def apply(self, database): def apply(self, database):
logger.info(' Executing raw SQL operations') logger.info(" Executing raw SQL operations")
for item in self._sql: for item in self._sql:
database.raw(item) database.raw(item)
class MigrationHistory(Model): class MigrationHistory(Model):
''' """
A model for storing which migrations were already applied to the containing database. A model for storing which migrations were already applied to the containing database.
''' """
package_name = StringField() package_name = StringField()
module_name = StringField() module_name = StringField()
applied = DateField() applied = DateField()
engine = MergeTree('applied', ('package_name', 'module_name')) engine = MergeTree("applied", ("package_name", "module_name"))
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'infi_clickhouse_orm_migrations' return "infi_clickhouse_orm_migrations"
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -1,4 +1,3 @@
from __future__ import unicode_literals
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
@ -6,84 +5,83 @@ from logging import getLogger
import pytz import pytz
from .engines import Distributed, Merge
from .fields import Field, StringField from .fields import Field, StringField
from .utils import parse_tsv, NO_VALUE, get_subclass_names, arg_to_sql, unescape
from .query import QuerySet
from .funcs import F from .funcs import F
from .engines import Merge, Distributed from .query import QuerySet
from .utils import NO_VALUE, arg_to_sql, get_subclass_names, parse_tsv
logger = getLogger('clickhouse_orm')
logger = getLogger("clickhouse_orm")
class Constraint: class Constraint:
''' """
Defines a model constraint. Defines a model constraint.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr): def __init__(self, expr):
''' """
Initializer. Expects an expression that ClickHouse will verify when inserting data. Initializer. Expects an expression that ClickHouse will verify when inserting data.
''' """
self.expr = expr self.expr = expr
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this constraint during table creation. Returns the SQL statement for defining this constraint during table creation.
''' """
return 'CONSTRAINT `%s` CHECK %s' % (self.name, arg_to_sql(self.expr)) return "CONSTRAINT `%s` CHECK %s" % (self.name, arg_to_sql(self.expr))
class Index: class Index:
''' """
Defines a data-skipping index. Defines a data-skipping index.
''' """
name = None # this is set by the parent model name = None # this is set by the parent model
parent = None # this is set by the parent model parent = None # this is set by the parent model
def __init__(self, expr, type, granularity): def __init__(self, expr, type, granularity):
''' """
Initializer. Initializer.
- `expr` - a column, expression, or tuple of columns and expressions to index. - `expr` - a column, expression, or tuple of columns and expressions to index.
- `type` - the index type. Use one of the following methods to specify the type: - `type` - the index type. Use one of the following methods to specify the type:
`Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`. `Index.minmax`, `Index.set`, `Index.ngrambf_v1`, `Index.tokenbf_v1` or `Index.bloom_filter`.
- `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine). - `granularity` - index block size (number of multiples of the `index_granularity` defined by the engine).
''' """
self.expr = expr self.expr = expr
self.type = type self.type = type
self.granularity = granularity self.granularity = granularity
def create_table_sql(self): def create_table_sql(self):
''' """
Returns the SQL statement for defining this index during table creation. Returns the SQL statement for defining this index during table creation.
''' """
return 'INDEX `%s` %s TYPE %s GRANULARITY %d' % (self.name, arg_to_sql(self.expr), self.type, self.granularity) return "INDEX `%s` %s TYPE %s GRANULARITY %d" % (self.name, arg_to_sql(self.expr), self.type, self.granularity)
@staticmethod @staticmethod
def minmax(): def minmax():
''' """
An index that stores extremes of the specified expression (if the expression is tuple, then it stores An index that stores extremes of the specified expression (if the expression is tuple, then it stores
extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key. extremes for each element of tuple). The stored info is used for skipping blocks of data like the primary key.
''' """
return 'minmax' return "minmax"
@staticmethod @staticmethod
def set(max_rows): def set(max_rows):
''' """
An index that stores unique values of the specified expression (no more than max_rows rows, An index that stores unique values of the specified expression (no more than max_rows rows,
or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable or unlimited if max_rows=0). Uses the values to check if the WHERE expression is not satisfiable
on a block of data. on a block of data.
''' """
return 'set(%d)' % max_rows return "set(%d)" % max_rows
@staticmethod @staticmethod
def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def ngrambf_v1(n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing all ngrams from a block of data. An index that stores a Bloom filter containing all ngrams from a block of data.
Works only with strings. Can be used for optimization of equals, like and in expressions. Works only with strings. Can be used for optimization of equals, like and in expressions.
@ -92,12 +90,12 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'ngrambf_v1(%d, %d, %d, %d)' % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return "ngrambf_v1(%d, %d, %d, %d)" % (n, size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed): def tokenbf_v1(size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed):
''' """
An index that stores a Bloom filter containing string tokens. Tokens are sequences An index that stores a Bloom filter containing string tokens. Tokens are sequences
separated by non-alphanumeric characters. separated by non-alphanumeric characters.
@ -105,28 +103,28 @@ class Index:
for example 256 or 512, because it can be compressed well). for example 256 or 512, because it can be compressed well).
- `number_of_hash_functions` The number of hash functions used in the Bloom filter. - `number_of_hash_functions` The number of hash functions used in the Bloom filter.
- `random_seed` The seed for Bloom filter hash functions. - `random_seed` The seed for Bloom filter hash functions.
''' """
return 'tokenbf_v1(%d, %d, %d)' % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed) return "tokenbf_v1(%d, %d, %d)" % (size_of_bloom_filter_in_bytes, number_of_hash_functions, random_seed)
@staticmethod @staticmethod
def bloom_filter(false_positive=0.025): def bloom_filter(false_positive=0.025):
''' """
An index that stores a Bloom filter containing values of the index expression. An index that stores a Bloom filter containing values of the index expression.
- `false_positive` - the probability (between 0 and 1) of receiving a false positive - `false_positive` - the probability (between 0 and 1) of receiving a false positive
response from the filter response from the filter
''' """
return 'bloom_filter(%f)' % false_positive return "bloom_filter(%f)" % false_positive
class ModelBase(type): class ModelBase(type):
''' """
A metaclass for ORM models. It adds the _fields list to model classes. A metaclass for ORM models. It adds the _fields list to model classes.
''' """
ad_hoc_model_cache = {} ad_hoc_model_cache = {}
def __new__(cls, name, bases, attrs): def __new__(metacls, name, bases, attrs):
# Collect fields, constraints and indexes from parent classes # Collect fields, constraints and indexes from parent classes
fields = {} fields = {}
@ -170,90 +168,88 @@ class ModelBase(type):
_indexes=indexes, _indexes=indexes,
_writable_fields=OrderedDict([f for f in fields if not f[1].readonly]), _writable_fields=OrderedDict([f for f in fields if not f[1].readonly]),
_defaults=defaults, _defaults=defaults,
_has_funcs_as_defaults=has_funcs_as_defaults _has_funcs_as_defaults=has_funcs_as_defaults,
) )
model = super(ModelBase, cls).__new__(cls, str(name), bases, attrs) model = super(ModelBase, metacls).__new__(metacls, str(name), bases, attrs)
# Let each field, constraint and index know its parent and its own name # Let each field, constraint and index know its parent and its own name
for n, obj in chain(fields, constraints.items(), indexes.items()): for n, obj in chain(fields, constraints.items(), indexes.items()):
setattr(obj, 'parent', model) obj.parent = model
setattr(obj, 'name', n) obj.name = n
return model return model
@classmethod @classmethod
def create_ad_hoc_model(cls, fields, model_name='AdHocModel'): def create_ad_hoc_model(metacls, fields, model_name="AdHocModel"):
# fields is a list of tuples (name, db_type) # fields is a list of tuples (name, db_type)
# Check if model exists in cache # Check if model exists in cache
fields = list(fields) fields = list(fields)
cache_key = model_name + ' ' + str(fields) cache_key = model_name + " " + str(fields)
if cache_key in cls.ad_hoc_model_cache: if cache_key in metacls.ad_hoc_model_cache:
return cls.ad_hoc_model_cache[cache_key] return metacls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class # Create an ad hoc model class
attrs = {} attrs = {}
for name, db_type in fields: for name, db_type in fields:
attrs[name] = cls.create_ad_hoc_field(db_type) attrs[name] = metacls.create_ad_hoc_field(db_type)
model_class = cls.__new__(cls, model_name, (Model,), attrs) model_class = metacls.__new__(metacls, model_name, (Model,), attrs)
# Add the model class to the cache # Add the model class to the cache
cls.ad_hoc_model_cache[cache_key] = model_class metacls.ad_hoc_model_cache[cache_key] = model_class
return model_class return model_class
@classmethod @classmethod
def create_ad_hoc_field(cls, db_type): def create_ad_hoc_field(metacls, db_type):
import infi.clickhouse_orm.fields as orm_fields import clickhouse_orm.fields as orm_fields
# Enums # Enums
if db_type.startswith('Enum'): if db_type.startswith("Enum"):
return orm_fields.BaseEnumField.create_ad_hoc_field(db_type) return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
# DateTime with timezone # DateTime with timezone
if db_type.startswith('DateTime('): if db_type.startswith("DateTime("):
timezone = db_type[9:-1] timezone = db_type[9:-1]
return orm_fields.DateTimeField( return orm_fields.DateTimeField(timezone=timezone[1:-1] if timezone else None)
timezone=timezone[1:-1] if timezone else None
)
# DateTime64 # DateTime64
if db_type.startswith('DateTime64('): if db_type.startswith("DateTime64("):
precision, *timezone = [s.strip() for s in db_type[11:-1].split(',')] precision, *timezone = [s.strip() for s in db_type[11:-1].split(",")]
return orm_fields.DateTime64Field( return orm_fields.DateTime64Field(
precision=int(precision), precision=int(precision), timezone=timezone[0][1:-1] if timezone else None
timezone=timezone[0][1:-1] if timezone else None
) )
# Arrays # Arrays
if db_type.startswith('Array'): if db_type.startswith("Array"):
inner_field = cls.create_ad_hoc_field(db_type[6 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[6:-1])
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# Tuples (poor man's version - convert to array) # Tuples (poor man's version - convert to array)
if db_type.startswith('Tuple'): if db_type.startswith("Tuple"):
types = [s.strip() for s in db_type[6 : -1].split(',')] types = [s.strip() for s in db_type[6:-1].split(",")]
assert len(set(types)) == 1, 'No support for mixed types in tuples - ' + db_type assert len(set(types)) == 1, "No support for mixed types in tuples - " + db_type
inner_field = cls.create_ad_hoc_field(types[0]) inner_field = metacls.create_ad_hoc_field(types[0])
return orm_fields.ArrayField(inner_field) return orm_fields.ArrayField(inner_field)
# FixedString # FixedString
if db_type.startswith('FixedString'): if db_type.startswith("FixedString"):
length = int(db_type[12 : -1]) length = int(db_type[12:-1])
return orm_fields.FixedStringField(length) return orm_fields.FixedStringField(length)
# Decimal / Decimal32 / Decimal64 / Decimal128 # Decimal / Decimal32 / Decimal64 / Decimal128
if db_type.startswith('Decimal'): if db_type.startswith("Decimal"):
p = db_type.index('(') p = db_type.index("(")
args = [int(n.strip()) for n in db_type[p + 1 : -1].split(',')] args = [int(n.strip()) for n in db_type[p + 1 : -1].split(",")]
field_class = getattr(orm_fields, db_type[:p] + 'Field') field_class = getattr(orm_fields, db_type[:p] + "Field")
return field_class(*args) return field_class(*args)
# Nullable # Nullable
if db_type.startswith('Nullable'): if db_type.startswith("Nullable"):
inner_field = cls.create_ad_hoc_field(db_type[9 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[9:-1])
return orm_fields.NullableField(inner_field) return orm_fields.NullableField(inner_field)
# LowCardinality # LowCardinality
if db_type.startswith('LowCardinality'): if db_type.startswith("LowCardinality"):
inner_field = cls.create_ad_hoc_field(db_type[15 : -1]) inner_field = metacls.create_ad_hoc_field(db_type[15:-1])
return orm_fields.LowCardinalityField(inner_field) return orm_fields.LowCardinalityField(inner_field)
# Simple fields # Simple fields
name = db_type + 'Field' name = db_type + "Field"
if not hasattr(orm_fields, name): if not hasattr(orm_fields, name):
raise NotImplementedError('No field class for %s' % db_type) raise NotImplementedError("No field class for %s" % db_type)
return getattr(orm_fields, name)() return getattr(orm_fields, name)()
class Model(metaclass=ModelBase): class Model(metaclass=ModelBase):
''' """
A base class for ORM models. Each model class represent a ClickHouse table. For example: A base class for ORM models. Each model class represent a ClickHouse table. For example:
class CPUStats(Model): class CPUStats(Model):
@ -261,7 +257,7 @@ class Model(metaclass=ModelBase):
cpu_id = UInt16Field() cpu_id = UInt16Field()
cpu_percent = Float32Field() cpu_percent = Float32Field()
engine = Memory() engine = Memory()
''' """
engine = None engine = None
@ -274,12 +270,12 @@ class Model(metaclass=ModelBase):
_database = None _database = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
''' """
Creates a model instance, using keyword arguments as field values. Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type, Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised. invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`. Unrecognized field names will cause an `AttributeError`.
''' """
super(Model, self).__init__() super(Model, self).__init__()
# Assign default values # Assign default values
self.__dict__.update(self._defaults) self.__dict__.update(self._defaults)
@ -289,13 +285,13 @@ class Model(metaclass=ModelBase):
if field: if field:
setattr(self, name, value) setattr(self, name, value)
else: else:
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) raise AttributeError("%s does not have a field called %s" % (self.__class__.__name__, name))
def __setattr__(self, name, value): def __setattr__(self, name, value):
''' """
When setting a field value, converts the value to its Pythonic type and validates it. When setting a field value, converts the value to its Pythonic type and validates it.
This may raise a `ValueError`. This may raise a `ValueError`.
''' """
field = self.get_field(name) field = self.get_field(name)
if field and (value != NO_VALUE): if field and (value != NO_VALUE):
try: try:
@ -308,77 +304,78 @@ class Model(metaclass=ModelBase):
super(Model, self).__setattr__(name, value) super(Model, self).__setattr__(name, value)
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
# This can not be imported globally due to circular import # This can not be imported globally due to circular import
from .database import Database from .database import Database
assert isinstance(db, Database), "database must be database.Database instance" assert isinstance(db, Database), "database must be database.Database instance"
self._database = db self._database = db
def get_database(self): def get_database(self):
''' """
Gets the `Database` that this model instance belongs to. Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it. Returns `None` unless the instance was read from the database or written to it.
''' """
return self._database return self._database
def get_field(self, name): def get_field(self, name):
''' """
Gets a `Field` instance given its name, or `None` if not found. Gets a `Field` instance given its name, or `None` if not found.
''' """
return self._fields.get(name) return self._fields.get(name)
@classmethod @classmethod
def table_name(cls): def table_name(cls):
''' """
Returns the model's database table name. By default this is the Returns the model's database table name. By default this is the
class name converted to lowercase. Override this if you want to use class name converted to lowercase. Override this if you want to use
a different table name. a different table name.
''' """
return cls.__name__.lower() return cls.__name__.lower()
@classmethod @classmethod
def has_funcs_as_defaults(cls): def has_funcs_as_defaults(cls):
''' """
Return True if some of the model's fields use a function expression Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances. as a default value. This requires special handling when inserting instances.
''' """
return cls._has_funcs_as_defaults return cls._has_funcs_as_defaults
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
# Fields # Fields
items = [] items = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
items.append(' %s %s' % (name, field.get_sql(db=db))) items.append(" %s %s" % (name, field.get_sql(db=db)))
# Constraints # Constraints
for c in cls._constraints.values(): for c in cls._constraints.values():
items.append(' %s' % c.create_table_sql()) items.append(" %s" % c.create_table_sql())
# Indexes # Indexes
for i in cls._indexes.values(): for i in cls._indexes.values():
items.append(' %s' % i.create_table_sql()) items.append(" %s" % i.create_table_sql())
parts.append(',\n'.join(items)) parts.append(",\n".join(items))
# Engine # Engine
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
@classmethod @classmethod
def drop_table_sql(cls, db): def drop_table_sql(cls, db):
''' """
Returns the SQL command for deleting this model's table. Returns the SQL command for deleting this model's table.
''' """
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db.db_name, cls.table_name()) return "DROP TABLE IF EXISTS `%s`.`%s`" % (db.db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None): def from_tsv(cls, line, field_names, timezone_in_use=pytz.utc, database=None):
''' """
Create a model instance from a tab-separated line. The line may or may not include a newline. Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them. The `field_names` list must match the fields defined in the model, but does not have to include all of them.
@ -386,12 +383,12 @@ class Model(metaclass=ModelBase):
- `field_names`: names of the model fields in the data. - `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones. - `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to. - `database`: if given, sets the database that this instance belongs to.
''' """
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
field = getattr(cls, name) field = getattr(cls, name)
field_timezone = getattr(field, 'timezone', None) or timezone_in_use field_timezone = getattr(field, "timezone", None) or timezone_in_use
kwargs[name] = field.to_python(next(values), field_timezone) kwargs[name] = field.to_python(next(values), field_timezone)
obj = cls(**kwargs) obj = cls(**kwargs)
@ -401,45 +398,45 @@ class Model(metaclass=ModelBase):
return obj return obj
def to_tsv(self, include_readonly=True): def to_tsv(self, include_readonly=True):
''' """
Returns the instance's column values as a tab-separated line. A newline is not included. Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items()) return "\t".join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_tskv(self, include_readonly=True): def to_tskv(self, include_readonly=True):
''' """
Returns the instance's column keys and values as a tab-separated line. A newline is not included. Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted. Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
''' """
data = self.__dict__ data = self.__dict__
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
parts = [] parts = []
for name, field in fields.items(): for name, field in fields.items():
if data[name] != NO_VALUE: if data[name] != NO_VALUE:
parts.append(name + '=' + field.to_db_string(data[name], quote=False)) parts.append(name + "=" + field.to_db_string(data[name], quote=False))
return '\t'.join(parts) return "\t".join(parts)
def to_db_string(self): def to_db_string(self):
''' """
Returns the instance as a bytestring ready to be inserted into the database. Returns the instance as a bytestring ready to be inserted into the database.
''' """
s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False) s = self.to_tskv(False) if self._has_funcs_as_defaults else self.to_tsv(False)
s += '\n' s += "\n"
return s.encode('utf-8') return s.encode("utf-8")
def to_dict(self, include_readonly=True, field_names=None): def to_dict(self, include_readonly=True, field_names=None):
''' """
Returns the instance's column values as a dict. Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional) - `field_names`: an iterable of field names to return (optional)
''' """
fields = self.fields(writable=not include_readonly) fields = self.fields(writable=not include_readonly)
if field_names is not None: if field_names is not None:
@ -450,56 +447,58 @@ class Model(metaclass=ModelBase):
@classmethod @classmethod
def objects_in(cls, database): def objects_in(cls, database):
''' """
Returns a `QuerySet` for selecting instances of this model class. Returns a `QuerySet` for selecting instances of this model class.
''' """
return QuerySet(cls, database) return QuerySet(cls, database)
@classmethod @classmethod
def fields(cls, writable=False): def fields(cls, writable=False):
''' """
Returns an `OrderedDict` of the model's fields (from name to `Field` instance). Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included. If `writable` is true, only writable fields are included.
Callers should not modify the dictionary. Callers should not modify the dictionary.
''' """
# noinspection PyProtectedMember,PyUnresolvedReferences # noinspection PyProtectedMember,PyUnresolvedReferences
return cls._writable_fields if writable else cls._fields return cls._writable_fields if writable else cls._fields
@classmethod @classmethod
def is_read_only(cls): def is_read_only(cls):
''' """
Returns true if the model is marked as read only. Returns true if the model is marked as read only.
''' """
return cls._readonly return cls._readonly
@classmethod @classmethod
def is_system_model(cls): def is_system_model(cls):
''' """
Returns true if the model represents a system table. Returns true if the model represents a system table.
''' """
return cls._system return cls._system
class BufferModel(Model): class BufferModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`' % (db.db_name, cls.table_name(), db.db_name, parts = [
cls.engine.main_model.table_name())] "CREATE TABLE IF NOT EXISTS `%s`.`%s` AS `%s`.`%s`"
% (db.db_name, cls.table_name(), db.db_name, cls.engine.main_model.table_name())
]
engine_str = cls.engine.create_table_sql(db) engine_str = cls.engine.create_table_sql(db)
parts.append(engine_str) parts.append(engine_str)
return ' '.join(parts) return " ".join(parts)
class MergeModel(Model): class MergeModel(Model):
''' """
Model for Merge engine Model for Merge engine
Predefines virtual _table column an controls that rows can't be inserted to this table type Predefines virtual _table column an controls that rows can't be inserted to this table type
https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge https://clickhouse.tech/docs/en/single/index.html#document-table_engines/merge
''' """
readonly = True readonly = True
# Virtual fields can't be inserted into database # Virtual fields can't be inserted into database
@ -507,19 +506,20 @@ class MergeModel(Model):
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge" assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())] parts = ["CREATE TABLE IF NOT EXISTS `%s`.`%s` (" % (db.db_name, cls.table_name())]
cols = [] cols = []
for name, field in cls.fields().items(): for name, field in cls.fields().items():
if name != '_table': if name != "_table":
cols.append(' %s %s' % (name, field.get_sql(db=db))) cols.append(" %s %s" % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols)) parts.append(",\n".join(cols))
parts.append(')') parts.append(")")
parts.append('ENGINE = ' + cls.engine.create_table_sql(db)) parts.append("ENGINE = " + cls.engine.create_table_sql(db))
return '\n'.join(parts) return "\n".join(parts)
# TODO: base class for models that require specific engine # TODO: base class for models that require specific engine
@ -530,10 +530,10 @@ class DistributedModel(Model):
""" """
def set_database(self, db): def set_database(self, db):
''' """
Sets the `Database` that this model instance belongs to. Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it. This is done automatically when the instance is read from the database or written to it.
''' """
assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed" assert isinstance(self.engine, Distributed), "engine must be an instance of engines.Distributed"
res = super(DistributedModel, self).set_database(db) res = super(DistributedModel, self).set_database(db)
return res return res
@ -576,33 +576,37 @@ class DistributedModel(Model):
return return
# find out all the superclasses of the Model that store any data # find out all the superclasses of the Model that store any data
storage_models = [b for b in cls.__bases__ if issubclass(b, Model) storage_models = [b for b in cls.__bases__ if issubclass(b, Model) and not issubclass(b, DistributedModel)]
and not issubclass(b, DistributedModel)]
if not storage_models: if not storage_models:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has a parent model") "When defining Distributed engine without the table_name " "ensure that your model has a parent model"
)
if len(storage_models) > 1: if len(storage_models) > 1:
raise TypeError("When defining Distributed engine without the table_name " raise TypeError(
"ensure that your model has exactly one non-distributed superclass") "When defining Distributed engine without the table_name "
"ensure that your model has exactly one non-distributed superclass"
)
# enable correct SQL for engine # enable correct SQL for engine
cls.engine.table = storage_models[0] cls.engine.table = storage_models[0]
@classmethod @classmethod
def create_table_sql(cls, db): def create_table_sql(cls, db):
''' """
Returns the SQL statement for creating a table for this model. Returns the SQL statement for creating a table for this model.
''' """
assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance" assert isinstance(cls.engine, Distributed), "engine must be engines.Distributed instance"
cls.fix_engine_table() cls.fix_engine_table()
parts = [ parts = [
'CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`'.format( "CREATE TABLE IF NOT EXISTS `{0}`.`{1}` AS `{0}`.`{2}`".format(
db.db_name, cls.table_name(), cls.engine.table_name), db.db_name, cls.table_name(), cls.engine.table_name
'ENGINE = ' + cls.engine.create_table_sql(db)] ),
return '\n'.join(parts) "ENGINE = " + cls.engine.create_table_sql(db),
]
return "\n".join(parts)
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -1,15 +1,15 @@
from __future__ import unicode_literals
import pytz
from copy import copy, deepcopy from copy import copy, deepcopy
from math import ceil from math import ceil
from datetime import date, datetime
from .utils import comma_join, string_or_func, arg_to_sql
import pytz
from .engines import CollapsingMergeTree, ReplacingMergeTree
from .utils import Page, arg_to_sql, comma_join, string_or_func
# TODO # TODO
# - check that field names are valid # - check that field names are valid
class Operator(object): class Operator(object):
""" """
Base class for filtering operators. Base class for filtering operators.
@ -20,12 +20,13 @@ class Operator(object):
Subclasses should implement this method. It returns an SQL string Subclasses should implement this method. It returns an SQL string
that applies this operator on the given field and value. that applies this operator on the given field and value.
""" """
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _value_to_sql(self, field, value, quote=True): def _value_to_sql(self, field, value, quote=True):
from infi.clickhouse_orm.funcs import F if isinstance(value, Cond):
if isinstance(value, F): # This is an 'in-database' value, rather than a python one
return value.to_sql() return value.to_sql()
return field.to_db_string(field.to_python(value, pytz.utc), quote) return field.to_db_string(field.to_python(value, pytz.utc), quote)
@ -41,9 +42,9 @@ class SimpleOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
if value == '\\N' and self._sql_for_null is not None: if value == "\\N" and self._sql_for_null is not None:
return ' '.join([field_name, self._sql_for_null]) return " ".join([field_name, self._sql_for_null])
return ' '.join([field_name, self._sql_operator, value]) return " ".join([field_name, self._sql_operator, value])
class InOperator(Operator): class InOperator(Operator):
@ -63,7 +64,7 @@ class InOperator(Operator):
pass pass
else: else:
value = comma_join([self._value_to_sql(field, v) for v in value]) value = comma_join([self._value_to_sql(field, v) for v in value])
return '%s IN (%s)' % (field_name, value) return "%s IN (%s)" % (field_name, value)
class LikeOperator(Operator): class LikeOperator(Operator):
@ -79,12 +80,12 @@ class LikeOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value, quote=False) value = self._value_to_sql(field, value, quote=False)
value = value.replace('\\', '\\\\').replace('%', '\\\\%').replace('_', '\\\\_') value = value.replace("\\", "\\\\").replace("%", "\\\\%").replace("_", "\\\\_")
pattern = self._pattern.format(value) pattern = self._pattern.format(value)
if self._case_sensitive: if self._case_sensitive:
return '%s LIKE \'%s\'' % (field_name, pattern) return "%s LIKE '%s'" % (field_name, pattern)
else: else:
return 'lowerUTF8(%s) LIKE lowerUTF8(\'%s\')' % (field_name, pattern) return "lowerUTF8(%s) LIKE lowerUTF8('%s')" % (field_name, pattern)
class IExactOperator(Operator): class IExactOperator(Operator):
@ -95,7 +96,7 @@ class IExactOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
field = getattr(model_cls, field_name) field = getattr(model_cls, field_name)
value = self._value_to_sql(field, value) value = self._value_to_sql(field, value)
return 'lowerUTF8(%s) = lowerUTF8(%s)' % (field_name, value) return "lowerUTF8(%s) = lowerUTF8(%s)" % (field_name, value)
class NotOperator(Operator): class NotOperator(Operator):
@ -108,7 +109,7 @@ class NotOperator(Operator):
def to_sql(self, model_cls, field_name, value): def to_sql(self, model_cls, field_name, value):
# Negate the base operator # Negate the base operator
return 'NOT (%s)' % self._base_operator.to_sql(model_cls, field_name, value) return "NOT (%s)" % self._base_operator.to_sql(model_cls, field_name, value)
class BetweenOperator(Operator): class BetweenOperator(Operator):
@ -126,35 +127,38 @@ class BetweenOperator(Operator):
value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None value0 = self._value_to_sql(field, value[0]) if value[0] is not None or len(str(value[0])) > 0 else None
value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None value1 = self._value_to_sql(field, value[1]) if value[1] is not None or len(str(value[1])) > 0 else None
if value0 and value1: if value0 and value1:
return '%s BETWEEN %s AND %s' % (field_name, value0, value1) return "%s BETWEEN %s AND %s" % (field_name, value0, value1)
if value0 and not value1: if value0 and not value1:
return ' '.join([field_name, '>=', value0]) return " ".join([field_name, ">=", value0])
if value1 and not value0: if value1 and not value0:
return ' '.join([field_name, '<=', value1]) return " ".join([field_name, "<=", value1])
# Define the set of builtin operators # Define the set of builtin operators
_operators = {} _operators = {}
def register_operator(name, sql): def register_operator(name, sql):
_operators[name] = sql _operators[name] = sql
register_operator('eq', SimpleOperator('=', 'IS NULL'))
register_operator('ne', SimpleOperator('!=', 'IS NOT NULL')) register_operator("eq", SimpleOperator("=", "IS NULL"))
register_operator('gt', SimpleOperator('>')) register_operator("ne", SimpleOperator("!=", "IS NOT NULL"))
register_operator('gte', SimpleOperator('>=')) register_operator("gt", SimpleOperator(">"))
register_operator('lt', SimpleOperator('<')) register_operator("gte", SimpleOperator(">="))
register_operator('lte', SimpleOperator('<=')) register_operator("lt", SimpleOperator("<"))
register_operator('between', BetweenOperator()) register_operator("lte", SimpleOperator("<="))
register_operator('in', InOperator()) register_operator("between", BetweenOperator())
register_operator('not_in', NotOperator(InOperator())) register_operator("in", InOperator())
register_operator('contains', LikeOperator('%{}%')) register_operator("not_in", NotOperator(InOperator()))
register_operator('startswith', LikeOperator('{}%')) register_operator("contains", LikeOperator("%{}%"))
register_operator('endswith', LikeOperator('%{}')) register_operator("startswith", LikeOperator("{}%"))
register_operator('icontains', LikeOperator('%{}%', False)) register_operator("endswith", LikeOperator("%{}"))
register_operator('istartswith', LikeOperator('{}%', False)) register_operator("icontains", LikeOperator("%{}%", False))
register_operator('iendswith', LikeOperator('%{}', False)) register_operator("istartswith", LikeOperator("{}%", False))
register_operator('iexact', IExactOperator()) register_operator("iendswith", LikeOperator("%{}", False))
register_operator("iexact", IExactOperator())
class Cond(object): class Cond(object):
@ -170,19 +174,20 @@ class FieldCond(Cond):
""" """
A single query condition made up of Field + Operator + Value. A single query condition made up of Field + Operator + Value.
""" """
def __init__(self, field_name, operator, value): def __init__(self, field_name, operator, value):
self._field_name = field_name self._field_name = field_name
self._operator = _operators.get(operator) self._operator = _operators.get(operator)
if self._operator is None: if self._operator is None:
# The field name contains __ like my__field # The field name contains __ like my__field
self._field_name = field_name + '__' + operator self._field_name = field_name + "__" + operator
self._operator = _operators['eq'] self._operator = _operators["eq"]
self._value = value self._value = value
def to_sql(self, model_cls): def to_sql(self, model_cls):
return self._operator.to_sql(model_cls, self._field_name, self._value) return self._operator.to_sql(model_cls, self._field_name, self._value)
def __deepcopy__(self, memodict={}): def __deepcopy__(self, memo):
res = copy(self) res = copy(self)
res._value = deepcopy(self._value) res._value = deepcopy(self._value)
return res return res
@ -190,8 +195,8 @@ class FieldCond(Cond):
class Q(object): class Q(object):
AND_MODE = 'AND' AND_MODE = "AND"
OR_MODE = 'OR' OR_MODE = "OR"
def __init__(self, *filter_funcs, **filter_fields): def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()] self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
@ -205,29 +210,26 @@ class Q(object):
Checks if there are any conditions in Q object Checks if there are any conditions in Q object
Returns: Boolean Returns: Boolean
""" """
return not bool(self._conds or self._children) return not (self._conds or self._children)
@classmethod @classmethod
def _construct_from(cls, l_child, r_child, mode): def _construct_from(cls, l_child, r_child, mode):
if mode == l_child._mode: if mode == l_child._mode and not l_child._negate:
q = deepcopy(l_child) q = deepcopy(l_child)
q._children.append(deepcopy(r_child)) q._children.append(deepcopy(r_child))
elif mode == r_child._mode:
q = deepcopy(r_child)
q._children.append(deepcopy(l_child))
else: else:
# Different modes q = cls()
q = Q()
q._children = [l_child, r_child] q._children = [l_child, r_child]
q._mode = mode # AND/OR q._mode = mode
return q return q
def _build_cond(self, key, value): def _build_cond(self, key, value):
if '__' in key: if "__" in key:
field_name, operator = key.rsplit('__', 1) field_name, operator = key.rsplit("__", 1)
else: else:
field_name, operator = key, 'eq' field_name, operator = key, "eq"
return FieldCond(field_name, operator, value) return FieldCond(field_name, operator, value)
def to_sql(self, model_cls): def to_sql(self, model_cls):
@ -241,24 +243,30 @@ class Q(object):
if not condition_sql: if not condition_sql:
# Empty Q() object returns everything # Empty Q() object returns everything
sql = '1' sql = "1"
elif len(condition_sql) == 1: elif len(condition_sql) == 1:
# Skip not needed brackets over single condition # Skip not needed brackets over single condition
sql = condition_sql[0] sql = condition_sql[0]
else: else:
# Each condition must be enclosed in brackets, or order of operations may be wrong # Each condition must be enclosed in brackets, or order of operations may be wrong
sql = '(%s)' % ') {} ('.format(self._mode).join(condition_sql) sql = "(%s)" % ") {} (".format(self._mode).join(condition_sql)
if self._negate: if self._negate:
sql = 'NOT (%s)' % sql sql = "NOT (%s)" % sql
return sql return sql
def __or__(self, other): def __or__(self, other):
return Q._construct_from(self, other, self.OR_MODE) if not isinstance(other, Q):
return NotImplemented
return self.__class__._construct_from(self, other, self.OR_MODE)
def __and__(self, other): def __and__(self, other):
return Q._construct_from(self, other, self.AND_MODE) if not isinstance(other, Q):
return NotImplemented
return self.__class__._construct_from(self, other, self.AND_MODE)
def __invert__(self): def __invert__(self):
q = copy(self) q = copy(self)
@ -268,8 +276,8 @@ class Q(object):
def __bool__(self): def __bool__(self):
return not self.is_empty return not self.is_empty
def __deepcopy__(self, memodict={}): def __deepcopy__(self, memo):
q = Q() q = self.__class__()
q._conds = [deepcopy(cond) for cond in self._conds] q._conds = [deepcopy(cond) for cond in self._conds]
q._negate = self._negate q._negate = self._negate
q._mode = self._mode q._mode = self._mode
@ -318,7 +326,7 @@ class QuerySet(object):
""" """
return bool(self.count()) return bool(self.count())
def __nonzero__(self): # Python 2 compatibility def __nonzero__(self): # Python 2 compatibility
return type(self).__bool__(self) return type(self).__bool__(self)
def __str__(self): def __str__(self):
@ -327,17 +335,17 @@ class QuerySet(object):
def __getitem__(self, s): def __getitem__(self, s):
if isinstance(s, int): if isinstance(s, int):
# Single index # Single index
assert s >= 0, 'negative indexes are not supported' assert s >= 0, "negative indexes are not supported"
qs = copy(self) qs = copy(self)
qs._limits = (s, 1) qs._limits = (s, 1)
return next(iter(qs)) return next(iter(qs))
else: else:
# Slice # Slice
assert s.step in (None, 1), 'step is not supported in slices' assert s.step in (None, 1), "step is not supported in slices"
start = s.start or 0 start = s.start or 0
stop = s.stop or 2**63 - 1 stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start >= 0 and stop >= 0, "negative indexes are not supported"
assert start <= stop, 'start of slice cannot be smaller than its end' assert start <= stop, "start of slice cannot be smaller than its end"
qs = copy(self) qs = copy(self)
qs._limits = (start, stop - start) qs._limits = (start, stop - start)
return qs return qs
@ -353,7 +361,7 @@ class QuerySet(object):
offset_limit = (0, offset_limit) offset_limit = (0, offset_limit)
offset = offset_limit[0] offset = offset_limit[0]
limit = offset_limit[1] limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported' assert offset >= 0 and limit >= 0, "negative limits are not supported"
qs = copy(self) qs = copy(self)
qs._limit_by = (offset, limit) qs._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr qs._limit_by_fields = fields_or_expr
@ -363,44 +371,44 @@ class QuerySet(object):
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
fields = '*' fields = "*"
if self._fields: if self._fields:
fields = comma_join('`%s`' % field for field in self._fields) fields = comma_join("`%s`" % field for field in self._fields)
return fields return fields
def as_sql(self): def as_sql(self):
""" """
Returns the whole query as a SQL string. Returns the whole query as a SQL string.
""" """
distinct = 'DISTINCT ' if self._distinct else '' distinct = "DISTINCT " if self._distinct else ""
final = ' FINAL' if self._final else '' final = " FINAL" if self._final else ""
table_name = '`%s`' % self._model_cls.table_name() table_name = "`%s`" % self._model_cls.table_name()
if self._model_cls.is_system_model(): if self._model_cls.is_system_model():
table_name = '`system`.' + table_name table_name = "`system`." + table_name
params = (distinct, self.select_fields_as_sql(), table_name, final) params = (distinct, self.select_fields_as_sql(), table_name, final)
sql = u'SELECT %s%s\nFROM %s%s' % params sql = "SELECT %s%s\nFROM %s%s" % params
if self._prewhere_q and not self._prewhere_q.is_empty: if self._prewhere_q and not self._prewhere_q.is_empty:
sql += '\nPREWHERE ' + self.conditions_as_sql(prewhere=True) sql += "\nPREWHERE " + self.conditions_as_sql(prewhere=True)
if self._where_q and not self._where_q.is_empty: if self._where_q and not self._where_q.is_empty:
sql += '\nWHERE ' + self.conditions_as_sql(prewhere=False) sql += "\nWHERE " + self.conditions_as_sql(prewhere=False)
if self._grouping_fields: if self._grouping_fields:
sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) sql += "\nGROUP BY %s" % comma_join("`%s`" % field for field in self._grouping_fields)
if self._grouping_with_totals: if self._grouping_with_totals:
sql += ' WITH TOTALS' sql += " WITH TOTALS"
if self._order_by: if self._order_by:
sql += '\nORDER BY ' + self.order_by_as_sql() sql += "\nORDER BY " + self.order_by_as_sql()
if self._limit_by: if self._limit_by:
sql += '\nLIMIT %d, %d' % self._limit_by sql += "\nLIMIT %d, %d" % self._limit_by
sql += ' BY %s' % comma_join(string_or_func(field) for field in self._limit_by_fields) sql += " BY %s" % comma_join(string_or_func(field) for field in self._limit_by_fields)
if self._limits: if self._limits:
sql += '\nLIMIT %d, %d' % self._limits sql += "\nLIMIT %d, %d" % self._limits
return sql return sql
@ -408,10 +416,12 @@ class QuerySet(object):
""" """
Returns the contents of the query's `ORDER BY` clause as a string. Returns the contents of the query's `ORDER BY` clause as a string.
""" """
return comma_join([ return comma_join(
'%s DESC' % field[1:] if isinstance(field, str) and field[0] == '-' else str(field) [
for field in self._order_by "%s DESC" % field[1:] if isinstance(field, str) and field[0] == "-" else str(field)
]) for field in self._order_by
]
)
def conditions_as_sql(self, prewhere=False): def conditions_as_sql(self, prewhere=False):
""" """
@ -426,7 +436,7 @@ class QuerySet(object):
""" """
if self._distinct or self._limits: if self._distinct or self._limits:
# Use a subquery, since a simple count won't be accurate # Use a subquery, since a simple count won't be accurate
sql = u'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0
@ -453,10 +463,8 @@ class QuerySet(object):
return qs return qs
def _filter_or_exclude(self, *q, **kwargs): def _filter_or_exclude(self, *q, **kwargs):
from .funcs import F inverse = kwargs.pop("_inverse", False)
prewhere = kwargs.pop("prewhere", False)
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
qs = copy(self) qs = copy(self)
@ -464,10 +472,10 @@ class QuerySet(object):
for arg in q: for arg in q:
if isinstance(arg, Q): if isinstance(arg, Q):
condition &= arg condition &= arg
elif isinstance(arg, F): elif isinstance(arg, Cond):
condition &= Q(arg) condition &= Q(arg)
else: else:
raise TypeError('Invalid argument "%r" to queryset filter' % arg) raise TypeError(f"Invalid argument '{arg}' of type '{type(arg)}' to filter")
if kwargs: if kwargs:
condition &= Q(**kwargs) condition &= Q(**kwargs)
@ -509,20 +517,19 @@ class QuerySet(object):
The result is a namedtuple containing `objects` (list), `number_of_objects`, The result is a namedtuple containing `objects` (list), `number_of_objects`,
`pages_total`, `number` (of the current page), and `page_size`. `pages_total`, `number` (of the current page), and `page_size`.
""" """
from .database import Page
count = self.count() count = self.count()
pages_total = int(ceil(count / float(page_size))) pages_total = int(ceil(count / float(page_size)))
if page_num == -1: if page_num == -1:
page_num = pages_total page_num = pages_total
elif page_num < 1: elif page_num < 1:
raise ValueError('Invalid page number: %d' % page_num) raise ValueError("Invalid page number: %d" % page_num)
offset = (page_num - 1) * page_size offset = (page_num - 1) * page_size
return Page( return Page(
objects=list(self[offset : offset + page_size]), objects=list(self[offset : offset + page_size]),
number_of_objects=count, number_of_objects=count,
pages_total=pages_total, pages_total=pages_total,
number=page_num, number=page_num,
page_size=page_size page_size=page_size,
) )
def distinct(self): def distinct(self):
@ -539,9 +546,10 @@ class QuerySet(object):
Adds a FINAL modifier to table, meaning data will be collapsed to final version. Adds a FINAL modifier to table, meaning data will be collapsed to final version.
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
""" """
from .engines import CollapsingMergeTree, ReplacingMergeTree
if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)): if not isinstance(self._model_cls.engine, (CollapsingMergeTree, ReplacingMergeTree)):
raise TypeError('final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines') raise TypeError(
"final() method can be used only with the CollapsingMergeTree and ReplacingMergeTree engines"
)
qs = copy(self) qs = copy(self)
qs._final = True qs._final = True
@ -554,7 +562,7 @@ class QuerySet(object):
""" """
self._verify_mutation_allowed() self._verify_mutation_allowed()
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` DELETE WHERE %s' % (self._model_cls.table_name(), conditions) sql = "ALTER TABLE $db.`%s` DELETE WHERE %s" % (self._model_cls.table_name(), conditions)
self._database.raw(sql) self._database.raw(sql)
return self return self
@ -564,22 +572,22 @@ class QuerySet(object):
Keyword arguments specify the field names and expressions to use for the update. Keyword arguments specify the field names and expressions to use for the update.
Note that ClickHouse performs updates in the background, so they are not immediate. Note that ClickHouse performs updates in the background, so they are not immediate.
""" """
assert kwargs, 'No fields specified for update' assert kwargs, "No fields specified for update"
self._verify_mutation_allowed() self._verify_mutation_allowed()
fields = comma_join('`%s` = %s' % (name, arg_to_sql(expr)) for name, expr in kwargs.items()) fields = comma_join("`%s` = %s" % (name, arg_to_sql(expr)) for name, expr in kwargs.items())
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
sql = 'ALTER TABLE $db.`%s` UPDATE %s WHERE %s' % (self._model_cls.table_name(), fields, conditions) sql = "ALTER TABLE $db.`%s` UPDATE %s WHERE %s" % (self._model_cls.table_name(), fields, conditions)
self._database.raw(sql) self._database.raw(sql)
return self return self
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
''' """
Checks that the queryset's state allows mutations. Raises an AssertionError if not. Checks that the queryset's state allows mutations. Raises an AssertionError if not.
''' """
assert not self._limits, 'Mutations are not allowed after slicing the queryset' assert not self._limits, "Mutations are not allowed after slicing the queryset"
assert not self._limit_by, 'Mutations are not allowed after calling limit_by(...)' assert not self._limit_by, "Mutations are not allowed after calling limit_by(...)"
assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._distinct, "Mutations are not allowed after calling distinct()"
assert not self._final, 'Mutations are not allowed after calling final()' assert not self._final, "Mutations are not allowed after calling final()"
def aggregate(self, *args, **kwargs): def aggregate(self, *args, **kwargs):
""" """
@ -619,7 +627,7 @@ class AggregateQuerySet(QuerySet):
At least one calculated field is required. At least one calculated field is required.
""" """
super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database) super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database)
assert calculated_fields, 'No calculated fields specified for aggregation' assert calculated_fields, "No calculated fields specified for aggregation"
self._fields = grouping_fields self._fields = grouping_fields
self._grouping_fields = grouping_fields self._grouping_fields = grouping_fields
self._calculated_fields = calculated_fields self._calculated_fields = calculated_fields
@ -636,8 +644,9 @@ class AggregateQuerySet(QuerySet):
created with. created with.
""" """
for name in args: for name in args:
assert name in self._fields or name in self._calculated_fields, \ assert name in self._fields or name in self._calculated_fields, (
'Cannot group by `%s` since it is not included in the query' % name "Cannot group by `%s` since it is not included in the query" % name
)
qs = copy(self) qs = copy(self)
qs._grouping_fields = args qs._grouping_fields = args
return qs return qs
@ -652,22 +661,24 @@ class AggregateQuerySet(QuerySet):
""" """
This method is not supported on `AggregateQuerySet`. This method is not supported on `AggregateQuerySet`.
""" """
raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') raise NotImplementedError("Cannot re-aggregate an AggregateQuerySet")
def select_fields_as_sql(self): def select_fields_as_sql(self):
""" """
Returns the selected fields or expressions as a SQL string. Returns the selected fields or expressions as a SQL string.
""" """
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) return comma_join(
[str(f) for f in self._fields] + ["%s AS %s" % (v, k) for k, v in self._calculated_fields.items()]
)
def __iter__(self): def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model return self._database.select(self.as_sql()) # using an ad-hoc model
def count(self): def count(self):
""" """
Returns the number of rows after aggregation. Returns the number of rows after aggregation.
""" """
sql = u'SELECT count() FROM (%s)' % self.as_sql() sql = "SELECT count() FROM (%s)" % self.as_sql()
raw = self._database.raw(sql) raw = self._database.raw(sql)
return int(raw) if raw else 0 return int(raw) if raw else 0
@ -682,7 +693,7 @@ class AggregateQuerySet(QuerySet):
return qs return qs
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet') raise AssertionError("Cannot mutate an AggregateQuerySet")
# Expose only relevant classes in import * # Expose only relevant classes in import *

View File

@ -2,10 +2,8 @@
This file contains system readonly models that can be got from the database This file contains system readonly models that can be got from the database
https://clickhouse.tech/docs/en/system_tables/ https://clickhouse.tech/docs/en/system_tables/
""" """
from __future__ import unicode_literals
from .database import Database from .database import Database
from .fields import * from .fields import DateTimeField, StringField, UInt8Field, UInt32Field, UInt64Field
from .models import Model from .models import Model
from .utils import comma_join from .utils import comma_join
@ -16,7 +14,8 @@ class SystemPart(Model):
This model operates only fields, described in the reference. Other fields are ignored. This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/ https://clickhouse.tech/docs/en/system_tables/system.parts/
""" """
OPERATIONS = frozenset({'DETACH', 'DROP', 'ATTACH', 'FREEZE', 'FETCH'})
OPERATIONS = frozenset({"DETACH", "DROP", "ATTACH", "FREEZE", "FETCH"})
_readonly = True _readonly = True
_system = True _system = True
@ -51,12 +50,13 @@ class SystemPart(Model):
@classmethod @classmethod
def table_name(cls): def table_name(cls):
return 'parts' return "parts"
""" """
Next methods return SQL for some operations, which can be done with partitions Next methods return SQL for some operations, which can be done with partitions
https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts https://clickhouse.tech/docs/en/query_language/queries/#manipulations-with-partitions-and-parts
""" """
def _partition_operation_sql(self, operation, settings=None, from_part=None): def _partition_operation_sql(self, operation, settings=None, from_part=None):
""" """
Performs some operation over partition Performs some operation over partition
@ -83,7 +83,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DETACH', settings=settings) return self._partition_operation_sql("DETACH", settings=settings)
def drop(self, settings=None): def drop(self, settings=None):
""" """
@ -93,7 +93,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('DROP', settings=settings) return self._partition_operation_sql("DROP", settings=settings)
def attach(self, settings=None): def attach(self, settings=None):
""" """
@ -103,7 +103,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('ATTACH', settings=settings) return self._partition_operation_sql("ATTACH", settings=settings)
def freeze(self, settings=None): def freeze(self, settings=None):
""" """
@ -113,7 +113,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FREEZE', settings=settings) return self._partition_operation_sql("FREEZE", settings=settings)
def fetch(self, zookeeper_path, settings=None): def fetch(self, zookeeper_path, settings=None):
""" """
@ -124,7 +124,7 @@ class SystemPart(Model):
Returns: SQL Query Returns: SQL Query
""" """
return self._partition_operation_sql('FETCH', settings=settings, from_part=zookeeper_path) return self._partition_operation_sql("FETCH", settings=settings, from_part=zookeeper_path)
@classmethod @classmethod
def get(cls, database, conditions=""): def get(cls, database, conditions=""):
@ -140,9 +140,12 @@ class SystemPart(Model):
assert isinstance(conditions, str), "conditions must be a string" assert isinstance(conditions, str), "conditions must be a string"
if conditions: if conditions:
conditions += " AND" conditions += " AND"
field_names = ','.join(cls.fields()) field_names = ",".join(cls.fields())
return database.select("SELECT %s FROM `system`.%s WHERE %s database='%s'" % return database.select(
(field_names, cls.table_name(), conditions, database.db_name), model_class=cls) "SELECT %s FROM `system`.%s WHERE %s database='%s'"
% (field_names, cls.table_name(), conditions, database.db_name),
model_class=cls,
)
@classmethod @classmethod
def get_active(cls, database, conditions=""): def get_active(cls, database, conditions=""):
@ -155,8 +158,8 @@ class SystemPart(Model):
Returns: A list of SystemPart objects Returns: A list of SystemPart objects
""" """
if conditions: if conditions:
conditions += ' AND ' conditions += " AND "
conditions += 'active' conditions += "active"
return SystemPart.get(database, conditions=conditions) return SystemPart.get(database, conditions=conditions)

View File

@ -1,54 +1,48 @@
import codecs import codecs
import importlib
import pkgutil
import re import re
from datetime import date, datetime, tzinfo, timedelta from collections import namedtuple
from datetime import date, datetime, timedelta, tzinfo
from inspect import isclass
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Type, Union
Page = namedtuple("Page", "objects number_of_objects pages_total number page_size")
Page.__doc__ += "\nA simple data structure for paginated results."
SPECIAL_CHARS = { def escape(value: str, quote: bool = True) -> str:
"\b" : "\\b", """
"\f" : "\\f",
"\r" : "\\r",
"\n" : "\\n",
"\t" : "\\t",
"\0" : "\\0",
"\\" : "\\\\",
"'" : "\\'"
}
SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
def escape(value, quote=True):
'''
If the value is a string, escapes any special characters and optionally If the value is a string, escapes any special characters and optionally
surrounds it with single quotes. If the value is not a string (e.g. a number), surrounds it with single quotes. If the value is not a string (e.g. a number),
converts it to one. converts it to one.
''' """
def escape_one(match): value = codecs.escape_encode(value.encode("utf-8"))[0].decode("utf-8")
return SPECIAL_CHARS[match.group(0)] if quote:
value = "'" + value + "'"
if isinstance(value, str): return value
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
if quote:
value = "'" + value + "'"
return str(value)
def unescape(value): def unescape(value: str) -> Optional[str]:
return codecs.escape_decode(value)[0].decode('utf-8') if value == "\\N":
return None
return codecs.escape_decode(value)[0].decode("utf-8")
def string_or_func(obj): def string_or_func(obj):
return obj.to_sql() if hasattr(obj, 'to_sql') else obj return obj.to_sql() if hasattr(obj, "to_sql") else obj
def arg_to_sql(arg): def arg_to_sql(arg: Any) -> str:
""" """
Converts a function argument to SQL string according to its type. Converts a function argument to SQL string according to its type.
Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans, Supports functions, model fields, strings, dates, datetimes, timedeltas, booleans,
None, numbers, timezones, arrays/iterables. None, numbers, timezones, arrays/iterables.
""" """
from infi.clickhouse_orm import Field, StringField, DateTimeField, DateField, F, QuerySet from clickhouse_orm import DateTimeField, F, Field, QuerySet, StringField
if isinstance(arg, F): if isinstance(arg, F):
return arg.to_sql() return arg.to_sql()
if isinstance(arg, Field): if isinstance(arg, Field):
@ -66,42 +60,42 @@ def arg_to_sql(arg):
if isinstance(arg, tzinfo): if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None)) return StringField().to_db_string(arg.tzname(None))
if arg is None: if arg is None:
return 'NULL' return "NULL"
if isinstance(arg, QuerySet): if isinstance(arg, QuerySet):
return "(%s)" % arg return "(%s)" % arg
if isinstance(arg, tuple): if isinstance(arg, tuple):
return '(' + comma_join(arg_to_sql(x) for x in arg) + ')' return "(" + comma_join(arg_to_sql(x) for x in arg) + ")"
if is_iterable(arg): if is_iterable(arg):
return '[' + comma_join(arg_to_sql(x) for x in arg) + ']' return "[" + comma_join(arg_to_sql(x) for x in arg) + "]"
return str(arg) return str(arg)
def parse_tsv(line): def parse_tsv(line: Union[bytes, str]) -> List[str]:
if isinstance(line, bytes): if isinstance(line, bytes):
line = line.decode() line = line.decode()
if line and line[-1] == '\n': if line and line[-1] == "\n":
line = line[:-1] line = line[:-1]
return [unescape(value) for value in line.split(str('\t'))] return [unescape(value) for value in line.split("\t")]
def parse_array(array_string): def parse_array(array_string: str) -> List[Any]:
""" """
Parse an array or tuple string as returned by clickhouse. For example: Parse an array or tuple string as returned by clickhouse. For example:
"['hello', 'world']" ==> ["hello", "world"] "['hello', 'world']" ==> ["hello", "world"]
"(1,2,3)" ==> [1, 2, 3] "(1,2,3)" ==> [1, 2, 3]
""" """
# Sanity check # Sanity check
if len(array_string) < 2 or array_string[0] not in '[(' or array_string[-1] not in '])': if len(array_string) < 2 or array_string[0] not in "[(" or array_string[-1] not in "])":
raise ValueError('Invalid array string: "%s"' % array_string) raise ValueError('Invalid array string: "%s"' % array_string)
# Drop opening brace # Drop opening brace
array_string = array_string[1:] array_string = array_string[1:]
# Go over the string, lopping off each value at the beginning until nothing is left # Go over the string, lopping off each value at the beginning until nothing is left
values = [] values = []
while True: while True:
if array_string in '])': if array_string in "])":
# End of array # End of array
return values return values
elif array_string[0] in ', ': elif array_string[0] in ", ":
# In between values # In between values
array_string = array_string[1:] array_string = array_string[1:]
elif array_string[0] == "'": elif array_string[0] == "'":
@ -110,37 +104,33 @@ def parse_array(array_string):
if match is None: if match is None:
raise ValueError('Missing closing quote: "%s"' % array_string) raise ValueError('Missing closing quote: "%s"' % array_string)
values.append(array_string[1 : match.start() + 1]) values.append(array_string[1 : match.start() + 1])
array_string = array_string[match.end():] array_string = array_string[match.end() :]
else: else:
# Start of non-quoted value, find its end # Start of non-quoted value, find its end
match = re.search(r",|\]", array_string) match = re.search(r",|\]", array_string)
values.append(array_string[0 : match.start()]) values.append(array_string[0 : match.start()])
array_string = array_string[match.end() - 1:] array_string = array_string[match.end() - 1 :]
def import_submodules(package_name): def import_submodules(package_name: str) -> Dict[str, ModuleType]:
""" """
Import all submodules of a module. Import all submodules of a module.
""" """
import importlib, pkgutil
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
return { return {
name: importlib.import_module(package_name + '.' + name) name: importlib.import_module(package_name + "." + name)
for _, name, _ in pkgutil.iter_modules(package.__path__) for _, name, _ in pkgutil.iter_modules(package.__path__)
} }
def comma_join(items, stringify=False): def comma_join(items: Iterable[str]) -> str:
""" """
Joins an iterable of strings with commas. Joins an iterable of strings with commas.
""" """
if stringify: return ", ".join(items)
return ', '.join(str(item) for item in items)
else:
return ', '.join(items)
def is_iterable(obj): def is_iterable(obj: Any) -> bool:
""" """
Checks if the given object is iterable. Checks if the given object is iterable.
""" """
@ -151,17 +141,24 @@ def is_iterable(obj):
return False return False
def get_subclass_names(locals, base_class): def get_subclass_names(locals: Dict[str, Any], base_class: Type):
from inspect import isclass
return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)] return [c.__name__ for c in locals.values() if isclass(c) and issubclass(c, base_class)]
class NoValue: class NoValue:
''' """
A sentinel for fields with an expression for a default value, A sentinel for fields with an expression for a default value,
that were not assigned a value yet. that were not assigned a value yet.
''' """
def __repr__(self): def __repr__(self):
return 'NO_VALUE' return "NO_VALUE"
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
NO_VALUE = NoValue() NO_VALUE = NoValue()

View File

@ -1,8 +1,8 @@
Class Reference Class Reference
=============== ===============
infi.clickhouse_orm.database clickhouse_orm.database
---------------------------- -----------------------
### Database ### Database
@ -152,8 +152,8 @@ Extends Exception
Raised when a database operation fails. Raised when a database operation fails.
infi.clickhouse_orm.models clickhouse_orm.models
-------------------------- ---------------------
### Model ### Model
@ -811,8 +811,8 @@ separated by non-alphanumeric characters.
- `random_seed` — The seed for Bloom filter hash functions. - `random_seed` — The seed for Bloom filter hash functions.
infi.clickhouse_orm.fields clickhouse_orm.fields
-------------------------- ---------------------
### ArrayField ### ArrayField
@ -1046,8 +1046,8 @@ Extends Field
#### UUIDField(default=None, alias=None, materialized=None, readonly=None, codec=None) #### UUIDField(default=None, alias=None, materialized=None, readonly=None, codec=None)
infi.clickhouse_orm.engines clickhouse_orm.engines
--------------------------- ----------------------
### Engine ### Engine
@ -1140,8 +1140,8 @@ Extends MergeTree
#### ReplacingMergeTree(date_col=None, order_by=(), ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, primary_key=None) #### ReplacingMergeTree(date_col=None, order_by=(), ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None, primary_key=None)
infi.clickhouse_orm.query clickhouse_orm.query
------------------------- --------------------
### QuerySet ### QuerySet
@ -1443,8 +1443,8 @@ https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier
#### to_sql(model_cls) #### to_sql(model_cls)
infi.clickhouse_orm.funcs clickhouse_orm.funcs
------------------------- --------------------
### F ### F
@ -2012,7 +2012,7 @@ Initializer.
#### floor(n=None) #### floor(n=None)
#### formatDateTime(format, timezone="") #### formatDateTime(format, timezone=NO_VALUE)
#### gcd(b) #### gcd(b)
@ -2804,13 +2804,13 @@ Initializer.
#### toDateTimeOrZero() #### toDateTimeOrZero()
#### toDayOfMonth() #### toDayOfMonth(timezone=NO_VALUE)
#### toDayOfWeek() #### toDayOfWeek(timezone=NO_VALUE)
#### toDayOfYear() #### toDayOfYear(timezone=NO_VALUE)
#### toDecimal128(**kwargs) #### toDecimal128(**kwargs)
@ -2861,7 +2861,7 @@ Initializer.
#### toFloat64OrZero() #### toFloat64OrZero()
#### toHour() #### toHour(timezone=NO_VALUE)
#### toIPv4() #### toIPv4()
@ -2870,10 +2870,10 @@ Initializer.
#### toIPv6() #### toIPv6()
#### toISOWeek(timezone="") #### toISOWeek(timezone=NO_VALUE)
#### toISOYear(timezone="") #### toISOYear(timezone=NO_VALUE)
#### toInt16(**kwargs) #### toInt16(**kwargs)
@ -2936,73 +2936,73 @@ Initializer.
#### toIntervalYear() #### toIntervalYear()
#### toMinute() #### toMinute(timezone=NO_VALUE)
#### toMonday() #### toMonday(timezone=NO_VALUE)
#### toMonth() #### toMonth(timezone=NO_VALUE)
#### toQuarter(timezone="") #### toQuarter(timezone=NO_VALUE)
#### toRelativeDayNum(timezone="") #### toRelativeDayNum(timezone=NO_VALUE)
#### toRelativeHourNum(timezone="") #### toRelativeHourNum(timezone=NO_VALUE)
#### toRelativeMinuteNum(timezone="") #### toRelativeMinuteNum(timezone=NO_VALUE)
#### toRelativeMonthNum(timezone="") #### toRelativeMonthNum(timezone=NO_VALUE)
#### toRelativeSecondNum(timezone="") #### toRelativeSecondNum(timezone=NO_VALUE)
#### toRelativeWeekNum(timezone="") #### toRelativeWeekNum(timezone=NO_VALUE)
#### toRelativeYearNum(timezone="") #### toRelativeYearNum(timezone=NO_VALUE)
#### toSecond() #### toSecond(timezone=NO_VALUE)
#### toStartOfDay() #### toStartOfDay(timezone=NO_VALUE)
#### toStartOfFifteenMinutes() #### toStartOfFifteenMinutes(timezone=NO_VALUE)
#### toStartOfFiveMinute() #### toStartOfFiveMinute(timezone=NO_VALUE)
#### toStartOfHour() #### toStartOfHour(timezone=NO_VALUE)
#### toStartOfISOYear() #### toStartOfISOYear(timezone=NO_VALUE)
#### toStartOfMinute() #### toStartOfMinute(timezone=NO_VALUE)
#### toStartOfMonth() #### toStartOfMonth(timezone=NO_VALUE)
#### toStartOfQuarter() #### toStartOfQuarter(timezone=NO_VALUE)
#### toStartOfTenMinutes() #### toStartOfTenMinutes(timezone=NO_VALUE)
#### toStartOfWeek(mode=0) #### toStartOfWeek(timezone=NO_VALUE)
#### toStartOfYear() #### toStartOfYear(timezone=NO_VALUE)
#### toString() #### toString()
@ -3011,7 +3011,7 @@ Initializer.
#### toStringCutToZero() #### toStringCutToZero()
#### toTime(timezone="") #### toTime(timezone=NO_VALUE)
#### toTimeZone(timezone) #### toTimeZone(timezone)
@ -3056,22 +3056,22 @@ Initializer.
#### toUUID() #### toUUID()
#### toUnixTimestamp(timezone="") #### toUnixTimestamp(timezone=NO_VALUE)
#### toWeek(mode=0, timezone="") #### toWeek(mode=0, timezone=NO_VALUE)
#### toYYYYMM(timezone="") #### toYYYYMM(timezone=NO_VALUE)
#### toYYYYMMDD(timezone="") #### toYYYYMMDD(timezone=NO_VALUE)
#### toYYYYMMDDhhmmss(timezone="") #### toYYYYMMDDhhmmss(timezone=NO_VALUE)
#### toYear() #### toYear(timezone=NO_VALUE)
#### to_sql(*args) #### to_sql(*args)
@ -3144,3 +3144,308 @@ For other functions:
#### uniqExact(**kwargs) #### uniqExact(**kwargs)
#### uniqExactIf(*args)
#### uniqExactOrDefault()
#### uniqExactOrDefaultIf(*args)
#### uniqExactOrNull()
#### uniqExactOrNullIf(*args)
#### uniqHLL12(**kwargs)
#### uniqHLL12If(*args)
#### uniqHLL12OrDefault()
#### uniqHLL12OrDefaultIf(*args)
#### uniqHLL12OrNull()
#### uniqHLL12OrNullIf(*args)
#### uniqIf(*args)
#### uniqOrDefault()
#### uniqOrDefaultIf(*args)
#### uniqOrNull()
#### uniqOrNullIf(*args)
#### upper(**kwargs)
#### upperUTF8()
#### varPop(**kwargs)
#### varPopIf(cond)
#### varPopOrDefault()
#### varPopOrDefaultIf(cond)
#### varPopOrNull()
#### varPopOrNullIf(cond)
#### varSamp(**kwargs)
#### varSampIf(cond)
#### varSampOrDefault()
#### varSampOrDefaultIf(cond)
#### varSampOrNull()
#### varSampOrNullIf(cond)
#### xxHash32()
#### xxHash64()
#### yesterday()
clickhouse_orm.system_models
----------------------------
### SystemPart
Extends Model
Contains information about parts of a table in the MergeTree family.
This model operates only fields, described in the reference. Other fields are ignored.
https://clickhouse.tech/docs/en/system_tables/system.parts/
#### SystemPart(**kwargs)
Creates a model instance, using keyword arguments as field values.
Since values are immediately converted to their Pythonic type,
invalid values will cause a `ValueError` to be raised.
Unrecognized field names will cause an `AttributeError`.
#### attach(settings=None)
Add a new part or partition from the 'detached' directory to the table.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.create_table_sql(db)
Returns the SQL statement for creating a table for this model.
#### detach(settings=None)
Move a partition to the 'detached' directory and forget it.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### drop(settings=None)
Delete a partition
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.drop_table_sql(db)
Returns the SQL command for deleting this model's table.
#### fetch(zookeeper_path, settings=None)
Download a partition from another server.
- `zookeeper_path`: Path in zookeeper to fetch from
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.fields(writable=False)
Returns an `OrderedDict` of the model's fields (from name to `Field` instance).
If `writable` is true, only writable fields are included.
Callers should not modify the dictionary.
#### freeze(settings=None)
Create a backup of a partition.
- `settings`: Settings for executing request to ClickHouse over db.raw() method
Returns: SQL Query
#### SystemPart.from_tsv(line, field_names, timezone_in_use=UTC, database=None)
Create a model instance from a tab-separated line. The line may or may not include a newline.
The `field_names` list must match the fields defined in the model, but does not have to include all of them.
- `line`: the TSV-formatted data.
- `field_names`: names of the model fields in the data.
- `timezone_in_use`: the timezone to use when parsing dates and datetimes. Some fields use their own timezones.
- `database`: if given, sets the database that this instance belongs to.
#### SystemPart.get(database, conditions="")
Get all data from system.parts table
- `database`: A database object to fetch data from.
- `conditions`: WHERE clause conditions. Database condition is added automatically
Returns: A list of SystemPart objects
#### SystemPart.get_active(database, conditions="")
Gets active data from system.parts table
- `database`: A database object to fetch data from.
- `conditions`: WHERE clause conditions. Database and active conditions are added automatically
Returns: A list of SystemPart objects
#### get_database()
Gets the `Database` that this model instance belongs to.
Returns `None` unless the instance was read from the database or written to it.
#### get_field(name)
Gets a `Field` instance given its name, or `None` if not found.
#### SystemPart.has_funcs_as_defaults()
Return True if some of the model's fields use a function expression
as a default value. This requires special handling when inserting instances.
#### SystemPart.is_read_only()
Returns true if the model is marked as read only.
#### SystemPart.is_system_model()
Returns true if the model represents a system table.
#### SystemPart.objects_in(database)
Returns a `QuerySet` for selecting instances of this model class.
#### set_database(db)
Sets the `Database` that this model instance belongs to.
This is done automatically when the instance is read from the database or written to it.
#### SystemPart.table_name()
#### to_db_string()
Returns the instance as a bytestring ready to be inserted into the database.
#### to_dict(include_readonly=True, field_names=None)
Returns the instance's column values as a dict.
- `include_readonly`: if false, returns only fields that can be inserted into database.
- `field_names`: an iterable of field names to return (optional)
#### to_tskv(include_readonly=True)
Returns the instance's column keys and values as a tab-separated line. A newline is not included.
Fields that were not assigned a value are omitted.
- `include_readonly`: if false, returns only fields that can be inserted into database.
#### to_tsv(include_readonly=True)
Returns the instance's column values as a tab-separated line. A newline is not included.
- `include_readonly`: if false, returns only fields that can be inserted into database.

View File

@ -1,7 +1,7 @@
Contributing Contributing
============ ============
This project is hosted on GitHub - [https://github.com/Infinidat/infi.clickhouse_orm/](https://github.com/Infinidat/infi.clickhouse_orm/). This project is hosted on GitHub - [https://github.com/Infinidat/clickhouse_orm/](https://github.com/Infinidat/clickhouse_orm/).
Please open an issue there if you encounter a bug or want to request a feature. Please open an issue there if you encounter a bug or want to request a feature.
Pull requests are also welcome. Pull requests are also welcome.
@ -12,7 +12,7 @@ Building
After cloning the project, run the following commands: After cloning the project, run the following commands:
easy_install -U infi.projector easy_install -U infi.projector
cd infi.clickhouse_orm cd clickhouse_orm
projector devenv build projector devenv build
A `setup.py` file will be generated, which you can use to install the development version of the package: A `setup.py` file will be generated, which you can use to install the development version of the package:
@ -28,7 +28,7 @@ To run the tests, ensure that the ClickHouse server is running on <http://localh
To see test coverage information run: To see test coverage information run:
bin/nosetests --with-coverage --cover-package=infi.clickhouse_orm bin/nosetests --with-coverage --cover-package=clickhouse_orm
To test with tox, ensure that the setup.py is present (otherwise run `bin/buildout buildout:develop= setup.py`) and run: To test with tox, ensure that the setup.py is present (otherwise run `bin/buildout buildout:develop= setup.py`) and run:

View File

@ -13,7 +13,7 @@ Using Expressions
Expressions usually include ClickHouse database functions, which are made available by the `F` class. Here's a simple function: Expressions usually include ClickHouse database functions, which are made available by the `F` class. Here's a simple function:
```python ```python
from infi.clickhouse_orm import F from clickhouse_orm import F
expr = F.today() expr = F.today()
``` ```

View File

@ -25,7 +25,7 @@ class Event(Model):
engine = Memory() engine = Memory()
... ...
``` ```
When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` instead. For example: When creating a model instance, any fields you do not specify get their default value. Fields that use a default expression are assigned a sentinel value of `clickhouse_orm.utils.NO_VALUE` instead. For example:
```python ```python
>>> event = Event() >>> event = Event()
>>> print(event.to_dict()) >>> print(event.to_dict())
@ -63,7 +63,7 @@ db.select('SELECT created, created_date, username, name FROM $db.event', model_c
# created_date and username will contain a default value # created_date and username will contain a default value
db.select('SELECT * FROM $db.event', model_class=Event) db.select('SELECT * FROM $db.event', model_class=Event)
``` ```
When creating a model instance, any alias or materialized fields are assigned a sentinel value of `infi.clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database. When creating a model instance, any alias or materialized fields are assigned a sentinel value of `clickhouse_orm.utils.NO_VALUE` since their real values can only be known after insertion to the database.
## codec ## codec

View File

@ -166,7 +166,7 @@ For example, we can create a BooleanField which will hold `True` and `False` val
Here's the full implementation: Here's the full implementation:
```python ```python
from infi.clickhouse_orm import Field from clickhouse_orm import Field
class BooleanField(Field): class BooleanField(Field):

View File

@ -7,24 +7,24 @@ The ORM supports different styles of importing and referring to its classes, so
Importing Everything Importing Everything
-------------------- --------------------
It is safe to use `import *` from `infi.clickhouse_orm` or its submodules. Only classes that are needed by users of the ORM will get imported, and nothing else: It is safe to use `import *` from `clickhouse_orm` or its submodules. Only classes that are needed by users of the ORM will get imported, and nothing else:
```python ```python
from infi.clickhouse_orm import * from clickhouse_orm import *
``` ```
This is exactly equivalent to the following import statements: This is exactly equivalent to the following import statements:
```python ```python
from infi.clickhouse_orm.database import * from clickhouse_orm.database import *
from infi.clickhouse_orm.engines import * from clickhouse_orm.engines import *
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import *
from infi.clickhouse_orm.funcs import * from clickhouse_orm.funcs import *
from infi.clickhouse_orm.migrations import * from clickhouse_orm.migrations import *
from infi.clickhouse_orm.models import * from clickhouse_orm.models import *
from infi.clickhouse_orm.query import * from clickhouse_orm.query import *
from infi.clickhouse_orm.system_models import * from clickhouse_orm.system_models import *
``` ```
By importing everything, all of the ORM's public classes can be used directly. For example: By importing everything, all of the ORM's public classes can be used directly. For example:
```python ```python
from infi.clickhouse_orm import * from clickhouse_orm import *
class Event(Model): class Event(Model):
@ -40,8 +40,8 @@ Importing Everything into a Namespace
To prevent potential name clashes and to make the code more readable, you can import the ORM's classes into a namespace of your choosing, e.g. `orm`. For brevity, it is recommended to import the `F` class explicitly: To prevent potential name clashes and to make the code more readable, you can import the ORM's classes into a namespace of your choosing, e.g. `orm`. For brevity, it is recommended to import the `F` class explicitly:
```python ```python
import infi.clickhouse_orm as orm import clickhouse_orm as orm
from infi.clickhouse_orm import F from clickhouse_orm import F
class Event(orm.Model): class Event(orm.Model):
@ -57,7 +57,7 @@ Importing Specific Submodules
It is possible to import only the submodules you need, and use their names to qualify the ORM's class names. This option is more verbose, but makes it clear where each class comes from. For example: It is possible to import only the submodules you need, and use their names to qualify the ORM's class names. This option is more verbose, but makes it clear where each class comes from. For example:
```python ```python
from infi.clickhouse_orm import models, fields, engines, F from clickhouse_orm import models, fields, engines, F
class Event(models.Model): class Event(models.Model):
@ -71,9 +71,9 @@ class Event(models.Model):
Importing Specific Classes Importing Specific Classes
-------------------------- --------------------------
If you prefer, you can import only the specific ORM classes that you need directly from `infi.clickhouse_orm`: If you prefer, you can import only the specific ORM classes that you need directly from `clickhouse_orm`:
```python ```python
from infi.clickhouse_orm import Model, StringField, UInt32Field, DateTimeField, F, Memory from clickhouse_orm import Model, StringField, UInt32Field, DateTimeField, F, Memory
class Event(Model): class Event(Model):

View File

@ -8,9 +8,9 @@ Version 1.x supports Python 2.7 and 3.5+. Version 2.x dropped support for Python
Installation Installation
------------ ------------
To install infi.clickhouse_orm: To install clickhouse_orm:
pip install infi.clickhouse_orm pip install clickhouse_orm
--- ---

View File

@ -10,7 +10,7 @@ Defining Models
Models are defined in a way reminiscent of Django's ORM, by subclassing `Model`: Models are defined in a way reminiscent of Django's ORM, by subclassing `Model`:
```python ```python
from infi.clickhouse_orm import Model, StringField, DateField, Float32Field, MergeTree from clickhouse_orm import Model, StringField, DateField, Float32Field, MergeTree
class Person(Model): class Person(Model):
@ -133,7 +133,7 @@ Inserting to the Database
To write your instances to ClickHouse, you need a `Database` instance: To write your instances to ClickHouse, you need a `Database` instance:
from infi.clickhouse_orm import Database from clickhouse_orm import Database
db = Database('my_test_db') db = Database('my_test_db')

View File

@ -1,7 +1,7 @@
Class Reference Class Reference
=============== ===============
infi.clickhouse_orm.database clickhouse_orm.database
---------------------------- ----------------------------
### Database ### Database
@ -104,7 +104,7 @@ Extends Exception
Raised when a database operation fails. Raised when a database operation fails.
infi.clickhouse_orm.models clickhouse_orm.models
-------------------------- --------------------------
### Model ### Model
@ -263,7 +263,7 @@ Returns the instance's column values as a tab-separated line. A newline is not i
- `include_readonly`: if false, returns only fields that can be inserted into database. - `include_readonly`: if false, returns only fields that can be inserted into database.
infi.clickhouse_orm.fields clickhouse_orm.fields
-------------------------- --------------------------
### Field ### Field
@ -419,7 +419,7 @@ Extends BaseEnumField
#### Enum16Field(enum_cls, default=None, alias=None, materialized=None) #### Enum16Field(enum_cls, default=None, alias=None, materialized=None)
infi.clickhouse_orm.engines clickhouse_orm.engines
--------------------------- ---------------------------
### Engine ### Engine
@ -474,7 +474,7 @@ Extends MergeTree
#### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None) #### ReplacingMergeTree(date_col, key_cols, ver_col=None, sampling_expr=None, index_granularity=8192, replica_table_path=None, replica_name=None)
infi.clickhouse_orm.query clickhouse_orm.query
------------------------- -------------------------
### QuerySet ### QuerySet

View File

@ -22,7 +22,7 @@ To write migrations, create a Python package. Then create a python file for the
Each migration file is expected to contain a list of `operations`, for example: Each migration file is expected to contain a list of `operations`, for example:
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from analytics import models from analytics import models
operations = [ operations = [

View File

@ -30,7 +30,7 @@ A partition in a table is data for a single calendar month. Table "system.parts"
Usage example: Usage example:
from infi.clickhouse_orm import Database, SystemPart from clickhouse_orm import Database, SystemPart
db = Database('my_test_db', db_url='http://192.168.1.1:8050', username='scott', password='tiger') db = Database('my_test_db', db_url='http://192.168.1.1:8050', username='scott', password='tiger')
partitions = SystemPart.get_active(db, conditions='') # Getting all active partitions of the database partitions = SystemPart.get_active(db, conditions='') # Getting all active partitions of the database
if len(partitions) > 0: if len(partitions) > 0:

View File

@ -78,17 +78,17 @@
* [Tests](contributing.md#tests) * [Tests](contributing.md#tests)
* [Class Reference](class_reference.md#class-reference) * [Class Reference](class_reference.md#class-reference)
* [infi.clickhouse_orm.database](class_reference.md#inficlickhouse_ormdatabase) * [clickhouse_orm.database](class_reference.md#clickhouse_ormdatabase)
* [Database](class_reference.md#database) * [Database](class_reference.md#database)
* [DatabaseException](class_reference.md#databaseexception) * [DatabaseException](class_reference.md#databaseexception)
* [infi.clickhouse_orm.models](class_reference.md#inficlickhouse_ormmodels) * [clickhouse_orm.models](class_reference.md#clickhouse_ormmodels)
* [Model](class_reference.md#model) * [Model](class_reference.md#model)
* [BufferModel](class_reference.md#buffermodel) * [BufferModel](class_reference.md#buffermodel)
* [MergeModel](class_reference.md#mergemodel) * [MergeModel](class_reference.md#mergemodel)
* [DistributedModel](class_reference.md#distributedmodel) * [DistributedModel](class_reference.md#distributedmodel)
* [Constraint](class_reference.md#constraint) * [Constraint](class_reference.md#constraint)
* [Index](class_reference.md#index) * [Index](class_reference.md#index)
* [infi.clickhouse_orm.fields](class_reference.md#inficlickhouse_ormfields) * [clickhouse_orm.fields](class_reference.md#clickhouse_ormfields)
* [ArrayField](class_reference.md#arrayfield) * [ArrayField](class_reference.md#arrayfield)
* [BaseEnumField](class_reference.md#baseenumfield) * [BaseEnumField](class_reference.md#baseenumfield)
* [BaseFloatField](class_reference.md#basefloatfield) * [BaseFloatField](class_reference.md#basefloatfield)
@ -120,7 +120,7 @@
* [UInt64Field](class_reference.md#uint64field) * [UInt64Field](class_reference.md#uint64field)
* [UInt8Field](class_reference.md#uint8field) * [UInt8Field](class_reference.md#uint8field)
* [UUIDField](class_reference.md#uuidfield) * [UUIDField](class_reference.md#uuidfield)
* [infi.clickhouse_orm.engines](class_reference.md#inficlickhouse_ormengines) * [clickhouse_orm.engines](class_reference.md#clickhouse_ormengines)
* [Engine](class_reference.md#engine) * [Engine](class_reference.md#engine)
* [TinyLog](class_reference.md#tinylog) * [TinyLog](class_reference.md#tinylog)
* [Log](class_reference.md#log) * [Log](class_reference.md#log)
@ -132,10 +132,12 @@
* [CollapsingMergeTree](class_reference.md#collapsingmergetree) * [CollapsingMergeTree](class_reference.md#collapsingmergetree)
* [SummingMergeTree](class_reference.md#summingmergetree) * [SummingMergeTree](class_reference.md#summingmergetree)
* [ReplacingMergeTree](class_reference.md#replacingmergetree) * [ReplacingMergeTree](class_reference.md#replacingmergetree)
* [infi.clickhouse_orm.query](class_reference.md#inficlickhouse_ormquery) * [clickhouse_orm.query](class_reference.md#clickhouse_ormquery)
* [QuerySet](class_reference.md#queryset) * [QuerySet](class_reference.md#queryset)
* [AggregateQuerySet](class_reference.md#aggregatequeryset) * [AggregateQuerySet](class_reference.md#aggregatequeryset)
* [Q](class_reference.md#q) * [Q](class_reference.md#q)
* [infi.clickhouse_orm.funcs](class_reference.md#inficlickhouse_ormfuncs) * [clickhouse_orm.funcs](class_reference.md#clickhouse_ormfuncs)
* [F](class_reference.md#f) * [F](class_reference.md#f)
* [clickhouse_orm.system_models](class_reference.md#clickhouse_ormsystem_models)
* [SystemPart](class_reference.md#systempart)

View File

@ -50,9 +50,9 @@ for row in QueryLog.objects_in(db).filter(QueryLog.query_duration_ms > 10000):
## Convenient ways to import ORM classes ## Convenient ways to import ORM classes
You can now import all ORM classes directly from `infi.clickhouse_orm`, without worrying about sub-modules. For example: You can now import all ORM classes directly from `clickhouse_orm`, without worrying about sub-modules. For example:
```python ```python
from infi.clickhouse_orm import Database, Model, StringField, DateTimeField, MergeTree from clickhouse_orm import Database, Model, StringField, DateTimeField, MergeTree
``` ```
See [Importing ORM Classes](importing_orm_classes.md). See [Importing ORM Classes](importing_orm_classes.md).

View File

@ -1,20 +1,25 @@
import psutil, time, datetime import datetime
from infi.clickhouse_orm import Database import time
import psutil
from models import CPUStats from models import CPUStats
from clickhouse_orm import Database
db = Database('demo') db = Database("demo")
db.create_table(CPUStats) db.create_table(CPUStats)
psutil.cpu_percent(percpu=True) # first sample should be discarded psutil.cpu_percent(percpu=True) # first sample should be discarded
while True: while True:
time.sleep(1) time.sleep(1)
stats = psutil.cpu_percent(percpu=True) stats = psutil.cpu_percent(percpu=True)
timestamp = datetime.datetime.now() timestamp = datetime.datetime.now()
print(timestamp) print(timestamp)
db.insert([ db.insert(
CPUStats(timestamp=timestamp, cpu_id=cpu_id, cpu_percent=cpu_percent) [
for cpu_id, cpu_percent in enumerate(stats) CPUStats(timestamp=timestamp, cpu_id=cpu_id, cpu_percent=cpu_percent)
]) for cpu_id, cpu_percent in enumerate(stats)
]
)

View File

@ -1,4 +1,4 @@
from infi.clickhouse_orm import Model, DateTimeField, UInt16Field, Float32Field, Memory from clickhouse_orm import DateTimeField, Float32Field, Memory, Model, UInt16Field
class CPUStats(Model): class CPUStats(Model):
@ -8,4 +8,3 @@ class CPUStats(Model):
cpu_percent = Float32Field() cpu_percent = Float32Field()
engine = Memory() engine = Memory()

View File

@ -1,2 +1,2 @@
infi.clickhouse_orm clickhouse_orm
psutil psutil

View File

@ -1,13 +1,13 @@
from infi.clickhouse_orm import Database, F
from models import CPUStats from models import CPUStats
from clickhouse_orm import Database, F
db = Database('demo') db = Database("demo")
queryset = CPUStats.objects_in(db) queryset = CPUStats.objects_in(db)
total = queryset.filter(CPUStats.cpu_id == 1).count() total = queryset.filter(CPUStats.cpu_id == 1).count()
busy = queryset.filter(CPUStats.cpu_id == 1, CPUStats.cpu_percent > 95).count() busy = queryset.filter(CPUStats.cpu_id == 1, CPUStats.cpu_percent > 95).count()
print('CPU 1 was busy {:.2f}% of the time'.format(busy * 100.0 / total)) print("CPU 1 was busy {:.2f}% of the time".format(busy * 100.0 / total))
# Calculate the average usage per CPU # Calculate the average usage per CPU
for row in queryset.aggregate(CPUStats.cpu_id, average=F.avg(CPUStats.cpu_percent)): for row in queryset.aggregate(CPUStats.cpu_id, average=F.avg(CPUStats.cpu_percent)):
print('CPU {row.cpu_id}: {row.average:.2f}%'.format(row=row)) print("CPU {row.cpu_id}: {row.average:.2f}%".format(row=row))

View File

@ -1,62 +1,73 @@
import pygal import pygal
from pygal.style import RotateStyle
from jinja2.filters import do_filesizeformat from jinja2.filters import do_filesizeformat
from pygal.style import RotateStyle
# Formatting functions # Formatting functions
number_formatter = lambda v: '{:,}'.format(v) def number_formatter(v):
bytes_formatter = lambda v: do_filesizeformat(v, True) return "{:,}".format(v)
def bytes_formatter(v):
do_filesizeformat(v, True)
def tables_piechart(db, by_field, value_formatter): def tables_piechart(db, by_field, value_formatter):
''' """
Generate a pie chart of the top n tables in the database. Generate a pie chart of the top n tables in the database.
`db` - the database instance `db` - the database instance
`by_field` - the field name to sort by `by_field` - the field name to sort by
`value_formatter` - a function to use for formatting the numeric values `value_formatter` - a function to use for formatting the numeric values
''' """
Tables = db.get_model_for_table('tables', system_table=True) Tables = db.get_model_for_table("tables", system_table=True)
qs = Tables.objects_in(db).filter(database=db.db_name, is_temporary=False).exclude(engine='Buffer') qs = Tables.objects_in(db).filter(database=db.db_name, is_temporary=False).exclude(engine="Buffer")
tuples = [(getattr(table, by_field), table.name) for table in qs] tuples = [(getattr(table, by_field), table.name) for table in qs]
return _generate_piechart(tuples, value_formatter) return _generate_piechart(tuples, value_formatter)
def columns_piechart(db, tbl_name, by_field, value_formatter): def columns_piechart(db, tbl_name, by_field, value_formatter):
''' """
Generate a pie chart of the top n columns in the table. Generate a pie chart of the top n columns in the table.
`db` - the database instance `db` - the database instance
`tbl_name` - the table name `tbl_name` - the table name
`by_field` - the field name to sort by `by_field` - the field name to sort by
`value_formatter` - a function to use for formatting the numeric values `value_formatter` - a function to use for formatting the numeric values
''' """
ColumnsTable = db.get_model_for_table('columns', system_table=True) ColumnsTable = db.get_model_for_table("columns", system_table=True)
qs = ColumnsTable.objects_in(db).filter(database=db.db_name, table=tbl_name) qs = ColumnsTable.objects_in(db).filter(database=db.db_name, table=tbl_name)
tuples = [(getattr(col, by_field), col.name) for col in qs] tuples = [(getattr(col, by_field), col.name) for col in qs]
return _generate_piechart(tuples, value_formatter) return _generate_piechart(tuples, value_formatter)
def _get_top_tuples(tuples, n=15): def _get_top_tuples(tuples, n=15):
''' """
Given a list of tuples (value, name), this function sorts Given a list of tuples (value, name), this function sorts
the list and returns only the top n results. All other tuples the list and returns only the top n results. All other tuples
are aggregated to a single "others" tuple. are aggregated to a single "others" tuple.
''' """
non_zero_tuples = [t for t in tuples if t[0]] non_zero_tuples = [t for t in tuples if t[0]]
sorted_tuples = sorted(non_zero_tuples, reverse=True) sorted_tuples = sorted(non_zero_tuples, reverse=True)
if len(sorted_tuples) > n: if len(sorted_tuples) > n:
others = (sum(t[0] for t in sorted_tuples[n:]), 'others') others = (sum(t[0] for t in sorted_tuples[n:]), "others")
sorted_tuples = sorted_tuples[:n] + [others] sorted_tuples = sorted_tuples[:n] + [others]
return sorted_tuples return sorted_tuples
def _generate_piechart(tuples, value_formatter): def _generate_piechart(tuples, value_formatter):
''' """
Generates a pie chart. Generates a pie chart.
`tuples` - a list of (value, name) tuples to include in the chart `tuples` - a list of (value, name) tuples to include in the chart
`value_formatter` - a function to use for formatting the values `value_formatter` - a function to use for formatting the values
''' """
style = RotateStyle('#9e6ffe', background='white', legend_font_family='Roboto', legend_font_size=18, tooltip_font_family='Roboto', tooltip_font_size=24) style = RotateStyle(
chart = pygal.Pie(style=style, margin=0, title=' ', value_formatter=value_formatter, truncate_legend=-1) "#9e6ffe",
background="white",
legend_font_family="Roboto",
legend_font_size=18,
tooltip_font_family="Roboto",
tooltip_font_size=24,
)
chart = pygal.Pie(style=style, margin=0, title=" ", value_formatter=value_formatter, truncate_legend=-1)
for t in _get_top_tuples(tuples): for t in _get_top_tuples(tuples):
chart.add(t[1], t[0]) chart.add(t[1], t[0])
return chart.render(is_unicode=True, disable_xml_declaration=True) return chart.render(is_unicode=True, disable_xml_declaration=True)

View File

@ -3,7 +3,7 @@ chardet==3.0.4
click==7.1.2 click==7.1.2
Flask==1.1.2 Flask==1.1.2
idna==2.9 idna==2.9
infi.clickhouse-orm==2.0.1 clickhouse-orm==2.0.1
iso8601==0.1.12 iso8601==0.1.12
itsdangerous==1.1.0 itsdangerous==1.1.0
Jinja2==2.11.2 Jinja2==2.11.2

View File

@ -1,87 +1,93 @@
from infi.clickhouse_orm import Database, F
from charts import tables_piechart, columns_piechart, number_formatter, bytes_formatter
from flask import Flask
from flask import render_template
import sys import sys
from charts import bytes_formatter, columns_piechart, number_formatter, tables_piechart
from flask import Flask, render_template
from clickhouse_orm import Database, F
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route("/")
def homepage_view(): def homepage_view():
''' """
Root view that lists all databases. Root view that lists all databases.
''' """
db = _get_db('system') db = _get_db("system")
# Get all databases in the system.databases table # Get all databases in the system.databases table
DatabasesTable = db.get_model_for_table('databases', system_table=True) DatabasesTable = db.get_model_for_table("databases", system_table=True)
databases = DatabasesTable.objects_in(db).exclude(name='system') databases = DatabasesTable.objects_in(db).exclude(name="system")
databases = databases.order_by(F.lower(DatabasesTable.name)) databases = databases.order_by(F.lower(DatabasesTable.name))
# Generate the page # Generate the page
return render_template('homepage.html', db=db, databases=databases) return render_template("homepage.html", db=db, databases=databases)
@app.route('/<db_name>/') @app.route("/<db_name>/")
def database_view(db_name): def database_view(db_name):
''' """
A view that displays information about a single database. A view that displays information about a single database.
''' """
db = _get_db(db_name) db = _get_db(db_name)
# Get all the tables in the database, by aggregating information from system.columns # Get all the tables in the database, by aggregating information from system.columns
ColumnsTable = db.get_model_for_table('columns', system_table=True) ColumnsTable = db.get_model_for_table("columns", system_table=True)
tables = ColumnsTable.objects_in(db).filter(database=db_name).aggregate( tables = (
ColumnsTable.table, ColumnsTable.objects_in(db)
compressed_size=F.sum(ColumnsTable.data_compressed_bytes), .filter(database=db_name)
uncompressed_size=F.sum(ColumnsTable.data_uncompressed_bytes), .aggregate(
ratio=F.sum(ColumnsTable.data_uncompressed_bytes) / F.sum(ColumnsTable.data_compressed_bytes) ColumnsTable.table,
compressed_size=F.sum(ColumnsTable.data_compressed_bytes),
uncompressed_size=F.sum(ColumnsTable.data_uncompressed_bytes),
ratio=F.sum(ColumnsTable.data_uncompressed_bytes) / F.sum(ColumnsTable.data_compressed_bytes),
)
) )
tables = tables.order_by(F.lower(ColumnsTable.table)) tables = tables.order_by(F.lower(ColumnsTable.table))
# Generate the page # Generate the page
return render_template('database.html', return render_template(
"database.html",
db=db, db=db,
tables=tables, tables=tables,
tables_piechart_by_rows=tables_piechart(db, 'total_rows', value_formatter=number_formatter), tables_piechart_by_rows=tables_piechart(db, "total_rows", value_formatter=number_formatter),
tables_piechart_by_size=tables_piechart(db, 'total_bytes', value_formatter=bytes_formatter), tables_piechart_by_size=tables_piechart(db, "total_bytes", value_formatter=bytes_formatter),
) )
@app.route('/<db_name>/<tbl_name>/') @app.route("/<db_name>/<tbl_name>/")
def table_view(db_name, tbl_name): def table_view(db_name, tbl_name):
''' """
A view that displays information about a single table. A view that displays information about a single table.
''' """
db = _get_db(db_name) db = _get_db(db_name)
# Get table information from system.tables # Get table information from system.tables
TablesTable = db.get_model_for_table('tables', system_table=True) TablesTable = db.get_model_for_table("tables", system_table=True)
tbl_info = TablesTable.objects_in(db).filter(database=db_name, name=tbl_name)[0] tbl_info = TablesTable.objects_in(db).filter(database=db_name, name=tbl_name)[0]
# Get the SQL used for creating the table # Get the SQL used for creating the table
create_table_sql = db.raw('SHOW CREATE TABLE %s FORMAT TabSeparatedRaw' % tbl_name) create_table_sql = db.raw("SHOW CREATE TABLE %s FORMAT TabSeparatedRaw" % tbl_name)
# Get all columns in the table from system.columns # Get all columns in the table from system.columns
ColumnsTable = db.get_model_for_table('columns', system_table=True) ColumnsTable = db.get_model_for_table("columns", system_table=True)
columns = ColumnsTable.objects_in(db).filter(database=db_name, table=tbl_name) columns = ColumnsTable.objects_in(db).filter(database=db_name, table=tbl_name)
# Generate the page # Generate the page
return render_template('table.html', return render_template(
"table.html",
db=db, db=db,
tbl_name=tbl_name, tbl_name=tbl_name,
tbl_info=tbl_info, tbl_info=tbl_info,
create_table_sql=create_table_sql, create_table_sql=create_table_sql,
columns=columns, columns=columns,
piechart=columns_piechart(db, tbl_name, 'data_compressed_bytes', value_formatter=bytes_formatter), piechart=columns_piechart(db, tbl_name, "data_compressed_bytes", value_formatter=bytes_formatter),
) )
def _get_db(db_name): def _get_db(db_name):
''' """
Returns a Database instance using connection information Returns a Database instance using connection information
from the command line arguments (optional). from the command line arguments (optional).
''' """
db_url = sys.argv[1] if len(sys.argv) > 1 else 'http://localhost:8123/' db_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8123/"
username = sys.argv[2] if len(sys.argv) > 2 else None username = sys.argv[2] if len(sys.argv) > 2 else None
password = sys.argv[3] if len(sys.argv) > 3 else None password = sys.argv[3] if len(sys.argv) > 3 else None
return Database(db_name, db_url, username, password, readonly=True) return Database(db_name, db_url, username, password, readonly=True)
if __name__ == '__main__': if __name__ == "__main__":
_get_db('system') # fail early on db connection problems _get_db("system") # fail early on db connection problems
app.run(debug=True) app.run(debug=True)

View File

@ -1,27 +1,28 @@
import requests
import os import os
import requests
def download_ebook(id): def download_ebook(id):
print(id, end=' ') print(id, end=" ")
# Download the ebook's text # Download the ebook's text
r = requests.get('https://www.gutenberg.org/files/{id}/{id}-0.txt'.format(id=id)) r = requests.get("https://www.gutenberg.org/files/{id}/{id}-0.txt".format(id=id))
if r.status_code == 404: if r.status_code == 404:
print('NOT FOUND, SKIPPING') print("NOT FOUND, SKIPPING")
return return
r.raise_for_status() r.raise_for_status()
# Find the ebook's title # Find the ebook's title
text = r.content.decode('utf-8') text = r.content.decode("utf-8")
for line in text.splitlines(): for line in text.splitlines():
if line.startswith('Title:'): if line.startswith("Title:"):
title = line[6:].strip() title = line[6:].strip()
print(title) print(title)
# Save the ebook # Save the ebook
with open('ebooks/{}.txt'.format(title), 'wb') as f: with open("ebooks/{}.txt".format(title), "wb") as f:
f.write(r.content) f.write(r.content)
if __name__ == "__main__": if __name__ == "__main__":
os.makedirs('ebooks', exist_ok=True) os.makedirs("ebooks", exist_ok=True)
for i in [1342, 11, 84, 2701, 25525, 1661, 98, 74, 43, 215, 1400, 76]: for i in [1342, 11, 84, 2701, 25525, 1661, 98, 74, 43, 215, 1400, 76]:
download_ebook(i) download_ebook(i)

View File

@ -1,61 +1,64 @@
import sys import sys
import nltk
from nltk.stem.porter import PorterStemmer
from glob import glob from glob import glob
from infi.clickhouse_orm import Database
import nltk
from models import Fragment from models import Fragment
from nltk.stem.porter import PorterStemmer
from clickhouse_orm import Database
def trim_punctuation(word): def trim_punctuation(word):
''' """
Trim punctuation characters from the beginning and end of the word Trim punctuation characters from the beginning and end of the word
''' """
start = end = len(word) start = end = len(word)
for i in range(len(word)): for i in range(len(word)):
if word[i].isalnum(): if word[i].isalnum():
start = min(start, i) start = min(start, i)
end = i + 1 end = i + 1
return word[start : end] return word[start:end]
def parse_file(filename): def parse_file(filename):
''' """
Parses a text file at the give path. Parses a text file at the give path.
Returns a generator of tuples (original_word, stemmed_word) Returns a generator of tuples (original_word, stemmed_word)
The original_word may include punctuation characters. The original_word may include punctuation characters.
''' """
stemmer = PorterStemmer() stemmer = PorterStemmer()
with open(filename, 'r', encoding='utf-8') as f: with open(filename, "r", encoding="utf-8") as f:
for line in f: for line in f:
for word in line.split(): for word in line.split():
yield (word, stemmer.stem(trim_punctuation(word))) yield (word, stemmer.stem(trim_punctuation(word)))
def get_fragments(filename): def get_fragments(filename):
''' """
Converts a text file at the given path to a generator Converts a text file at the given path to a generator
of Fragment instances. of Fragment instances.
''' """
from os import path from os import path
document = path.splitext(path.basename(filename))[0] document = path.splitext(path.basename(filename))[0]
idx = 0 idx = 0
for word, stem in parse_file(filename): for word, stem in parse_file(filename):
idx += 1 idx += 1
yield Fragment(document=document, idx=idx, word=word, stem=stem) yield Fragment(document=document, idx=idx, word=word, stem=stem)
print('{} - {} words'.format(filename, idx)) print("{} - {} words".format(filename, idx))
if __name__ == '__main__': if __name__ == "__main__":
# Load NLTK data if necessary # Load NLTK data if necessary
nltk.download('punkt') nltk.download("punkt")
nltk.download('wordnet') nltk.download("wordnet")
# Initialize database # Initialize database
db = Database('default') db = Database("default")
db.create_table(Fragment) db.create_table(Fragment)
# Load files from the command line or everything under ebooks/ # Load files from the command line or everything under ebooks/
filenames = sys.argv[1:] or glob('ebooks/*.txt') filenames = sys.argv[1:] or glob("ebooks/*.txt")
for filename in filenames: for filename in filenames:
db.insert(get_fragments(filename), batch_size=100000) db.insert(get_fragments(filename), batch_size=100000)

View File

@ -1,16 +1,18 @@
from infi.clickhouse_orm import * from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import LowCardinalityField, StringField, UInt64Field
from clickhouse_orm.models import Index, Model
class Fragment(Model): class Fragment(Model):
language = LowCardinalityField(StringField(), default='EN') language = LowCardinalityField(StringField(), default="EN")
document = LowCardinalityField(StringField()) document = LowCardinalityField(StringField())
idx = UInt64Field() idx = UInt64Field()
word = StringField() word = StringField()
stem = StringField() stem = StringField()
# An index for faster search by document and fragment idx # An index for faster search by document and fragment idx
index = Index((document, idx), type=Index.minmax(), granularity=1) index = Index((document, idx), type=Index.minmax(), granularity=1)
# The primary key allows efficient lookup of stems # The primary key allows efficient lookup of stems
engine = MergeTree(order_by=(stem, document, idx), partition_key=('language',)) engine = MergeTree(order_by=(stem, document, idx), partition_key=("language",))

View File

@ -1,4 +1,4 @@
infi.clickhouse_orm clickhouse_orm
nltk nltk
requests requests
colorama colorama

View File

@ -1,19 +1,20 @@
import sys import sys
from colorama import init, Fore, Back, Style
from nltk.stem.porter import PorterStemmer
from infi.clickhouse_orm import Database, F
from models import Fragment
from load import trim_punctuation
from colorama import Fore, Style, init
from load import trim_punctuation
from models import Fragment
from nltk.stem.porter import PorterStemmer
from clickhouse_orm import Database, F
# The wildcard character # The wildcard character
WILDCARD = '*' WILDCARD = "*"
def prepare_search_terms(text): def prepare_search_terms(text):
''' """
Convert the text to search into a list of stemmed words. Convert the text to search into a list of stemmed words.
''' """
stemmer = PorterStemmer() stemmer = PorterStemmer()
stems = [] stems = []
for word in text.split(): for word in text.split():
@ -25,10 +26,10 @@ def prepare_search_terms(text):
def build_query(db, stems): def build_query(db, stems):
''' """
Returns a queryset instance for finding sequences of Fragment instances Returns a queryset instance for finding sequences of Fragment instances
that matche the list of stemmed words. that matche the list of stemmed words.
''' """
# Start by searching for the first stemmed word # Start by searching for the first stemmed word
all_fragments = Fragment.objects_in(db) all_fragments = Fragment.objects_in(db)
query = all_fragments.filter(stem=stems[0]).only(Fragment.document, Fragment.idx) query = all_fragments.filter(stem=stems[0]).only(Fragment.document, Fragment.idx)
@ -47,44 +48,44 @@ def build_query(db, stems):
def get_matching_text(db, document, from_idx, to_idx, extra=5): def get_matching_text(db, document, from_idx, to_idx, extra=5):
''' """
Reconstructs the document text between the given indexes (inclusive), Reconstructs the document text between the given indexes (inclusive),
plus `extra` words before and after the match. The words that are plus `extra` words before and after the match. The words that are
included in the given range are highlighted in green. included in the given range are highlighted in green.
''' """
text = [] text = []
conds = (Fragment.document == document) & (Fragment.idx >= from_idx - extra) & (Fragment.idx <= to_idx + extra) conds = (Fragment.document == document) & (Fragment.idx >= from_idx - extra) & (Fragment.idx <= to_idx + extra)
for fragment in Fragment.objects_in(db).filter(conds).order_by('document', 'idx'): for fragment in Fragment.objects_in(db).filter(conds).order_by("document", "idx"):
word = fragment.word word = fragment.word
if fragment.idx == from_idx: if fragment.idx == from_idx:
word = Fore.GREEN + word word = Fore.GREEN + word
if fragment.idx == to_idx: if fragment.idx == to_idx:
word = word + Style.RESET_ALL word = word + Style.RESET_ALL
text.append(word) text.append(word)
return ' '.join(text) return " ".join(text)
def find(db, text): def find(db, text):
''' """
Performs the search for the given text, and prints out the matches. Performs the search for the given text, and prints out the matches.
''' """
stems = prepare_search_terms(text) stems = prepare_search_terms(text)
query = build_query(db, stems) query = build_query(db, stems)
print('\n' + Fore.MAGENTA + str(query) + Style.RESET_ALL + '\n') print("\n" + Fore.MAGENTA + str(query) + Style.RESET_ALL + "\n")
for match in query: for match in query:
text = get_matching_text(db, match.document, match.idx, match.idx + len(stems) - 1) text = get_matching_text(db, match.document, match.idx, match.idx + len(stems) - 1)
print(Fore.CYAN + match.document + ':' + Style.RESET_ALL, text) print(Fore.CYAN + match.document + ":" + Style.RESET_ALL, text)
if __name__ == '__main__': if __name__ == "__main__":
# Initialize colored output # Initialize colored output
init() init()
# Initialize database # Initialize database
db = Database('default') db = Database("default")
# Search # Search
text = ' '.join(sys.argv[1:]) text = " ".join(sys.argv[1:])
if text: if text:
find(db, text) find(db, text)

52
pyproject.toml Normal file
View File

@ -0,0 +1,52 @@
[tool.black]
line-length = 120
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
[tool.poetry]
name = "clickhouse_orm"
version = "2.2.2"
description = "A simple ORM for working with the Clickhouse database. Maintainance fork of infi.clickhouse_orm."
authors = ["olliemath <oliver.margetts@gmail.com>"]
license = "BSD"
homepage = "https://github.com/SuadeLabs/clickhouse_orm"
repository = "https://github.com/SuadeLabs/clickhouse_orm"
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Database"
]
[tool.poetry.dependencies]
python = ">=3.6.2,<4"
requests = "*"
pytz = "*"
iso8601 = "*"
[tool.poetry.dev-dependencies]
flake8 = "^3.9.2"
flake8-bugbear = "^21.4.3"
pep8-naming = "^0.12.0"
pytest = "^6.2.4"
flake8-isort = "^4.0.0"
black = {version = "^21.7b0", markers = "platform_python_implementation == 'CPython'"}
isort = "^5.9.2"
freezegun = "^1.1.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,4 +1,4 @@
#!/bin/bash
mkdir -p ../htmldocs mkdir -p ../htmldocs
find ./ -iname "*.md" -type f -exec sh -c 'echo "Converting ${0}"; pandoc "${0}" -s -o "../htmldocs/${0%.md}.html"' {} \; find ./ -iname "*.md" -type f -exec sh -c 'echo "Converting ${0}"; pandoc "${0}" -s -o "../htmldocs/${0%.md}.html"' {} \;

View File

@ -1,5 +1,6 @@
#!/bin/bash
# Class reference # Class reference
../bin/python ../scripts/generate_ref.py > class_reference.md poetry run python ../scripts/generate_ref.py > class_reference.md
# Table of contents # Table of contents
../scripts/generate_toc.sh ../scripts/generate_toc.sh

View File

@ -1,11 +1,11 @@
import inspect import inspect
from collections import namedtuple from collections import namedtuple
DefaultArgSpec = namedtuple('DefaultArgSpec', 'has_default default_value') DefaultArgSpec = namedtuple("DefaultArgSpec", "has_default default_value")
def _get_default_arg(args, defaults, arg_index): def _get_default_arg(args, defaults, arg_index):
""" Method that determines if an argument has default value or not, """Method that determines if an argument has default value or not,
and if yes what is the default value for the argument and if yes what is the default value for the argument
:param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg'] :param args: array of arguments, eg: ['first_arg', 'second_arg', 'third_arg']
@ -25,12 +25,13 @@ def _get_default_arg(args, defaults, arg_index):
return DefaultArgSpec(False, None) return DefaultArgSpec(False, None)
else: else:
value = defaults[arg_index - args_with_no_defaults] value = defaults[arg_index - args_with_no_defaults]
if (type(value) is str): if type(value) is str:
value = '"%s"' % value value = '"%s"' % value
return DefaultArgSpec(True, value) return DefaultArgSpec(True, value)
def get_method_sig(method): def get_method_sig(method):
""" Given a function, it returns a string that pretty much looks how the """Given a function, it returns a string that pretty much looks how the
function signature would be written in python. function signature would be written in python.
:param method: a python method :param method: a python method
@ -42,31 +43,37 @@ def get_method_sig(method):
# list of defaults are returned in separate array. # list of defaults are returned in separate array.
# eg: ArgSpec(args=['first_arg', 'second_arg', 'third_arg'], # eg: ArgSpec(args=['first_arg', 'second_arg', 'third_arg'],
# varargs=None, keywords=None, defaults=(42, 'something')) # varargs=None, keywords=None, defaults=(42, 'something'))
argspec = inspect.getargspec(method) argspec = inspect.getfullargspec(method)
arg_index=0
args = [] args = []
# Use the args and defaults array returned by argspec and find out # Use the args and defaults array returned by argspec and find out
# which arguments has default # which arguments has default
for arg in argspec.args: for idx, arg in enumerate(argspec.args):
default_arg = _get_default_arg(argspec.args, argspec.defaults, arg_index) default_arg = _get_default_arg(argspec.args, argspec.defaults, idx)
if default_arg.has_default:
val = default_arg.default_value
args.append("%s=%s" % (arg, val))
else:
args.append(arg)
for idx, arg in enumerate(argspec.kwonlyargs):
default_arg = _get_default_arg(argspec.kwonlyargs, argspec.kwonlydefaults, idx)
if default_arg.has_default: if default_arg.has_default:
val = default_arg.default_value val = default_arg.default_value
args.append("%s=%s" % (arg, val)) args.append("%s=%s" % (arg, val))
else: else:
args.append(arg) args.append(arg)
arg_index += 1
if argspec.varargs: if argspec.varargs:
args.append('*' + argspec.varargs) args.append("*" + argspec.varargs)
if argspec.keywords: if argspec.varkw:
args.append('**' + argspec.keywords) args.append("**" + argspec.varkw)
return "%s(%s)" % (method.__name__, ", ".join(args[1:])) return "%s(%s)" % (method.__name__, ", ".join(args[1:]))
def docstring(obj): def docstring(obj):
doc = (obj.__doc__ or '').rstrip() doc = (obj.__doc__ or "").rstrip()
if doc: if doc:
lines = doc.split('\n') lines = doc.split("\n")
# Find the length of the whitespace prefix common to all non-empty lines # Find the length of the whitespace prefix common to all non-empty lines
indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip()) indentation = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
# Output the lines without the indentation # Output the lines without the indentation
@ -76,30 +83,30 @@ def docstring(obj):
def class_doc(cls, list_methods=True): def class_doc(cls, list_methods=True):
bases = ', '.join([b.__name__ for b in cls.__bases__]) bases = ", ".join([b.__name__ for b in cls.__bases__])
print('###', cls.__name__) print("###", cls.__name__)
print() print()
if bases != 'object': if bases != "object":
print('Extends', bases) print("Extends", bases)
print() print()
docstring(cls) docstring(cls)
for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)): for name, method in inspect.getmembers(cls, lambda m: inspect.ismethod(m) or inspect.isfunction(m)):
if name == '__init__': if name == "__init__":
# Initializer # Initializer
print('####', get_method_sig(method).replace(name, cls.__name__)) print("####", get_method_sig(method).replace(name, cls.__name__))
elif name[0] == '_': elif name[0] == "_":
# Private method # Private method
continue continue
elif hasattr(method, '__self__') and method.__self__ == cls: elif hasattr(method, "__self__") and method.__self__ == cls:
# Class method # Class method
if not list_methods: if not list_methods:
continue continue
print('#### %s.%s' % (cls.__name__, get_method_sig(method))) print("#### %s.%s" % (cls.__name__, get_method_sig(method)))
else: else:
# Regular method # Regular method
if not list_methods: if not list_methods:
continue continue
print('####', get_method_sig(method)) print("####", get_method_sig(method))
print() print()
docstring(method) docstring(method)
print() print()
@ -108,7 +115,7 @@ def class_doc(cls, list_methods=True):
def module_doc(classes, list_methods=True): def module_doc(classes, list_methods=True):
mdl = classes[0].__module__ mdl = classes[0].__module__
print(mdl) print(mdl)
print('-' * len(mdl)) print("-" * len(mdl))
print() print()
for cls in classes: for cls in classes:
class_doc(cls, list_methods) class_doc(cls, list_methods)
@ -118,21 +125,17 @@ def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)] return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in all_subclasses(s)]
if __name__ == '__main__': if __name__ == "__main__":
from infi.clickhouse_orm import database from clickhouse_orm import database, engines, fields, funcs, models, query, system_models
from infi.clickhouse_orm import fields
from infi.clickhouse_orm import engines
from infi.clickhouse_orm import models
from infi.clickhouse_orm import query
from infi.clickhouse_orm import funcs
from infi.clickhouse_orm import system_models
print('Class Reference') print("Class Reference")
print('===============') print("===============")
print() print()
module_doc([database.Database, database.DatabaseException]) module_doc([database.Database, database.DatabaseException])
module_doc([models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index]) module_doc(
[models.Model, models.BufferModel, models.MergeModel, models.DistributedModel, models.Constraint, models.Index]
)
module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False) module_doc(sorted([fields.Field] + all_subclasses(fields.Field), key=lambda x: x.__name__), False)
module_doc([engines.Engine] + all_subclasses(engines.Engine), False) module_doc([engines.Engine] + all_subclasses(engines.Engine), False)
module_doc([query.QuerySet, query.AggregateQuerySet, query.Q]) module_doc([query.QuerySet, query.AggregateQuerySet, query.Q])

View File

@ -1,7 +1,7 @@
#!/bin/bash
generate_one() { generate_one() {
# Converts Markdown to HTML using Pandoc, and then extracts the header tags # Converts Markdown to HTML using Pandoc, and then extracts the header tags
pandoc "$1" | python "../scripts/html_to_markdown_toc.py" "$1" >> toc.md pandoc "$1" | poetry run python "../scripts/html_to_markdown_toc.py" "$1" >> toc.md
} }
printf "# Table of Contents\n\n" > toc.md printf "# Table of Contents\n\n" > toc.md

View File

@ -1,14 +1,13 @@
from html.parser import HTMLParser
import sys import sys
from html.parser import HTMLParser
HEADER_TAGS = ("h1", "h2", "h3")
HEADER_TAGS = ('h1', 'h2', 'h3')
class HeadersToMarkdownParser(HTMLParser): class HeadersToMarkdownParser(HTMLParser):
inside = None inside = None
text = '' text = ""
def handle_starttag(self, tag, attrs): def handle_starttag(self, tag, attrs):
if tag.lower() in HEADER_TAGS: if tag.lower() in HEADER_TAGS:
@ -16,11 +15,11 @@ class HeadersToMarkdownParser(HTMLParser):
def handle_endtag(self, tag): def handle_endtag(self, tag):
if tag.lower() in HEADER_TAGS: if tag.lower() in HEADER_TAGS:
indent = ' ' * int(self.inside[1]) indent = " " * int(self.inside[1])
fragment = self.text.lower().replace(' ', '-').replace('.', '') fragment = self.text.lower().replace(" ", "-").replace(".", "")
print('%s* [%s](%s#%s)' % (indent, self.text, sys.argv[1], fragment)) print("%s* [%s](%s#%s)" % (indent, self.text, sys.argv[1], fragment))
self.inside = None self.inside = None
self.text = '' self.text = ""
def handle_data(self, data): def handle_data(self, data):
if self.inside: if self.inside:
@ -28,4 +27,4 @@ class HeadersToMarkdownParser(HTMLParser):
HeadersToMarkdownParser().feed(sys.stdin.read()) HeadersToMarkdownParser().feed(sys.stdin.read())
print('') print("")

View File

@ -1,11 +0,0 @@
#!/bin/bash
cd /tmp
rm -rf /tmp/orm_env*
virtualenv -p python3 /tmp/orm_env
cd /tmp/orm_env
source bin/activate
pip install infi.projector
git clone https://github.com/Infinidat/infi.clickhouse_orm.git
cd infi.clickhouse_orm
projector devenv build
bin/nosetests

19
setup.cfg Normal file
View File

@ -0,0 +1,19 @@
[flake8]
max-line-length = 120
select =
# pycodestyle
E, W
# pyflakes
F
# flake8-bugbear
B, B9
# pydocstyle
D
# isort
I
ignore =
E203 # Whitespace after ':'
W503 # Operator after new line
B950 # We use E501
exclude =
tests/sample_migrations

View File

@ -1,50 +0,0 @@
SETUP_INFO = dict(
name = '${project:name}',
version = '${infi.recipe.template.version:version}',
author = '${infi.recipe.template.version:author}',
author_email = '${infi.recipe.template.version:author_email}',
url = ${infi.recipe.template.version:homepage},
license = 'BSD',
description = """${project:description}""",
# http://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.4",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Database"
],
install_requires = ${project:install_requires},
namespace_packages = ${project:namespace_packages},
package_dir = {'': 'src'},
package_data = {'': ${project:package_data}},
include_package_data = True,
zip_safe = False,
entry_points = dict(
console_scripts = ${project:console_scripts},
gui_scripts = ${project:gui_scripts},
),
)
if SETUP_INFO['url'] is None:
_ = SETUP_INFO.pop('url')
def setup():
from setuptools import setup as _setup
from setuptools import find_packages
SETUP_INFO['packages'] = find_packages('src')
_setup(**SETUP_INFO)
if __name__ == '__main__':
setup()

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -1,13 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)
from infi.clickhouse_orm.database import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.funcs import *
from infi.clickhouse_orm.migrations import *
from infi.clickhouse_orm.models import *
from infi.clickhouse_orm.query import *
from infi.clickhouse_orm.system_models import *
from inspect import isclass
__all__ = [c.__name__ for c in locals().values() if isclass(c)]

View File

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -1,19 +1,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
import unittest import unittest
from infi.clickhouse_orm.database import Database from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import MergeTree
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import DateField, Float32Field, LowCardinalityField, NullableField, StringField, UInt32Field
from infi.clickhouse_orm.engines import * from clickhouse_orm.models import Model
import logging
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
class TestCaseWithData(unittest.TestCase): class TestCaseWithData(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(Person) self.database.create_table(Person)
def tearDown(self): def tearDown(self):
@ -35,7 +34,6 @@ class TestCaseWithData(unittest.TestCase):
yield Person(**entry) yield Person(**entry)
class Person(Model): class Person(Model):
first_name = StringField() first_name = StringField()
@ -44,16 +42,12 @@ class Person(Model):
height = Float32Field() height = Float32Field()
passport = NullableField(UInt32Field()) passport = NullableField(UInt32Field())
engine = MergeTree('birthday', ('first_name', 'last_name', 'birthday')) engine = MergeTree("birthday", ("first_name", "last_name", "birthday"))
data = [ data = [
{"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", {"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", "passport": 35052255},
"passport": 35052255}, {"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", "passport": 36052255},
{"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74",
"passport": 36052255},
{"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"}, {"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"},
{"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"}, {"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"},
{"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"}, {"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"},
@ -151,5 +145,5 @@ data = [
{"first_name": "Whitney", "last_name": "Durham", "birthday": "1977-09-15", "height": "1.72"}, {"first_name": "Whitney", "last_name": "Durham", "birthday": "1977-09-15", "height": "1.72"},
{"first_name": "Whitney", "last_name": "Scott", "birthday": "1971-07-04", "height": "1.70"}, {"first_name": "Whitney", "last_name": "Scott", "birthday": "1971-07-04", "height": "1.70"},
{"first_name": "Wynter", "last_name": "Garcia", "birthday": "1975-01-10", "height": "1.69"}, {"first_name": "Wynter", "last_name": "Garcia", "birthday": "1975-01-10", "height": "1.69"},
{"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"} {"first_name": "Yolanda", "last_name": "Duke", "birthday": "1997-02-25", "height": "1.74"},
] ]

0
tests/fields/__init__.py Normal file
View File

View File

@ -1,32 +1,29 @@
import unittest import unittest
from datetime import date from datetime import date
from infi.clickhouse_orm.database import Database from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE from clickhouse_orm.engines import MergeTree
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import DateField, Int32Field, StringField
from infi.clickhouse_orm.engines import * from clickhouse_orm.funcs import F
from infi.clickhouse_orm.funcs import F from clickhouse_orm.models import NO_VALUE, Model
class AliasFieldsTest(unittest.TestCase): class AliasFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithAliasFields) self.database.create_table(ModelWithAliasFields)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithAliasFields( instance = ModelWithAliasFields(date_field="2016-08-30", int_field=-10, str_field="TEST")
date_field='2016-08-30',
int_field=-10,
str_field='TEST'
)
self.database.insert([instance]) self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields # We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func' \ query = (
' FROM $db.%s ORDER BY alias_date' % ModelWithAliasFields.table_name() "SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func"
" FROM $db.%s ORDER BY alias_date" % ModelWithAliasFields.table_name()
)
for model_cls in (ModelWithAliasFields, None): for model_cls in (ModelWithAliasFields, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class AliasFieldsTest(unittest.TestCase):
def test_assignment_error(self): def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields. # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithAliasFields() instance = ModelWithAliasFields()
for value in ('x', [date.today()], ['aaa'], [None]): for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.alias_date = value instance.alias_date = value
@ -51,10 +48,10 @@ class AliasFieldsTest(unittest.TestCase):
def test_duplicate_default(self): def test_duplicate_default(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(alias='str_field', default='with default') StringField(alias="str_field", default="with default")
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(alias='str_field', materialized='str_field') StringField(alias="str_field", materialized="str_field")
def test_default_value(self): def test_default_value(self):
instance = ModelWithAliasFields() instance = ModelWithAliasFields()
@ -62,7 +59,7 @@ class AliasFieldsTest(unittest.TestCase):
# Check that NO_VALUE can be assigned to a field # Check that NO_VALUE can be assigned to a field
instance.str_field = NO_VALUE instance.str_field = NO_VALUE
# Check that NO_VALUE can be assigned when creating a new instance # Check that NO_VALUE can be assigned when creating a new instance
instance2 = ModelWithAliasFields(**instance.to_dict()) ModelWithAliasFields(**instance.to_dict())
class ModelWithAliasFields(Model): class ModelWithAliasFields(Model):
@ -70,9 +67,9 @@ class ModelWithAliasFields(Model):
date_field = DateField() date_field = DateField()
str_field = StringField() str_field = StringField()
alias_str = StringField(alias=u'str_field') alias_str = StringField(alias="str_field")
alias_int = Int32Field(alias='int_field') alias_int = Int32Field(alias="int_field")
alias_date = DateField(alias='date_field') alias_date = DateField(alias="date_field")
alias_func = Int32Field(alias=F.toYYYYMM(date_field)) alias_func = Int32Field(alias=F.toYYYYMM(date_field))
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -1,16 +1,15 @@
import unittest import unittest
from datetime import date from datetime import date
from infi.clickhouse_orm.database import Database from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import MergeTree
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import ArrayField, DateField, Int32Field, StringField
from infi.clickhouse_orm.engines import * from clickhouse_orm.models import Model
class ArrayFieldsTest(unittest.TestCase): class ArrayFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithArrays) self.database.create_table(ModelWithArrays)
def tearDown(self): def tearDown(self):
@ -18,12 +17,12 @@ class ArrayFieldsTest(unittest.TestCase):
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithArrays( instance = ModelWithArrays(
date_field='2016-08-30', date_field="2016-08-30",
arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'], arr_str=["goodbye,", "cruel", "world", "special chars: ,\"\\'` \n\t\\[]"],
arr_date=['2010-01-01'], arr_date=["2010-01-01"],
) )
self.database.insert([instance]) self.database.insert([instance])
query = 'SELECT * from $db.modelwitharrays ORDER BY date_field' query = "SELECT * from $db.modelwitharrays ORDER BY date_field"
for model_cls in (ModelWithArrays, None): for model_cls in (ModelWithArrays, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -32,32 +31,25 @@ class ArrayFieldsTest(unittest.TestCase):
self.assertEqual(results[0].arr_date, instance.arr_date) self.assertEqual(results[0].arr_date, instance.arr_date)
def test_conversion(self): def test_conversion(self):
instance = ModelWithArrays( instance = ModelWithArrays(arr_int=("1", "2", "3"), arr_date=["2010-01-01"])
arr_int=('1', '2', '3'),
arr_date=['2010-01-01']
)
self.assertEqual(instance.arr_str, []) self.assertEqual(instance.arr_str, [])
self.assertEqual(instance.arr_int, [1, 2, 3]) self.assertEqual(instance.arr_int, [1, 2, 3])
self.assertEqual(instance.arr_date, [date(2010, 1, 1)]) self.assertEqual(instance.arr_date, [date(2010, 1, 1)])
def test_assignment_error(self): def test_assignment_error(self):
instance = ModelWithArrays() instance = ModelWithArrays()
for value in (7, 'x', [date.today()], ['aaa'], [None]): for value in (7, "x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.arr_int = value instance.arr_int = value
def test_parse_array(self): def test_parse_array(self):
from infi.clickhouse_orm.utils import parse_array, unescape from clickhouse_orm.utils import parse_array, unescape
self.assertEqual(parse_array("[]"), []) self.assertEqual(parse_array("[]"), [])
self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"]) self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"]) self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"]) self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
for s in ("", for s in ("", "[", "]", "[1, 2", "3, 4]", "['aaa', 'aaa]"):
"[",
"]",
"[1, 2",
"3, 4]",
"['aaa', 'aaa]"):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_array(s) parse_array(s)
@ -74,4 +66,4 @@ class ModelWithArrays(Model):
arr_int = ArrayField(Int32Field()) arr_int = ArrayField(Int32Field())
arr_date = ArrayField(DateField()) arr_date = ArrayField(DateField())
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -0,0 +1,174 @@
import datetime
import unittest
import pytest
import pytz
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import (
ArrayField,
DateField,
DateTimeField,
FixedStringField,
Float32Field,
Int64Field,
NullableField,
StringField,
UInt64Field,
)
from clickhouse_orm.models import NO_VALUE, Model
from clickhouse_orm.utils import parse_tsv
class CompressedFieldsTestCase(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
self.database.create_table(CompressedModel)
def tearDown(self):
self.database.drop_database()
def test_defaults(self):
# Check that all fields have their explicit or implicit defaults
instance = CompressedModel()
self.database.insert([instance])
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.string_field, "dozo")
self.assertEqual(instance.int64_field, 42)
self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, [])
def test_assignment(self):
# Check that all fields are assigned during construction
kwargs = dict(
uint64_field=217,
date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
string_field="aloha",
int64_field=-50,
float_field=3.14,
nullable_field=-2.718281,
array_field=["123456789123456", "", "a"],
)
instance = CompressedModel(**kwargs)
self.database.insert([instance])
for name in kwargs:
self.assertEqual(kwargs[name], getattr(instance, name))
def test_string_conversion(self):
# Check field conversion from string during construction
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
nullable_field=None,
array_field="[a,b,c]",
)
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int64_field, 100)
self.assertEqual(instance.float_field, 7)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, ["a", "b", "c"])
# Check field conversion from string during assignment
instance.int64_field = "99"
self.assertEqual(instance.int64_field, 99)
def test_to_dict(self):
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
array_field="[a,b,c]",
)
self.assertDictEqual(
instance.to_dict(),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"mat_field": NO_VALUE,
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(include_readonly=False),
{
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"string_field": "dozo",
"nullable_field": None,
"uint64_field": 0,
"array_field": ["a", "b", "c"],
},
)
self.assertDictEqual(
instance.to_dict(
include_readonly=False,
field_names=("int64_field", "mat_field", "datetime_field"),
),
{
"int64_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
},
)
def test_confirm_compression_codec(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest("ClickHouse version too old")
instance = CompressedModel(
date_field="1973-12-06",
int64_field="100",
float_field="7",
array_field="[a,b,c]",
)
self.database.insert([instance])
r = self.database.raw(
"select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(
instance.table_name(), self.database.db_name
)
)
lines = r.splitlines()
parse_tsv(lines[0])
parse_tsv(lines[1])
data = [tuple(parse_tsv(line)) for line in lines[2:]]
self.assertListEqual(
data,
[
("uint64_field", "CODEC(ZSTD(10))"),
("datetime_field", "CODEC(Delta(4), ZSTD(1))"),
("date_field", "CODEC(Delta(4), ZSTD(22))"),
("int64_field", "CODEC(LZ4)"),
("string_field", "CODEC(LZ4HC(10))"),
("nullable_field", "CODEC(ZSTD(1))"),
("array_field", "CODEC(Delta(2), LZ4HC(0))"),
("float_field", "CODEC(NONE)"),
("mat_field", "CODEC(ZSTD(4))"),
],
)
def test_alias_field(self):
with pytest.raises(AssertionError):
Float32Field(alias="something", codec="ZSTD(4)")
class CompressedModel(Model):
uint64_field = UInt64Field(codec="ZSTD(10)")
datetime_field = DateTimeField(codec="Delta,ZSTD")
date_field = DateField(codec="Delta(4),ZSTD(22)")
int64_field = Int64Field(default=42, codec="LZ4")
string_field = StringField(default="dozo", codec="LZ4HC(10)")
nullable_field = NullableField(Float32Field(), codec="ZSTD")
array_field = ArrayField(FixedStringField(length=15), codec="Delta(2),LZ4HC")
float_field = Float32Field(codec="NONE")
mat_field = Float32Field(materialized="float_field", codec="ZSTD(4)")
engine = MergeTree("datetime_field", ("uint64_field", "datetime_field"))

View File

@ -1,14 +1,14 @@
import unittest import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Field, Int16Field from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import Memory
from infi.clickhouse_orm.engines import Memory from clickhouse_orm.fields import Field, Int16Field
from clickhouse_orm.models import Model
class CustomFieldsTest(unittest.TestCase): class CustomFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -19,15 +19,18 @@ class CustomFieldsTest(unittest.TestCase):
i = Int16Field() i = Int16Field()
f = BooleanField() f = BooleanField()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values # Check valid values
for index, value in enumerate([1, '1', True, 0, '0', False]): for index, value in enumerate([1, "1", True, 0, "0", False]):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
self.database.insert([rec]) self.database.insert([rec])
self.assertEqual([rec.f for rec in TestModel.objects_in(self.database).order_by('i')], self.assertEqual(
[True, True, True, False, False, False]) [rec.f for rec in TestModel.objects_in(self.database).order_by("i")],
[True, True, True, False, False, False],
)
# Check invalid values # Check invalid values
for value in [None, 'zzz', -5, 7]: for value in [None, "zzz", -5, 7]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)
@ -35,21 +38,20 @@ class CustomFieldsTest(unittest.TestCase):
class BooleanField(Field): class BooleanField(Field):
# The ClickHouse column type to use # The ClickHouse column type to use
db_type = 'UInt8' db_type = "UInt8"
# The default value if empty # The default value if empty
class_default = False class_default = False
def to_python(self, value, timezone_in_use): def to_python(self, value, timezone_in_use):
# Convert valid values to bool # Convert valid values to bool
if value in (1, '1', True): if value in (1, "1", True):
return True return True
elif value in (0, '0', False): elif value in (0, "0", False):
return False return False
else: else:
raise ValueError('Invalid value for BooleanField: %r' % value) raise ValueError("Invalid value for BooleanField: %r" % value)
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
# The value was already converted by to_python, so it's a bool # The value was already converted by to_python, so it's a bool
return '1' if value else '0' return "1" if value else "0"

View File

@ -0,0 +1,187 @@
import datetime
import unittest
import pytz
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, DateTime64Field, DateTimeField
from clickhouse_orm.models import Model
class DateFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithDate)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert(
[
ModelWithDate(
date_field="2016-08-30",
datetime_field="2016-08-30 03:50:00",
datetime64_field="2016-08-30 03:50:00.123456",
datetime64_3_field="2016-08-30 03:50:00.123456",
),
ModelWithDate(
date_field="2016-08-31",
datetime_field="2016-08-31 01:30:00",
datetime64_field="2016-08-31 01:30:00.123456",
datetime64_3_field="2016-08-31 01:30:00.123456",
),
]
)
# toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
query = "SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field"
results = list(self.database.select(query))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
self.assertEqual(
results[0].datetime_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].hour_start,
datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
self.assertEqual(
results[1].datetime_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].hour_start,
datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_3_field,
datetime.datetime(2016, 8, 30, 3, 50, 0, 123000, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_3_field,
datetime.datetime(2016, 8, 31, 1, 30, 0, 123000, tzinfo=pytz.UTC),
)
class ModelWithDate(Model):
date_field = DateField()
datetime_field = DateTimeField()
datetime64_field = DateTime64Field()
datetime64_3_field = DateTime64Field(precision=3)
engine = MergeTree("date_field", ("date_field",))
class ModelWithTz(Model):
datetime_no_tz_field = DateTimeField() # server tz
datetime_tz_field = DateTimeField(timezone="Europe/Madrid")
datetime64_tz_field = DateTime64Field(timezone="Europe/Madrid")
datetime_utc_field = DateTimeField(timezone=pytz.UTC)
engine = MergeTree("datetime_no_tz_field", ("datetime_no_tz_field",))
class DateTimeFieldWithTzTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(ModelWithTz)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert(
[
ModelWithTz(
datetime_no_tz_field="2020-06-11 04:00:00",
datetime_tz_field="2020-06-11 04:00:00",
datetime64_tz_field="2020-06-11 04:00:00",
datetime_utc_field="2020-06-11 04:00:00",
),
ModelWithTz(
datetime_no_tz_field="2020-06-11 07:00:00+0300",
datetime_tz_field="2020-06-11 07:00:00+0300",
datetime64_tz_field="2020-06-11 07:00:00+0300",
datetime_utc_field="2020-06-11 07:00:00+0300",
),
]
)
query = "SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field"
results = list(self.database.select(query))
self.assertEqual(
results[0].datetime_no_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime64_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_utc_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_no_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime64_tz_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[1].datetime_utc_field,
datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC),
)
self.assertEqual(
results[0].datetime_no_tz_field.tzinfo.zone,
self.database.server_timezone.zone,
)
self.assertEqual(
results[0].datetime_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(
results[0].datetime64_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)
self.assertEqual(
results[1].datetime_no_tz_field.tzinfo.zone,
self.database.server_timezone.zone,
)
self.assertEqual(
results[1].datetime_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(
results[1].datetime64_tz_field.tzinfo.zone,
pytz.timezone("Europe/Madrid").zone,
)
self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone("UTC").zone)

View File

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
import unittest
from decimal import Decimal
from clickhouse_orm.database import Database, ServerError
from clickhouse_orm.engines import Memory
from clickhouse_orm.fields import DateField, Decimal32Field, Decimal64Field, Decimal128Field, DecimalField
from clickhouse_orm.models import Model
class DecimalFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
try:
self.database.create_table(DecimalModel)
except ServerError as e:
# This ClickHouse version does not support decimals yet
raise unittest.SkipTest(str(e))
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert(
[
DecimalModel(date_field="2016-08-20"),
DecimalModel(date_field="2016-08-21", dec=Decimal("1.234")),
DecimalModel(date_field="2016-08-22", dec32=Decimal("12342.2345")),
DecimalModel(date_field="2016-08-23", dec64=Decimal("12342.23456")),
DecimalModel(date_field="2016-08-24", dec128=Decimal("-4545456612342.234567")),
]
)
def _assert_sample_data(self, results):
self.assertEqual(len(results), 5)
self.assertEqual(results[0].dec, Decimal(0))
self.assertEqual(results[0].dec32, Decimal(17))
self.assertEqual(results[1].dec, Decimal("1.234"))
self.assertEqual(results[2].dec32, Decimal("12342.2345"))
self.assertEqual(results[3].dec64, Decimal("12342.23456"))
self.assertEqual(results[4].dec128, Decimal("-4545456612342.234567"))
def test_insert_and_select(self):
self._insert_sample_data()
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, DecimalModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = "SELECT * from decimalmodel ORDER BY date_field"
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_rounding(self):
d = Decimal("11111.2340000000000000001")
self.database.insert([DecimalModel(date_field="2016-08-20", dec=d, dec32=d, dec64=d, dec128=d)])
m = DecimalModel.objects_in(self.database)[0]
for val in (m.dec, m.dec32, m.dec64, m.dec128):
self.assertEqual(val, Decimal("11111.234"))
def test_assignment_ok(self):
for value in (True, False, 17, 3.14, "20.5", Decimal("20.5")):
DecimalModel(dec=value)
def test_assignment_error(self):
for value in ("abc", u"זה ארוך", None, float("NaN"), Decimal("-Infinity")):
with self.assertRaises(ValueError):
DecimalModel(dec=value)
def test_aggregation(self):
self._insert_sample_data()
result = DecimalModel.objects_in(self.database).aggregate(m="min(dec)", n="max(dec)")
self.assertEqual(result[0].m, Decimal(0))
self.assertEqual(result[0].n, Decimal("1.234"))
def test_precision_and_scale(self):
# Go over all valid combinations
for precision in range(1, 39):
for scale in range(0, precision + 1):
DecimalField(precision, scale)
# Some invalid combinations
for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
with self.assertRaises(AssertionError):
DecimalField(precision, scale)
def test_min_max(self):
# In range
f = DecimalField(3, 1)
f.validate(f.to_python("99.9", None))
f.validate(f.to_python("-99.9", None))
# In range after rounding
f.validate(f.to_python("99.94", None))
f.validate(f.to_python("-99.94", None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python("99.99", None))
with self.assertRaises(ValueError):
f.validate(f.to_python("-99.99", None))
# In range
f = Decimal32Field(4)
f.validate(f.to_python("99999.9999", None))
f.validate(f.to_python("-99999.9999", None))
# In range after rounding
f.validate(f.to_python("99999.99994", None))
f.validate(f.to_python("-99999.99994", None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python("100000", None))
with self.assertRaises(ValueError):
f.validate(f.to_python("-100000", None))
class DecimalModel(Model):
date_field = DateField()
dec = DecimalField(15, 3)
dec32 = Decimal32Field(4, default=17)
dec64 = Decimal64Field(5)
dec128 = Decimal128Field(6)
engine = Memory()

View File

@ -1,17 +1,15 @@
import unittest import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from enum import Enum from enum import Enum
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import ArrayField, DateField, Enum8Field, Enum16Field
from clickhouse_orm.models import Model
class EnumFieldsTest(unittest.TestCase): class EnumFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithEnum) self.database.create_table(ModelWithEnum)
self.database.create_table(ModelWithEnumArray) self.database.create_table(ModelWithEnumArray)
@ -19,12 +17,14 @@ class EnumFieldsTest(unittest.TestCase):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
self.database.insert([ self.database.insert(
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), [
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange), ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
]) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
query = 'SELECT * from $table ORDER BY date_field' ]
)
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnum)) results = list(self.database.select(query, ModelWithEnum))
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field, Fruit.apple) self.assertEqual(results[0].enum_field, Fruit.apple)
@ -32,12 +32,14 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(results[2].enum_field, Fruit.cherry) self.assertEqual(results[2].enum_field, Fruit.cherry)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self.database.insert([ self.database.insert(
ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple), [
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange), ModelWithEnum(date_field="2016-08-30", enum_field=Fruit.apple),
ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.cherry) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.orange),
]) ModelWithEnum(date_field="2016-08-31", enum_field=Fruit.cherry),
query = 'SELECT * from $db.modelwithenum ORDER BY date_field' ]
)
query = "SELECT * from $db.modelwithenum ORDER BY date_field"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
self.assertEqual(results[0].enum_field.name, Fruit.apple.name) self.assertEqual(results[0].enum_field.name, Fruit.apple.name)
@ -50,11 +52,11 @@ class EnumFieldsTest(unittest.TestCase):
def test_conversion(self): def test_conversion(self):
self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange) self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
self.assertEqual(ModelWithEnum(enum_field=-7).enum_field, Fruit.cherry) self.assertEqual(ModelWithEnum(enum_field=-7).enum_field, Fruit.cherry)
self.assertEqual(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple) self.assertEqual(ModelWithEnum(enum_field="apple").enum_field, Fruit.apple)
self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana) self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
def test_assignment_error(self): def test_assignment_error(self):
for value in (0, 17, 'pear', '', None, 99.9): for value in (0, 17, "pear", "", None, 99.9):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ModelWithEnum(enum_field=value) ModelWithEnum(enum_field=value)
@ -63,15 +65,15 @@ class EnumFieldsTest(unittest.TestCase):
self.assertEqual(instance.enum_field, Fruit.apple) self.assertEqual(instance.enum_field, Fruit.apple)
def test_enum_array(self): def test_enum_array(self):
instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange]) instance = ModelWithEnumArray(date_field="2016-08-30", enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
self.database.insert([instance]) self.database.insert([instance])
query = 'SELECT * from $table ORDER BY date_field' query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, ModelWithEnumArray)) results = list(self.database.select(query, ModelWithEnumArray))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0].enum_array, instance.enum_array) self.assertEqual(results[0].enum_array, instance.enum_array)
Fruit = Enum('Fruit', [('apple', 1), ('banana', 2), ('orange', 3), ('cherry', -7)]) Fruit = Enum("Fruit", [("apple", 1), ("banana", 2), ("orange", 3), ("cherry", -7)])
class ModelWithEnum(Model): class ModelWithEnum(Model):
@ -79,7 +81,7 @@ class ModelWithEnum(Model):
date_field = DateField() date_field = DateField()
enum_field = Enum8Field(Fruit) enum_field = Enum8Field(Fruit)
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))
class ModelWithEnumArray(Model): class ModelWithEnumArray(Model):
@ -87,5 +89,4 @@ class ModelWithEnumArray(Model):
date_field = DateField() date_field = DateField()
enum_array = ArrayField(Enum16Field(Fruit)) enum_array = ArrayField(Enum16Field(Fruit))
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
import unittest
from clickhouse_orm.database import Database
from clickhouse_orm.engines import MergeTree
from clickhouse_orm.fields import DateField, FixedStringField
from clickhouse_orm.models import Model
class FixedStringFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database("test-db", log_statements=True)
self.database.create_table(FixedStringModel)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert(
[
FixedStringModel(date_field="2016-08-30", fstr_field=""),
FixedStringModel(date_field="2016-08-30"),
FixedStringModel(date_field="2016-08-31", fstr_field="foo"),
FixedStringModel(date_field="2016-08-31", fstr_field=u"לילה"),
]
)
def _assert_sample_data(self, results):
self.assertEqual(len(results), 4)
self.assertEqual(results[0].fstr_field, "")
self.assertEqual(results[1].fstr_field, "ABCDEFGHIJK")
self.assertEqual(results[2].fstr_field, "foo")
self.assertEqual(results[3].fstr_field, u"לילה")
def test_insert_and_select(self):
self._insert_sample_data()
query = "SELECT * from $table ORDER BY date_field"
results = list(self.database.select(query, FixedStringModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = "SELECT * from $db.fixedstringmodel ORDER BY date_field"
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_assignment_error(self):
for value in (17, "this is too long", u"זה ארוך", None, 99.9):
with self.assertRaises(ValueError):
FixedStringModel(fstr_field=value)
class FixedStringModel(Model):
date_field = DateField()
fstr_field = FixedStringField(12, default="ABCDEFGHIJK")
engine = MergeTree("date_field", ("date_field",))

View File

@ -1,60 +1,59 @@
import unittest import unittest
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Int16Field, IPv4Field, IPv6Field from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import Memory
from infi.clickhouse_orm.engines import Memory from clickhouse_orm.fields import Int16Field, IPv4Field, IPv6Field
from clickhouse_orm.models import Model
class IPFieldsTest(unittest.TestCase): class IPFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_ipv4_field(self): def test_ipv4_field(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = IPv4Field() f = IPv4Field()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same ip) # Check valid values (all values are the same ip)
values = [ values = ["1.2.3.4", b"\x01\x02\x03\x04", 16909060, IPv4Address("1.2.3.4")]
'1.2.3.4',
b'\x01\x02\x03\x04',
16909060,
IPv4Address('1.2.3.4')
]
for index, value in enumerate(values): for index, value in enumerate(values):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
self.database.insert([rec]) self.database.insert([rec])
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv4Address(values[0])) self.assertEqual(rec.f, IPv4Address(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)
def test_ipv6_field(self): def test_ipv6_field(self):
if self.database.server_version < (19, 17): if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = IPv6Field() f = IPv6Field()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same ip) # Check valid values (all values are the same ip)
values = [ values = [
'2a02:e980:1e::1', "2a02:e980:1e::1",
b'*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01', b"*\x02\xe9\x80\x00\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
55842696359362256756849388082849382401, 55842696359362256756849388082849382401,
IPv6Address('2a02:e980:1e::1') IPv6Address("2a02:e980:1e::1"),
] ]
for index, value in enumerate(values): for index, value in enumerate(values):
rec = TestModel(i=index, f=value) rec = TestModel(i=index, f=value)
@ -62,7 +61,6 @@ class IPFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, IPv6Address(values[0])) self.assertEqual(rec.f, IPv6Address(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)

View File

@ -1,32 +1,29 @@
import unittest import unittest
from datetime import date from datetime import date
from infi.clickhouse_orm.database import Database from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE from clickhouse_orm.engines import MergeTree
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import DateField, DateTimeField, Int32Field, StringField
from infi.clickhouse_orm.engines import * from clickhouse_orm.funcs import F
from infi.clickhouse_orm.funcs import F from clickhouse_orm.models import NO_VALUE, Model
class MaterializedFieldsTest(unittest.TestCase): class MaterializedFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithMaterializedFields) self.database.create_table(ModelWithMaterializedFields)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_and_select(self): def test_insert_and_select(self):
instance = ModelWithMaterializedFields( instance = ModelWithMaterializedFields(date_time_field="2016-08-30 11:00:00", int_field=-10, str_field="TEST")
date_time_field='2016-08-30 11:00:00',
int_field=-10,
str_field='TEST'
)
self.database.insert([instance]) self.database.insert([instance])
# We can't select * from table, as it doesn't select materialized and alias fields # We can't select * from table, as it doesn't select materialized and alias fields
query = 'SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func' \ query = (
' FROM $db.%s ORDER BY mat_date' % ModelWithMaterializedFields.table_name() "SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str, mat_func"
" FROM $db.%s ORDER BY mat_date" % ModelWithMaterializedFields.table_name()
)
for model_cls in (ModelWithMaterializedFields, None): for model_cls in (ModelWithMaterializedFields, None):
results = list(self.database.select(query, model_cls)) results = list(self.database.select(query, model_cls))
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
@ -41,7 +38,7 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_assignment_error(self): def test_assignment_error(self):
# I can't prevent assigning at all, in case db.select statements with model provided sets model fields. # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
instance = ModelWithMaterializedFields() instance = ModelWithMaterializedFields()
for value in ('x', [date.today()], ['aaa'], [None]): for value in ("x", [date.today()], ["aaa"], [None]):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
instance.mat_date = value instance.mat_date = value
@ -51,10 +48,10 @@ class MaterializedFieldsTest(unittest.TestCase):
def test_duplicate_default(self): def test_duplicate_default(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(materialized='str_field', default='with default') StringField(materialized="str_field", default="with default")
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
StringField(materialized='str_field', alias='str_field') StringField(materialized="str_field", alias="str_field")
def test_default_value(self): def test_default_value(self):
instance = ModelWithMaterializedFields() instance = ModelWithMaterializedFields()
@ -66,9 +63,9 @@ class ModelWithMaterializedFields(Model):
date_time_field = DateTimeField() date_time_field = DateTimeField()
str_field = StringField() str_field = StringField()
mat_str = StringField(materialized='lower(str_field)') mat_str = StringField(materialized="lower(str_field)")
mat_int = Int32Field(materialized='abs(int_field)') mat_int = Int32Field(materialized="abs(int_field)")
mat_date = DateField(materialized=u'toDate(date_time_field)') mat_date = DateField(materialized=u"toDate(date_time_field)")
mat_func = StringField(materialized=F.lower(str_field)) mat_func = StringField(materialized=F.lower(str_field))
engine = MergeTree('mat_date', ('mat_date',)) engine = MergeTree("mat_date", ("mat_date",))

View File

@ -1,19 +1,35 @@
import unittest import unittest
from datetime import date, datetime
import pytz import pytz
from infi.clickhouse_orm.database import Database from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import MergeTree
from infi.clickhouse_orm.fields import * from clickhouse_orm.fields import (
from infi.clickhouse_orm.engines import * BaseFloatField,
from infi.clickhouse_orm.utils import comma_join BaseIntField,
DateField,
from datetime import date, datetime DateTimeField,
Float32Field,
Float64Field,
Int8Field,
Int16Field,
Int32Field,
Int64Field,
NullableField,
StringField,
UInt8Field,
UInt16Field,
UInt32Field,
UInt64Field,
)
from clickhouse_orm.models import Model
from clickhouse_orm.utils import comma_join
class NullableFieldsTest(unittest.TestCase): class NullableFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
self.database.create_table(ModelWithNullable) self.database.create_table(ModelWithNullable)
def tearDown(self): def tearDown(self):
@ -23,18 +39,20 @@ class NullableFieldsTest(unittest.TestCase):
f = NullableField(DateTimeField()) f = NullableField(DateTimeField())
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc) epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values # Valid values
for value in (date(1970, 1, 1), for value in (
datetime(1970, 1, 1), date(1970, 1, 1),
epoch, datetime(1970, 1, 1),
epoch.astimezone(pytz.timezone('US/Eastern')), epoch,
epoch.astimezone(pytz.timezone('Asia/Jerusalem')), epoch.astimezone(pytz.timezone("US/Eastern")),
'1970-01-01 00:00:00', epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
'1970-01-17 00:00:17', "1970-01-01 00:00:00",
'0000-00-00 00:00:00', "1970-01-17 00:00:17",
0, "0000-00-00 00:00:00",
'\\N'): 0,
"\\N",
):
dt = f.to_python(value, pytz.utc) dt = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(dt) self.assertIsNone(dt)
else: else:
self.assertTrue(dt.tzinfo) self.assertTrue(dt.tzinfo)
@ -42,32 +60,32 @@ class NullableFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5): for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
def test_nullable_uint8_field(self): def test_nullable_uint8_field(self):
f = NullableField(UInt8Field()) f = NullableField(UInt8Field())
# Valid values # Valid values
for value in (17, '17', 17.0, '\\N'): for value in (17, "17", 17.0, "\\N"):
python_value = f.to_python(value, pytz.utc) python_value = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(python_value) self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value)) self.assertEqual(value, f.to_db_string(python_value))
else: else:
self.assertEqual(python_value, 17) self.assertEqual(python_value, 17)
# Invalid values # Invalid values
for value in ('nope', date.today()): for value in ("nope", date.today()):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
def test_nullable_string_field(self): def test_nullable_string_field(self):
f = NullableField(StringField()) f = NullableField(StringField())
# Valid values # Valid values
for value in ('\\\\N', 'N', 'some text', '\\N'): for value in ("\\\\N", "N", "some text", "\\N"):
python_value = f.to_python(value, pytz.utc) python_value = f.to_python(value, pytz.utc)
if value == '\\N': if value == "\\N":
self.assertIsNone(python_value) self.assertIsNone(python_value)
self.assertEqual(value, f.to_db_string(python_value)) self.assertEqual(value, f.to_db_string(python_value))
else: else:
@ -78,7 +96,16 @@ class NullableFieldsTest(unittest.TestCase):
f = NullableField(field()) f = NullableField(field())
self.assertTrue(f.isinstance(field)) self.assertTrue(f.isinstance(field))
self.assertTrue(f.isinstance(NullableField)) self.assertTrue(f.isinstance(NullableField))
for field in (Int8Field, Int16Field, Int32Field, Int64Field, UInt8Field, UInt16Field, UInt32Field, UInt64Field): for field in (
Int8Field,
Int16Field,
Int32Field,
Int64Field,
UInt8Field,
UInt16Field,
UInt32Field,
UInt64Field,
):
f = NullableField(field()) f = NullableField(field())
self.assertTrue(f.isinstance(BaseIntField)) self.assertTrue(f.isinstance(BaseIntField))
for field in (Float32Field, Float64Field): for field in (Float32Field, Float64Field):
@ -91,12 +118,25 @@ class NullableFieldsTest(unittest.TestCase):
def _insert_sample_data(self): def _insert_sample_data(self):
dt = date(1970, 1, 1) dt = date(1970, 1, 1)
self.database.insert([ self.database.insert(
ModelWithNullable(date_field='2016-08-30', null_str='', null_int=42, null_date=dt), [
ModelWithNullable(date_field='2016-08-30', null_str='nothing', null_int=None, null_date=None), ModelWithNullable(date_field="2016-08-30", null_str="", null_int=42, null_date=dt),
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=42, null_date=dt), ModelWithNullable(
ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=None, null_date=None, null_default=None) date_field="2016-08-30",
]) null_str="nothing",
null_int=None,
null_date=None,
),
ModelWithNullable(date_field="2016-08-31", null_str=None, null_int=42, null_date=dt),
ModelWithNullable(
date_field="2016-08-31",
null_str=None,
null_int=None,
null_date=None,
null_default=None,
),
]
)
def _assert_sample_data(self, results): def _assert_sample_data(self, results):
for r in results: for r in results:
@ -110,7 +150,7 @@ class NullableFieldsTest(unittest.TestCase):
self.assertEqual(results[0].null_materialized, 420) self.assertEqual(results[0].null_materialized, 420)
self.assertEqual(results[0].null_date, dt) self.assertEqual(results[0].null_date, dt)
self.assertIsNone(results[1].null_date) self.assertIsNone(results[1].null_date)
self.assertEqual(results[1].null_str, 'nothing') self.assertEqual(results[1].null_str, "nothing")
self.assertIsNone(results[1].null_date) self.assertIsNone(results[1].null_date)
self.assertIsNone(results[2].null_str) self.assertIsNone(results[2].null_str)
self.assertEqual(results[2].null_date, dt) self.assertEqual(results[2].null_date, dt)
@ -128,14 +168,14 @@ class NullableFieldsTest(unittest.TestCase):
def test_insert_and_select(self): def test_insert_and_select(self):
self._insert_sample_data() self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys()) fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $table ORDER BY date_field' % fields query = "SELECT %s from $table ORDER BY date_field" % fields
results = list(self.database.select(query, ModelWithNullable)) results = list(self.database.select(query, ModelWithNullable))
self._assert_sample_data(results) self._assert_sample_data(results)
def test_ad_hoc_model(self): def test_ad_hoc_model(self):
self._insert_sample_data() self._insert_sample_data()
fields = comma_join(ModelWithNullable.fields().keys()) fields = comma_join(ModelWithNullable.fields().keys())
query = 'SELECT %s from $db.modelwithnullable ORDER BY date_field' % fields query = "SELECT %s from $db.modelwithnullable ORDER BY date_field" % fields
results = list(self.database.select(query)) results = list(self.database.select(query))
self._assert_sample_data(results) self._assert_sample_data(results)
@ -143,11 +183,11 @@ class NullableFieldsTest(unittest.TestCase):
class ModelWithNullable(Model): class ModelWithNullable(Model):
date_field = DateField() date_field = DateField()
null_str = NullableField(StringField(), extra_null_values={''}) null_str = NullableField(StringField(), extra_null_values={""})
null_int = NullableField(Int32Field()) null_int = NullableField(Int32Field())
null_date = NullableField(DateField()) null_date = NullableField(DateField())
null_default = NullableField(Int32Field(), default=7) null_default = NullableField(Int32Field(), default=7)
null_alias = NullableField(Int32Field(), alias='null_int/2') null_alias = NullableField(Int32Field(), alias="null_int/2")
null_materialized = NullableField(Int32Field(), alias='null_int*10') null_materialized = NullableField(Int32Field(), alias="null_int*10")
engine = MergeTree('date_field', ('date_field',)) engine = MergeTree("date_field", ("date_field",))

View File

@ -1,19 +1,30 @@
import unittest import unittest
from infi.clickhouse_orm.fields import *
from datetime import date, datetime from datetime import date, datetime
import pytz import pytz
from clickhouse_orm.fields import DateField, DateTime64Field, DateTimeField, UInt8Field
class SimpleFieldsTest(unittest.TestCase): class SimpleFieldsTest(unittest.TestCase):
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc) epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values # Valid values
dates = [ dates = [
date(1970, 1, 1), datetime(1970, 1, 1), epoch, date(1970, 1, 1),
epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')), datetime(1970, 1, 1),
'1970-01-01 00:00:00', '1970-01-17 00:00:17', '0000-00-00 00:00:00', 0, epoch,
'2017-07-26T08:31:05', '2017-07-26T08:31:05Z', '2017-07-26 08:31', epoch.astimezone(pytz.timezone("US/Eastern")),
'2017-07-26T13:31:05+05', '2017-07-26 13:31:05+0500' epoch.astimezone(pytz.timezone("Asia/Jerusalem")),
"1970-01-01 00:00:00",
"1970-01-17 00:00:17",
"0000-00-00 00:00:00",
0,
"2017-07-26T08:31:05",
"2017-07-26T08:31:05Z",
"2017-07-26 08:31",
"2017-07-26T13:31:05+05",
"2017-07-26 13:31:05+0500",
] ]
def test_datetime_field(self): def test_datetime_field(self):
@ -25,8 +36,14 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5, for value in (
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'): "nope",
"21/7/1999",
0.5,
"2017-01 15:06:00",
"2017-01-01X15:06:00",
"2017-13-01T15:06:00",
):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
@ -35,10 +52,16 @@ class SimpleFieldsTest(unittest.TestCase):
# Valid values # Valid values
for value in self.dates + [ for value in self.dates + [
datetime(1970, 1, 1, microsecond=100000), datetime(1970, 1, 1, microsecond=100000),
pytz.timezone('US/Eastern').localize(datetime(1970, 1, 1, microsecond=100000)), pytz.timezone("US/Eastern").localize(datetime(1970, 1, 1, microsecond=100000)),
'1970-01-01 00:00:00.1', '1970-01-17 00:00:17.1', '0000-00-00 00:00:00.1', 0.1, "1970-01-01 00:00:00.1",
'2017-07-26T08:31:05.1', '2017-07-26T08:31:05.1Z', '2017-07-26 08:31.1', "1970-01-17 00:00:17.1",
'2017-07-26T13:31:05.1+05', '2017-07-26 13:31:05.1+0500' "0000-00-00 00:00:00.1",
0.1,
"2017-07-26T08:31:05.1",
"2017-07-26T08:31:05.1Z",
"2017-07-26 08:31.1",
"2017-07-26T13:31:05.1+05",
"2017-07-26 13:31:05.1+0500",
]: ]:
dt = f.to_python(value, pytz.utc) dt = f.to_python(value, pytz.utc)
self.assertTrue(dt.tzinfo) self.assertTrue(dt.tzinfo)
@ -46,8 +69,13 @@ class SimpleFieldsTest(unittest.TestCase):
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
self.assertEqual(dt, dt2) self.assertEqual(dt, dt2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', for value in (
'2017-01 15:06:00', '2017-01-01X15:06:00', '2017-13-01T15:06:00'): "nope",
"21/7/1999",
"2017-01 15:06:00",
"2017-01-01X15:06:00",
"2017-13-01T15:06:00",
):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
@ -56,21 +84,21 @@ class SimpleFieldsTest(unittest.TestCase):
f = DateTime64Field(precision=precision, timezone=pytz.utc) f = DateTime64Field(precision=precision, timezone=pytz.utc)
dt = f.to_python(datetime(2000, 1, 1, microsecond=123456), pytz.utc) dt = f.to_python(datetime(2000, 1, 1, microsecond=123456), pytz.utc)
dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc) dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
m = round(123456, precision - 6) # round rightmost microsecond digits according to precision m = round(123456, precision - 6) # round rightmost microsecond digits according to precision
self.assertEqual(dt2, dt.replace(microsecond=m)) self.assertEqual(dt2, dt.replace(microsecond=m))
def test_date_field(self): def test_date_field(self):
f = DateField() f = DateField()
epoch = date(1970, 1, 1) epoch = date(1970, 1, 1)
# Valid values # Valid values
for value in (datetime(1970, 1, 1), epoch, '1970-01-01', '0000-00-00', 0): for value in (datetime(1970, 1, 1), epoch, "1970-01-01", "0000-00-00", 0):
d = f.to_python(value, pytz.utc) d = f.to_python(value, pytz.utc)
self.assertEqual(d, epoch) self.assertEqual(d, epoch)
# Verify that conversion to and from db string does not change value # Verify that conversion to and from db string does not change value
d2 = f.to_python(f.to_db_string(d, quote=False), pytz.utc) d2 = f.to_python(f.to_db_string(d, quote=False), pytz.utc)
self.assertEqual(d, d2) self.assertEqual(d, d2)
# Invalid values # Invalid values
for value in ('nope', '21/7/1999', 0.5): for value in ("nope", "21/7/1999", 0.5):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
# Range check # Range check
@ -81,29 +109,29 @@ class SimpleFieldsTest(unittest.TestCase):
def test_date_field_timezone(self): def test_date_field_timezone(self):
# Verify that conversion of timezone-aware datetime is correct # Verify that conversion of timezone-aware datetime is correct
f = DateField() f = DateField()
dt = datetime(2017, 10, 5, tzinfo=pytz.timezone('Asia/Jerusalem')) dt = datetime(2017, 10, 5, tzinfo=pytz.timezone("Asia/Jerusalem"))
self.assertEqual(f.to_python(dt, pytz.utc), date(2017, 10, 4)) self.assertEqual(f.to_python(dt, pytz.utc), date(2017, 10, 4))
def test_datetime_field_timezone(self): def test_datetime_field_timezone(self):
# Verify that conversion of timezone-aware datetime is correct # Verify that conversion of timezone-aware datetime is correct
f = DateTimeField() f = DateTimeField()
utc_value = datetime(2017, 7, 26, 8, 31, 5, tzinfo=pytz.UTC) utc_value = datetime(2017, 7, 26, 8, 31, 5, tzinfo=pytz.UTC)
for value in ( for value in (
'2017-07-26T08:31:05', "2017-07-26T08:31:05",
'2017-07-26T08:31:05Z', "2017-07-26T08:31:05Z",
'2017-07-26T11:31:05+03', "2017-07-26T11:31:05+03",
'2017-07-26 11:31:05+0300', "2017-07-26 11:31:05+0300",
'2017-07-26T03:31:05-0500', "2017-07-26T03:31:05-0500",
): ):
self.assertEqual(f.to_python(value, pytz.utc), utc_value) self.assertEqual(f.to_python(value, pytz.utc), utc_value)
def test_uint8_field(self): def test_uint8_field(self):
f = UInt8Field() f = UInt8Field()
# Valid values # Valid values
for value in (17, '17', 17.0): for value in (17, "17", 17.0):
self.assertEqual(f.to_python(value, pytz.utc), 17) self.assertEqual(f.to_python(value, pytz.utc), 17)
# Invalid values # Invalid values
for value in ('nope', date.today()): for value in ("nope", date.today()):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f.to_python(value, pytz.utc) f.to_python(value, pytz.utc)
# Range check # Range check

View File

@ -1,35 +1,37 @@
import unittest import unittest
from uuid import UUID from uuid import UUID
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.fields import Int16Field, UUIDField from clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import Memory
from infi.clickhouse_orm.engines import Memory from clickhouse_orm.fields import Int16Field, UUIDField
from clickhouse_orm.models import Model
class UUIDFieldsTest(unittest.TestCase): class UUIDFieldsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_uuid_field(self): def test_uuid_field(self):
if self.database.server_version < (18, 1): if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
# Create a model # Create a model
class TestModel(Model): class TestModel(Model):
i = Int16Field() i = Int16Field()
f = UUIDField() f = UUIDField()
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
# Check valid values (all values are the same UUID) # Check valid values (all values are the same UUID)
values = [ values = [
'12345678-1234-5678-1234-567812345678', "12345678-1234-5678-1234-567812345678",
'{12345678-1234-5678-1234-567812345678}', "{12345678-1234-5678-1234-567812345678}",
'12345678123456781234567812345678', "12345678123456781234567812345678",
'urn:uuid:12345678-1234-5678-1234-567812345678', "urn:uuid:12345678-1234-5678-1234-567812345678",
b'\x12\x34\x56\x78'*4, b"\x12\x34\x56\x78" * 4,
(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678), (0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678),
0x12345678123456781234567812345678, 0x12345678123456781234567812345678,
UUID(int=0x12345678123456781234567812345678), UUID(int=0x12345678123456781234567812345678),
@ -40,7 +42,6 @@ class UUIDFieldsTest(unittest.TestCase):
for rec in TestModel.objects_in(self.database): for rec in TestModel.objects_in(self.database):
self.assertEqual(rec.f, UUID(values[0])) self.assertEqual(rec.f, UUID(values[0]))
# Check invalid values # Check invalid values
for value in [None, 'zzz', -1, '123']: for value in [None, "zzz", -1, "123"]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
TestModel(i=1, f=value) TestModel(i=1, f=value)

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model1)]
migrations.CreateTable(Model1)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.DropTable(Model1)]
migrations.DropTable(Model1)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model1)]
migrations.CreateTable(Model1)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model2)]
migrations.AlterTable(Model2)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model3)]
migrations.AlterTable(Model3)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(EnumModel1)]
migrations.CreateTable(EnumModel1)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(EnumModel2)]
migrations.AlterTable(EnumModel2)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(MaterializedModel)]
migrations.CreateTable(MaterializedModel)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(AliasModel)]
migrations.CreateTable(AliasModel)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(Model4Buffer)]
migrations.CreateTable(Model4Buffer)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTableWithBuffer(Model4Buffer_changed)]
migrations.AlterTableWithBuffer(Model4Buffer_changed)
]

View File

@ -1,9 +1,11 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
operations = [ operations = [
migrations.RunSQL("INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-01', 1, 1, 'test') "), migrations.RunSQL("INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-01', 1, 1, 'test') "),
migrations.RunSQL([ migrations.RunSQL(
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-02', 2, 2, 'test2') ", [
"INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-03', 3, 3, 'test3') ", "INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-02', 2, 2, 'test2') ",
]) "INSERT INTO `mig` (date, f1, f3, f4) VALUES ('2016-01-03', 3, 3, 'test3') ",
]
),
] ]

View File

@ -1,15 +1,12 @@
import datetime import datetime
from infi.clickhouse_orm import migrations
from test_migrations import Model3 from test_migrations import Model3
from clickhouse_orm import migrations
def forward(database): def forward(database):
database.insert([ database.insert([Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4="test4")])
Model3(date=datetime.date(2016, 1, 4), f1=4, f3=1, f4='test4')
])
operations = [ operations = [migrations.RunPython(forward)]
migrations.RunPython(forward)
]

View File

@ -1,7 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(MaterializedModel1), migrations.AlterTable(AliasModel1)]
migrations.AlterTable(MaterializedModel1),
migrations.AlterTable(AliasModel1)
]

View File

@ -1,7 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterTable(Model4_compressed), migrations.AlterTable(Model2LowCardinality)]
migrations.AlterTable(Model4_compressed),
migrations.AlterTable(Model2LowCardinality)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(ModelWithConstraints)]
migrations.CreateTable(ModelWithConstraints)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterConstraints(ModelWithConstraints2)]
migrations.AlterConstraints(ModelWithConstraints2)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.CreateTable(ModelWithIndex)]
migrations.CreateTable(ModelWithIndex)
]

View File

@ -1,6 +1,5 @@
from infi.clickhouse_orm import migrations from clickhouse_orm import migrations
from ..test_migrations import * from ..test_migrations import *
operations = [ operations = [migrations.AlterIndexes(ModelWithIndex2, reindex=True)]
migrations.AlterIndexes(ModelWithIndex2, reindex=True)
]

View File

@ -1,13 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest from clickhouse_orm.engines import Buffer
from clickhouse_orm.models import BufferModel
from infi.clickhouse_orm.models import BufferModel from .base_test_with_data import Person, TestCaseWithData, data
from infi.clickhouse_orm.engines import *
from .base_test_with_data import *
class BufferTestCase(TestCaseWithData): class BufferTestCase(TestCaseWithData):
def _insert_and_check_buffer(self, data, count): def _insert_and_check_buffer(self, data, count):
self.database.insert(data) self.database.insert(data)
self.assertEqual(count, self.database.count(PersonBuffer)) self.assertEqual(count, self.database.count(PersonBuffer))

View File

@ -1,123 +0,0 @@
import unittest
import datetime
import pytz
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model, NO_VALUE
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
from infi.clickhouse_orm.utils import parse_tsv
class CompressedFieldsTestCase(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(CompressedModel)
def tearDown(self):
self.database.drop_database()
def test_defaults(self):
# Check that all fields have their explicit or implicit defaults
instance = CompressedModel()
self.database.insert([instance])
self.assertEqual(instance.date_field, datetime.date(1970, 1, 1))
self.assertEqual(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
self.assertEqual(instance.string_field, 'dozo')
self.assertEqual(instance.int64_field, 42)
self.assertEqual(instance.float_field, 0)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, [])
def test_assignment(self):
# Check that all fields are assigned during construction
kwargs = dict(
uint64_field=217,
date_field=datetime.date(1973, 12, 6),
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
string_field='aloha',
int64_field=-50,
float_field=3.14,
nullable_field=-2.718281,
array_field=['123456789123456','','a']
)
instance = CompressedModel(**kwargs)
self.database.insert([instance])
for name, value in kwargs.items():
self.assertEqual(kwargs[name], getattr(instance, name))
def test_string_conversion(self):
# Check field conversion from string during construction
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', nullable_field=None, array_field='[a,b,c]')
self.assertEqual(instance.date_field, datetime.date(1973, 12, 6))
self.assertEqual(instance.int64_field, 100)
self.assertEqual(instance.float_field, 7)
self.assertEqual(instance.nullable_field, None)
self.assertEqual(instance.array_field, ['a', 'b', 'c'])
# Check field conversion from string during assignment
instance.int64_field = '99'
self.assertEqual(instance.int64_field, 99)
def test_to_dict(self):
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]')
self.assertDictEqual(instance.to_dict(), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
"alias_field": NO_VALUE,
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a','b','c']
})
self.assertDictEqual(instance.to_dict(include_readonly=False), {
"date_field": datetime.date(1973, 12, 6),
"int64_field": 100,
"float_field": 7.0,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
'string_field': 'dozo',
'nullable_field': None,
'uint64_field': 0,
'array_field': ['a', 'b', 'c']
})
self.assertDictEqual(
instance.to_dict(include_readonly=False, field_names=('int64_field', 'alias_field', 'datetime_field')), {
"int64_field": 100,
"datetime_field": datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc)
})
def test_confirm_compression_codec(self):
if self.database.server_version < (19, 17):
raise unittest.SkipTest('ClickHouse version too old')
instance = CompressedModel(date_field='1973-12-06', int64_field='100', float_field='7', array_field='[a,b,c]')
self.database.insert([instance])
r = self.database.raw("select name, compression_codec from system.columns where table = '{}' and database='{}' FORMAT TabSeparatedWithNamesAndTypes".format(instance.table_name(), self.database.db_name))
lines = r.splitlines()
field_names = parse_tsv(lines[0])
field_types = parse_tsv(lines[1])
data = [tuple(parse_tsv(line)) for line in lines[2:]]
self.assertListEqual(data, [('uint64_field', 'CODEC(ZSTD(10))'),
('datetime_field', 'CODEC(Delta(4), ZSTD(1))'),
('date_field', 'CODEC(Delta(4), ZSTD(22))'),
('int64_field', 'CODEC(LZ4)'),
('string_field', 'CODEC(LZ4HC(10))'),
('nullable_field', 'CODEC(ZSTD(1))'),
('array_field', 'CODEC(Delta(2), LZ4HC(0))'),
('float_field', 'CODEC(NONE)'),
('alias_field', 'CODEC(ZSTD(4))')])
class CompressedModel(Model):
uint64_field = UInt64Field(codec='ZSTD(10)')
datetime_field = DateTimeField(codec='Delta,ZSTD')
date_field = DateField(codec='Delta(4),ZSTD(22)')
int64_field = Int64Field(default=42, codec='LZ4')
string_field = StringField(default='dozo', codec='LZ4HC(10)')
nullable_field = NullableField(Float32Field(), codec='ZSTD')
array_field = ArrayField(FixedStringField(length=15), codec='Delta(2),LZ4HC')
float_field = Float32Field(codec='NONE')
alias_field = Float32Field(alias='float_field', codec='ZSTD(4)')
engine = MergeTree('datetime_field', ('uint64_field', 'datetime_field'))

View File

@ -1,44 +1,63 @@
import unittest import unittest
from infi.clickhouse_orm import * from clickhouse_orm import Constraint, Database, F, ServerError
from .base_test_with_data import Person from .base_test_with_data import Person
class ConstraintsTest(unittest.TestCase): class ConstraintsTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (19, 14, 3, 3): if self.database.server_version < (19, 14, 3, 3):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self.database.create_table(PersonWithConstraints) self.database.create_table(PersonWithConstraints)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def test_insert_valid_values(self): def test_insert_valid_values(self):
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2000-01-01", height=1.66) [
]) PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="2000-01-01",
height=1.66,
)
]
)
def test_insert_invalid_values(self): def test_insert_invalid_values(self):
with self.assertRaises(ServerError) as e: with self.assertRaises(ServerError) as e:
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="2100-01-01", height=1.66) [
]) PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="2100-01-01",
height=1.66,
)
]
)
self.assertEqual(e.code, 469) self.assertEqual(e.code, 469)
self.assertTrue('Constraint `birthday_in_the_past`' in e.message) self.assertTrue("Constraint `birthday_in_the_past`" in str(e))
with self.assertRaises(ServerError) as e: with self.assertRaises(ServerError) as e:
self.database.insert([ self.database.insert(
PersonWithConstraints(first_name="Mike", last_name="Caruzo", birthday="1970-01-01", height=3) [
]) PersonWithConstraints(
first_name="Mike",
last_name="Caruzo",
birthday="1970-01-01",
height=3,
)
]
)
self.assertEqual(e.code, 469) self.assertEqual(e.code, 469)
self.assertTrue('Constraint `max_height`' in e.message) self.assertTrue("Constraint `max_height`" in str(e))
class PersonWithConstraints(Person): class PersonWithConstraints(Person):
birthday_in_the_past = Constraint(Person.birthday <= F.today()) birthday_in_the_past = Constraint(Person.birthday <= F.today())
max_height = Constraint(Person.height <= 2.75) max_height = Constraint(Person.height <= 2.75)

View File

@ -1,18 +1,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest
import datetime import datetime
import unittest
from infi.clickhouse_orm.database import ServerError, DatabaseException from clickhouse_orm.database import Database, DatabaseException, ServerError
from infi.clickhouse_orm.models import Model from clickhouse_orm.engines import Memory
from infi.clickhouse_orm.engines import Memory from clickhouse_orm.fields import DateField, DateTimeField, Float32Field, Int32Field, StringField
from infi.clickhouse_orm.fields import * from clickhouse_orm.funcs import F
from infi.clickhouse_orm.funcs import F from clickhouse_orm.models import Model
from infi.clickhouse_orm.query import Q from clickhouse_orm.query import Q
from .base_test_with_data import *
from .base_test_with_data import Person, TestCaseWithData, data
class DatabaseTestCase(TestCaseWithData): class DatabaseTestCase(TestCaseWithData):
def test_insert__generator(self): def test_insert__generator(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
@ -33,17 +33,19 @@ class DatabaseTestCase(TestCaseWithData):
def test_insert__funcs_as_default_values(self): def test_insert__funcs_as_default_values(self):
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('Buggy in server versions before 20.1.2.4') raise unittest.SkipTest("Buggy in server versions before 20.1.2.4")
class TestModel(Model): class TestModel(Model):
a = DateTimeField(default=datetime.datetime(2020, 1, 1)) a = DateTimeField(default=datetime.datetime(2020, 1, 1))
b = DateField(default=F.toDate(a)) b = DateField(default=F.toDate(a))
c = Int32Field(default=7) c = Int32Field(default=7)
d = Int32Field(default=c * 5) d = Int32Field(default=c * 5)
engine = Memory() engine = Memory()
self.database.create_table(TestModel) self.database.create_table(TestModel)
self.database.insert([TestModel()]) self.database.insert([TestModel()])
t = TestModel.objects_in(self.database)[0] t = TestModel.objects_in(self.database)[0]
self.assertEqual(str(t.b), '2020-01-01') self.assertEqual(str(t.b), "2020-01-01")
self.assertEqual(t.d, 35) self.assertEqual(t.d, 35)
def test_count(self): def test_count(self):
@ -63,9 +65,9 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72) self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70) self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -79,10 +81,10 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query, Person)) results = list(self.database.select(query, Person))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 0) # default value self.assertEqual(results[0].height, 0) # default value
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 0) # default value self.assertEqual(results[1].height, 0) # default value
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -91,10 +93,10 @@ class DatabaseTestCase(TestCaseWithData):
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = list(self.database.select(query)) results = list(self.database.select(query))
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].__class__.__name__, 'AdHocModel') self.assertEqual(results[0].__class__.__name__, "AdHocModel")
self.assertEqual(results[0].last_name, 'Durham') self.assertEqual(results[0].last_name, "Durham")
self.assertEqual(results[0].height, 1.72) self.assertEqual(results[0].height, 1.72)
self.assertEqual(results[1].last_name, 'Scott') self.assertEqual(results[1].last_name, "Scott")
self.assertEqual(results[1].height, 1.70) self.assertEqual(results[1].height, 1.70)
self.assertEqual(results[0].get_database(), self.database) self.assertEqual(results[0].get_database(), self.database)
self.assertEqual(results[1].get_database(), self.database) self.assertEqual(results[1].get_database(), self.database)
@ -116,7 +118,7 @@ class DatabaseTestCase(TestCaseWithData):
page_num = 1 page_num = 1
instances = set() instances = set()
while True: while True:
page = self.database.paginate(Person, 'first_name, last_name', page_num, page_size) page = self.database.paginate(Person, "first_name, last_name", page_num, page_size)
self.assertEqual(page.number_of_objects, len(data)) self.assertEqual(page.number_of_objects, len(data))
self.assertGreater(page.pages_total, 0) self.assertGreater(page.pages_total, 0)
[instances.add(obj.to_tsv()) for obj in page.objects] [instances.add(obj.to_tsv()) for obj in page.objects]
@ -131,15 +133,23 @@ class DatabaseTestCase(TestCaseWithData):
# Try different page sizes # Try different page sizes
for page_size in (1, 2, 7, 10, 30, 100, 150): for page_size in (1, 2, 7, 10, 30, 100, 150):
# Ask for the last page in two different ways and verify equality # Ask for the last page in two different ways and verify equality
page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size) page_a = self.database.paginate(Person, "first_name, last_name", -1, page_size)
page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size) page_b = self.database.paginate(Person, "first_name, last_name", page_a.pages_total, page_size)
self.assertEqual(page_a[1:], page_b[1:]) self.assertEqual(page_a[1:], page_b[1:])
self.assertEqual([obj.to_tsv() for obj in page_a.objects], self.assertEqual(
[obj.to_tsv() for obj in page_b.objects]) [obj.to_tsv() for obj in page_a.objects],
[obj.to_tsv() for obj in page_b.objects],
)
def test_pagination_empty_page(self): def test_pagination_empty_page(self):
for page_num in (-1, 1, 2): for page_num in (-1, 1, 2):
page = self.database.paginate(Person, 'first_name, last_name', page_num, 10, conditions="first_name = 'Ziggy'") page = self.database.paginate(
Person,
"first_name, last_name",
page_num,
10,
conditions="first_name = 'Ziggy'",
)
self.assertEqual(page.number_of_objects, 0) self.assertEqual(page.number_of_objects, 0)
self.assertEqual(page.objects, []) self.assertEqual(page.objects, [])
self.assertEqual(page.pages_total, 0) self.assertEqual(page.pages_total, 0)
@ -149,22 +159,28 @@ class DatabaseTestCase(TestCaseWithData):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
for page_num in (0, -2, -100): for page_num in (0, -2, -100):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.database.paginate(Person, 'first_name, last_name', page_num, 100) self.database.paginate(Person, "first_name, last_name", page_num, 100)
def test_pagination_with_conditions(self): def test_pagination_with_conditions(self):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
# Conditions as string # Conditions as string
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions="first_name < 'Ava'") page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions="first_name < 'Ava'")
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
# Conditions as expression # Conditions as expression
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Person.first_name < 'Ava') page = self.database.paginate(
Person,
"first_name, last_name",
1,
100,
conditions=Person.first_name < "Ava",
)
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
# Conditions as Q object # Conditions as Q object
page = self.database.paginate(Person, 'first_name, last_name', 1, 100, conditions=Q(first_name__lt='Ava')) page = self.database.paginate(Person, "first_name, last_name", 1, 100, conditions=Q(first_name__lt="Ava"))
self.assertEqual(page.number_of_objects, 10) self.assertEqual(page.number_of_objects, 10)
def test_special_chars(self): def test_special_chars(self):
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r' s = u"אבגד \\'\"`,.;éåäöšž\n\t\0\b\r"
p = Person(first_name=s) p = Person(first_name=s)
self.database.insert([p]) self.database.insert([p])
p = list(self.database.select("SELECT * from $table", Person))[0] p = list(self.database.select("SELECT * from $table", Person))[0]
@ -174,29 +190,32 @@ class DatabaseTestCase(TestCaseWithData):
self._insert_and_check(self._sample_data(), len(data)) self._insert_and_check(self._sample_data(), len(data))
query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
results = self.database.raw(query) results = self.database.raw(query)
self.assertEqual(results, "Whitney\tDurham\t1977-09-15\t1.72\t\\N\nWhitney\tScott\t1971-07-04\t1.7\t\\N\n") self.assertEqual(
results,
"Whitney\tDurham\t1977-09-15\t1.72\t\\N\nWhitney\tScott\t1971-07-04\t1.7\t\\N\n",
)
def test_invalid_user(self): def test_invalid_user(self):
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
Database(self.database.db_name, username='default', password='wrong') Database(self.database.db_name, username="default", password="wrong")
exc = cm.exception exc = cm.exception
if exc.code == 193: # ClickHouse version < 20.3 if exc.code == 193: # ClickHouse version < 20.3
self.assertTrue(exc.message.startswith('Wrong password for user default')) self.assertTrue(exc.message.startswith("Wrong password for user default"))
elif exc.code == 516: # ClickHouse version >= 20.3 elif exc.code == 516: # ClickHouse version >= 20.3
self.assertTrue(exc.message.startswith('default: Authentication failed')) self.assertTrue(exc.message.startswith("default: Authentication failed"))
else: else:
raise Exception('Unexpected error code - %s' % exc.code) raise Exception("Unexpected error code - %s" % exc.code)
def test_nonexisting_db(self): def test_nonexisting_db(self):
db = Database('db_not_here', autocreate=False) db = Database("db_not_here", autocreate=False)
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
db.create_table(Person) db.create_table(Person)
exc = cm.exception exc = cm.exception
self.assertEqual(exc.code, 81) self.assertEqual(exc.code, 81)
self.assertTrue(exc.message.startswith("Database db_not_here doesn't exist")) self.assertTrue(exc.message.startswith("Database db_not_here doesn't exist"))
# Create and delete the db twice, to ensure db_exists gets updated # Create and delete the db twice, to ensure db_exists gets updated
for i in range(2): for _ in range(2):
# Now create the database - should succeed # Now create the database - should succeed
db.create_database() db.create_database()
self.assertTrue(db.db_exists) self.assertTrue(db.db_exists)
@ -212,25 +231,28 @@ class DatabaseTestCase(TestCaseWithData):
def test_missing_engine(self): def test_missing_engine(self):
class EnginelessModel(Model): class EnginelessModel(Model):
float_field = Float32Field() float_field = Float32Field()
with self.assertRaises(DatabaseException) as cm: with self.assertRaises(DatabaseException) as cm:
self.database.create_table(EnginelessModel) self.database.create_table(EnginelessModel)
self.assertEqual(str(cm.exception), 'EnginelessModel class must define an engine') self.assertEqual(str(cm.exception), "EnginelessModel class must define an engine")
def test_potentially_problematic_field_names(self): def test_potentially_problematic_field_names(self):
class Model1(Model): class Model1(Model):
system = StringField() system = StringField()
readonly = StringField() readonly = StringField()
engine = Memory() engine = Memory()
instance = Model1(system='s', readonly='r')
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r')) instance = Model1(system="s", readonly="r")
self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
self.database.create_table(Model1) self.database.create_table(Model1)
self.database.insert([instance]) self.database.insert([instance])
instance = Model1.objects_in(self.database)[0] instance = Model1.objects_in(self.database)[0]
self.assertEqual(instance.to_dict(), dict(system='s', readonly='r')) self.assertEqual(instance.to_dict(), dict(system="s", readonly="r"))
def test_does_table_exist(self): def test_does_table_exist(self):
class Person2(Person): class Person2(Person):
pass pass
self.assertTrue(self.database.does_table_exist(Person)) self.assertTrue(self.database.does_table_exist(Person))
self.assertFalse(self.database.does_table_exist(Person2)) self.assertFalse(self.database.does_table_exist(Person2))
@ -239,32 +261,31 @@ class DatabaseTestCase(TestCaseWithData):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.database.add_setting(0, 1) self.database.add_setting(0, 1)
# Add a setting and see that it makes the query fail # Add a setting and see that it makes the query fail
self.database.add_setting('max_columns_to_read', 1) self.database.add_setting("max_columns_to_read", 1)
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
list(self.database.select('SELECT * from system.tables')) list(self.database.select("SELECT * from system.tables"))
# Remove the setting and see that now it works # Remove the setting and see that now it works
self.database.add_setting('max_columns_to_read', None) self.database.add_setting("max_columns_to_read", None)
list(self.database.select('SELECT * from system.tables')) list(self.database.select("SELECT * from system.tables"))
def test_create_ad_hoc_field(self): def test_create_ad_hoc_field(self):
# Tests that create_ad_hoc_field works for all column types in the database # Tests that create_ad_hoc_field works for all column types in the database
from infi.clickhouse_orm.models import ModelBase from clickhouse_orm.models import ModelBase
query = "SELECT DISTINCT type FROM system.columns" query = "SELECT DISTINCT type FROM system.columns"
for row in self.database.select(query): for row in self.database.select(query):
ModelBase.create_ad_hoc_field(row.type) ModelBase.create_ad_hoc_field(row.type)
def test_get_model_for_table(self): def test_get_model_for_table(self):
# Tests that get_model_for_table works for a non-system model # Tests that get_model_for_table works for a non-system model
model = self.database.get_model_for_table('person') model = self.database.get_model_for_table("person")
self.assertFalse(model.is_system_model()) self.assertFalse(model.is_system_model())
self.assertFalse(model.is_read_only()) self.assertFalse(model.is_read_only())
self.assertEqual(model.table_name(), 'person') self.assertEqual(model.table_name(), "person")
# Read a few records # Read a few records
list(model.objects_in(self.database)[:10]) list(model.objects_in(self.database)[:10])
# Inserts should work too # Inserts should work too
self.database.insert([ self.database.insert([model(first_name="aaa", last_name="bbb", height=1.77)])
model(first_name='aaa', last_name='bbb', height=1.77)
])
def test_get_model_for_table__system(self): def test_get_model_for_table__system(self):
# Tests that get_model_for_table works for all system tables # Tests that get_model_for_table works for all system tables
@ -275,11 +296,15 @@ class DatabaseTestCase(TestCaseWithData):
self.assertTrue(model.is_system_model()) self.assertTrue(model.is_system_model())
self.assertTrue(model.is_read_only()) self.assertTrue(model.is_read_only())
self.assertEqual(model.table_name(), row.name) self.assertEqual(model.table_name(), row.name)
if row.name == "distributed_ddl_queue":
continue # Since zookeeper is not set up in our tests
# Read a few records # Read a few records
try: try:
list(model.objects_in(self.database)[:10]) list(model.objects_in(self.database)[:10])
except ServerError as e: except ServerError as e:
if 'Not enough privileges' in e.message: if "Not enough privileges" in str(e):
pass pass
else: else:
raise raise

View File

@ -1,119 +0,0 @@
import unittest
import datetime
import pytz
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class DateFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old')
self.database.create_table(ModelWithDate)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert([
ModelWithDate(
date_field='2016-08-30',
datetime_field='2016-08-30 03:50:00',
datetime64_field='2016-08-30 03:50:00.123456',
datetime64_3_field='2016-08-30 03:50:00.123456'
),
ModelWithDate(
date_field='2016-08-31',
datetime_field='2016-08-31 01:30:00',
datetime64_field='2016-08-31 01:30:00.123456',
datetime64_3_field='2016-08-31 01:30:00.123456')
])
# toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
query = 'SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field'
results = list(self.database.select(query))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
self.assertEqual(results[0].datetime_field, datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].hour_start, datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
self.assertEqual(results[1].datetime_field, datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].hour_start, datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_3_field, datetime.datetime(2016, 8, 30, 3, 50, 0, 123000,
tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123456, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_3_field, datetime.datetime(2016, 8, 31, 1, 30, 0, 123000,
tzinfo=pytz.UTC))
class ModelWithDate(Model):
date_field = DateField()
datetime_field = DateTimeField()
datetime64_field = DateTime64Field()
datetime64_3_field = DateTime64Field(precision=3)
engine = MergeTree('date_field', ('date_field',))
class ModelWithTz(Model):
datetime_no_tz_field = DateTimeField() # server tz
datetime_tz_field = DateTimeField(timezone='Europe/Madrid')
datetime64_tz_field = DateTime64Field(timezone='Europe/Madrid')
datetime_utc_field = DateTimeField(timezone=pytz.UTC)
engine = MergeTree('datetime_no_tz_field', ('datetime_no_tz_field',))
class DateTimeFieldWithTzTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old')
self.database.create_table(ModelWithTz)
def tearDown(self):
self.database.drop_database()
def test_ad_hoc_model(self):
self.database.insert([
ModelWithTz(
datetime_no_tz_field='2020-06-11 04:00:00',
datetime_tz_field='2020-06-11 04:00:00',
datetime64_tz_field='2020-06-11 04:00:00',
datetime_utc_field='2020-06-11 04:00:00',
),
ModelWithTz(
datetime_no_tz_field='2020-06-11 07:00:00+0300',
datetime_tz_field='2020-06-11 07:00:00+0300',
datetime64_tz_field='2020-06-11 07:00:00+0300',
datetime_utc_field='2020-06-11 07:00:00+0300',
),
])
query = 'SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field'
results = list(self.database.select(query))
self.assertEqual(results[0].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime64_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_no_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime64_tz_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[1].datetime_utc_field, datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
self.assertEqual(results[0].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[0].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[0].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[0].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone)
self.assertEqual(results[1].datetime_no_tz_field.tzinfo.zone, self.database.server_timezone.zone)
self.assertEqual(results[1].datetime_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[1].datetime64_tz_field.tzinfo.zone, pytz.timezone('Europe/Madrid').zone)
self.assertEqual(results[1].datetime_utc_field.tzinfo.zone, pytz.timezone('UTC').zone)

View File

@ -1,121 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from decimal import Decimal
from infi.clickhouse_orm.database import Database, ServerError
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class DecimalFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
try:
self.database.create_table(DecimalModel)
except ServerError as e:
# This ClickHouse version does not support decimals yet
raise unittest.SkipTest(e.message)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert([
DecimalModel(date_field='2016-08-20'),
DecimalModel(date_field='2016-08-21', dec=Decimal('1.234')),
DecimalModel(date_field='2016-08-22', dec32=Decimal('12342.2345')),
DecimalModel(date_field='2016-08-23', dec64=Decimal('12342.23456')),
DecimalModel(date_field='2016-08-24', dec128=Decimal('-4545456612342.234567')),
])
def _assert_sample_data(self, results):
self.assertEqual(len(results), 5)
self.assertEqual(results[0].dec, Decimal(0))
self.assertEqual(results[0].dec32, Decimal(17))
self.assertEqual(results[1].dec, Decimal('1.234'))
self.assertEqual(results[2].dec32, Decimal('12342.2345'))
self.assertEqual(results[3].dec64, Decimal('12342.23456'))
self.assertEqual(results[4].dec128, Decimal('-4545456612342.234567'))
def test_insert_and_select(self):
self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, DecimalModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = 'SELECT * from decimalmodel ORDER BY date_field'
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_rounding(self):
d = Decimal('11111.2340000000000000001')
self.database.insert([DecimalModel(date_field='2016-08-20', dec=d, dec32=d, dec64=d, dec128=d)])
m = DecimalModel.objects_in(self.database)[0]
for val in (m.dec, m.dec32, m.dec64, m.dec128):
self.assertEqual(val, Decimal('11111.234'))
def test_assignment_ok(self):
for value in (True, False, 17, 3.14, '20.5', Decimal('20.5')):
DecimalModel(dec=value)
def test_assignment_error(self):
for value in ('abc', u'זה ארוך', None, float('NaN'), Decimal('-Infinity')):
with self.assertRaises(ValueError):
DecimalModel(dec=value)
def test_aggregation(self):
self._insert_sample_data()
result = DecimalModel.objects_in(self.database).aggregate(m='min(dec)', n='max(dec)')
self.assertEqual(result[0].m, Decimal(0))
self.assertEqual(result[0].n, Decimal('1.234'))
def test_precision_and_scale(self):
# Go over all valid combinations
for precision in range(1, 39):
for scale in range(0, precision + 1):
f = DecimalField(precision, scale)
# Some invalid combinations
for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
with self.assertRaises(AssertionError):
f = DecimalField(precision, scale)
def test_min_max(self):
# In range
f = DecimalField(3, 1)
f.validate(f.to_python('99.9', None))
f.validate(f.to_python('-99.9', None))
# In range after rounding
f.validate(f.to_python('99.94', None))
f.validate(f.to_python('-99.94', None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python('99.99', None))
with self.assertRaises(ValueError):
f.validate(f.to_python('-99.99', None))
# In range
f = Decimal32Field(4)
f.validate(f.to_python('99999.9999', None))
f.validate(f.to_python('-99999.9999', None))
# In range after rounding
f.validate(f.to_python('99999.99994', None))
f.validate(f.to_python('-99999.99994', None))
# Out of range
with self.assertRaises(ValueError):
f.validate(f.to_python('100000', None))
with self.assertRaises(ValueError):
f.validate(f.to_python('-100000', None))
class DecimalModel(Model):
date_field = DateField()
dec = DecimalField(15, 3)
dec32 = Decimal32Field(4, default=17)
dec64 = Decimal64Field(5)
dec128 = Decimal128Field(6)
engine = Memory()

View File

@ -1,41 +1,43 @@
import unittest
import logging import logging
import unittest
from infi.clickhouse_orm import * from clickhouse_orm import Database, F, Memory, Model, StringField, UInt64Field
class DictionaryTestMixin: class DictionaryTestMixin:
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 11, 73): if self.database.server_version < (20, 1, 11, 73):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
self._create_dictionary() self._create_dictionary()
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
def _test_func(self, func, expected_value): def _call_func(self, func):
sql = 'SELECT %s AS value' % func.to_sql() sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql) logging.info(sql)
result = list(self.database.select(sql)) result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
print('Comparing %s to %s' % (result[0].value, expected_value)) return result
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None def _test_func(self, func, expected_value):
result = self._call_func(func)
print("Comparing %s to %s" % (result[0].value, expected_value))
assert result[0].value == expected_value
class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase): class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self): def _create_dictionary(self):
# Create a table to be used as source for the dictionary # Create a table to be used as source for the dictionary
self.database.create_table(NumberName) self.database.create_table(NumberName)
self.database.insert( self.database.insert(
NumberName(number=i, name=name) NumberName(number=i, name=name)
for i, name in enumerate('Zero One Two Three Four Five Six Seven Eight Nine Ten'.split()) for i, name in enumerate("Zero One Two Three Four Five Six Seven Eight Nine Ten".split())
) )
# Create the dictionary # Create the dictionary
self.database.raw(""" self.database.raw(
"""
CREATE DICTIONARY numbers_dict( CREATE DICTIONARY numbers_dict(
number UInt64, number UInt64,
name String DEFAULT '?' name String DEFAULT '?'
@ -46,16 +48,17 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
)) ))
LIFETIME(100) LIFETIME(100)
LAYOUT(HASHED()); LAYOUT(HASHED());
""") """
self.dict_name = 'test-db.numbers_dict' )
self.dict_name = "test-db.numbers_dict"
def test_dictget(self): def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(3)), 'Three') self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(3)), "Three")
self._test_func(F.dictGet(self.dict_name, 'name', F.toUInt64(99)), '?') self._test_func(F.dictGet(self.dict_name, "name", F.toUInt64(99)), "?")
def test_dictgetordefault(self): def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(3), 'n/a'), 'Three') self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(3), "n/a"), "Three")
self._test_func(F.dictGetOrDefault(self.dict_name, 'name', F.toUInt64(99), 'n/a'), 'n/a') self._test_func(F.dictGetOrDefault(self.dict_name, "name", F.toUInt64(99), "n/a"), "n/a")
def test_dicthas(self): def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -63,19 +66,21 @@ class SimpleDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase): class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def _create_dictionary(self): def _create_dictionary(self):
# Create a table to be used as source for the dictionary # Create a table to be used as source for the dictionary
self.database.create_table(Region) self.database.create_table(Region)
self.database.insert([ self.database.insert(
Region(region_id=1, parent_region=0, region_name='Russia'), [
Region(region_id=2, parent_region=1, region_name='Moscow'), Region(region_id=1, parent_region=0, region_name="Russia"),
Region(region_id=3, parent_region=2, region_name='Center'), Region(region_id=2, parent_region=1, region_name="Moscow"),
Region(region_id=4, parent_region=0, region_name='Great Britain'), Region(region_id=3, parent_region=2, region_name="Center"),
Region(region_id=5, parent_region=4, region_name='London'), Region(region_id=4, parent_region=0, region_name="Great Britain"),
]) Region(region_id=5, parent_region=4, region_name="London"),
]
)
# Create the dictionary # Create the dictionary
self.database.raw(""" self.database.raw(
"""
CREATE DICTIONARY regions_dict( CREATE DICTIONARY regions_dict(
region_id UInt64, region_id UInt64,
parent_region UInt64 HIERARCHICAL, parent_region UInt64 HIERARCHICAL,
@ -87,17 +92,24 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
)) ))
LIFETIME(100) LIFETIME(100)
LAYOUT(HASHED()); LAYOUT(HASHED());
""") """
self.dict_name = 'test-db.regions_dict' )
self.dict_name = "test-db.regions_dict"
def test_dictget(self): def test_dictget(self):
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(3)), 'Center') self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(3)), "Center")
self._test_func(F.dictGet(self.dict_name, 'parent_region', F.toUInt64(3)), 2) self._test_func(F.dictGet(self.dict_name, "parent_region", F.toUInt64(3)), 2)
self._test_func(F.dictGet(self.dict_name, 'region_name', F.toUInt64(99)), '?') self._test_func(F.dictGet(self.dict_name, "region_name", F.toUInt64(99)), "?")
def test_dictgetordefault(self): def test_dictgetordefault(self):
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(3), 'n/a'), 'Center') self._test_func(
self._test_func(F.dictGetOrDefault(self.dict_name, 'region_name', F.toUInt64(99), 'n/a'), 'n/a') F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(3), "n/a"),
"Center",
)
self._test_func(
F.dictGetOrDefault(self.dict_name, "region_name", F.toUInt64(99), "n/a"),
"n/a",
)
def test_dicthas(self): def test_dicthas(self):
self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1) self._test_func(F.dictHas(self.dict_name, F.toUInt64(3)), 1)
@ -105,7 +117,10 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
def test_dictgethierarchy(self): def test_dictgethierarchy(self):
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1]) self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(3)), [3, 2, 1])
self._test_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)), [99]) # Default behaviour changed in CH, but we're not really testing that
default = self._call_func(F.dictGetHierarchy(self.dict_name, F.toUInt64(99)))
assert isinstance(default, list)
assert len(default) <= 1 # either [] or [99]
def test_dictisin(self): def test_dictisin(self):
self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1) self._test_func(F.dictIsIn(self.dict_name, F.toUInt64(3), F.toUInt64(1)), 1)
@ -114,7 +129,7 @@ class HierarchicalDictionaryTest(DictionaryTestMixin, unittest.TestCase):
class NumberName(Model): class NumberName(Model):
''' A table to act as a source for the dictionary ''' """A table to act as a source for the dictionary"""
number = UInt64Field() number = UInt64Field()
name = StringField() name = StringField()

View File

@ -1,16 +1,29 @@
import unittest
import datetime import datetime
from infi.clickhouse_orm import *
import logging import logging
import unittest
from clickhouse_orm.database import Database, DatabaseException, ServerError
from clickhouse_orm.engines import (
CollapsingMergeTree,
Log,
Memory,
Merge,
MergeTree,
ReplacingMergeTree,
SummingMergeTree,
TinyLog,
)
from clickhouse_orm.fields import DateField, Int8Field, UInt8Field, UInt16Field, UInt32Field
from clickhouse_orm.funcs import F
from clickhouse_orm.models import Distributed, DistributedModel, MergeModel, Model
from clickhouse_orm.system_models import SystemPart
logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
class _EnginesHelperTestCase(unittest.TestCase): class _EnginesHelperTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -19,32 +32,57 @@ class _EnginesHelperTestCase(unittest.TestCase):
class EnginesTestCase(_EnginesHelperTestCase): class EnginesTestCase(_EnginesHelperTestCase):
def _create_and_insert(self, model_class): def _create_and_insert(self, model_class):
self.database.create_table(model_class) self.database.create_table(model_class)
self.database.insert([ self.database.insert(
model_class(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1) [
]) model_class(
date="2017-01-01",
event_id=23423,
event_group=13,
event_count=7,
event_version=1,
)
]
)
def test_merge_tree(self): def test_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group')) engine = MergeTree("date", ("date", "event_id", "event_group"))
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_sampling(self): def test_merge_tree_with_sampling(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', 'intHash32(event_id)'), sampling_expr='intHash32(event_id)') engine = MergeTree(
"date",
("date", "event_id", "event_group", "intHash32(event_id)"),
sampling_expr="intHash32(event_id)",
)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_sampling__funcs(self): def test_merge_tree_with_sampling__funcs(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group', F.intHash32(SampleModel.event_id)), sampling_expr=F.intHash32(SampleModel.event_id)) engine = MergeTree(
"date",
("date", "event_id", "event_group", F.intHash32(SampleModel.event_id)),
sampling_expr=F.intHash32(SampleModel.event_id),
)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge_tree_with_granularity(self): def test_merge_tree_with_granularity(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), index_granularity=4096) engine = MergeTree("date", ("date", "event_id", "event_group"), index_granularity=4096)
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_replicated_merge_tree(self): def test_replicated_merge_tree(self):
engine = MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits', replica_name='{replica}') engine = MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
replica_name="{replica}",
)
# In ClickHouse 1.1.54310 custom partitioning key was introduced and new syntax is used # In ClickHouse 1.1.54310 custom partitioning key was introduced and new syntax is used
if self.database.server_version >= (1, 1, 54310): if self.database.server_version >= (1, 1, 54310):
expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}') PARTITION BY (toYYYYMM(`date`)) ORDER BY (date, event_id, event_group) SETTINGS index_granularity=8192" expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}') PARTITION BY (toYYYYMM(`date`)) ORDER BY (date, event_id, event_group) SETTINGS index_granularity=8192"
@ -54,38 +92,48 @@ class EnginesTestCase(_EnginesHelperTestCase):
def test_replicated_merge_tree_incomplete(self): def test_replicated_merge_tree_incomplete(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_table_path='/clickhouse/tables/{layer}-{shard}/hits') MergeTree(
"date",
("date", "event_id", "event_group"),
replica_table_path="/clickhouse/tables/{layer}-{shard}/hits",
)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
MergeTree('date', ('date', 'event_id', 'event_group'), replica_name='{replica}') MergeTree("date", ("date", "event_id", "event_group"), replica_name="{replica}")
def test_collapsing_merge_tree(self): def test_collapsing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = CollapsingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_version') engine = CollapsingMergeTree("date", ("date", "event_id", "event_group"), "event_version")
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_summing_merge_tree(self): def test_summing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = SummingMergeTree('date', ('date', 'event_group'), ('event_count',)) engine = SummingMergeTree("date", ("date", "event_group"), ("event_count",))
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_replacing_merge_tree(self): def test_replacing_merge_tree(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = ReplacingMergeTree('date', ('date', 'event_id', 'event_group'), 'event_uversion') engine = ReplacingMergeTree("date", ("date", "event_id", "event_group"), "event_uversion")
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_tiny_log(self): def test_tiny_log(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = TinyLog() engine = TinyLog()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_log(self): def test_log(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = Log() engine = Log()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_memory(self): def test_memory(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = Memory() engine = Memory()
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
def test_merge(self): def test_merge(self):
@ -96,7 +144,7 @@ class EnginesTestCase(_EnginesHelperTestCase):
engine = TinyLog() engine = TinyLog()
class TestMergeModel(MergeModel, SampleModel): class TestMergeModel(MergeModel, SampleModel):
engine = Merge('^testmodel') engine = Merge("^testmodel")
self.database.create_table(TestModel1) self.database.create_table(TestModel1)
self.database.create_table(TestModel2) self.database.create_table(TestModel2)
@ -104,54 +152,87 @@ class EnginesTestCase(_EnginesHelperTestCase):
# Insert operations are restricted for this model type # Insert operations are restricted for this model type
with self.assertRaises(DatabaseException): with self.assertRaises(DatabaseException):
self.database.insert([ self.database.insert(
TestMergeModel(date='2017-01-01', event_id=23423, event_group=13, event_count=7, event_version=1) [
]) TestMergeModel(
date="2017-01-01",
event_id=23423,
event_group=13,
event_count=7,
event_version=1,
)
]
)
# Testing select # Testing select
self.database.insert([ self.database.insert(
TestModel1(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1) [
]) TestModel1(
self.database.insert([ date="2017-01-01",
TestModel2(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2) event_id=1,
]) event_group=1,
event_count=1,
event_version=1,
)
]
)
self.database.insert(
[
TestModel2(
date="2017-01-02",
event_id=2,
event_group=2,
event_count=2,
event_version=2,
)
]
)
# event_uversion is materialized field. So * won't select it and it will be zero # event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, _table, event_uversion FROM $table ORDER BY event_id', model_class=TestMergeModel) res = self.database.select(
"SELECT *, _table, event_uversion FROM $table ORDER BY event_id",
model_class=TestMergeModel,
)
res = list(res) res = list(res)
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertDictEqual({ self.assertDictEqual(
'_table': 'testmodel1', {
'date': datetime.date(2017, 1, 1), "_table": "testmodel1",
'event_id': 1, "date": datetime.date(2017, 1, 1),
'event_group': 1, "event_id": 1,
'event_count': 1, "event_group": 1,
'event_version': 1, "event_count": 1,
'event_uversion': 1 "event_version": 1,
}, res[0].to_dict(include_readonly=True)) "event_uversion": 1,
self.assertDictEqual({ },
'_table': 'testmodel2', res[0].to_dict(include_readonly=True),
'date': datetime.date(2017, 1, 2), )
'event_id': 2, self.assertDictEqual(
'event_group': 2, {
'event_count': 2, "_table": "testmodel2",
'event_version': 2, "date": datetime.date(2017, 1, 2),
'event_uversion': 2 "event_id": 2,
}, res[1].to_dict(include_readonly=True)) "event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=True),
)
def test_custom_partitioning(self): def test_custom_partitioning(self):
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree( engine = MergeTree(
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)', 'event_group') partition_key=("toYYYYMM(date)", "event_group"),
) )
class TestCollapseModel(SampleModel): class TestCollapseModel(SampleModel):
sign = Int8Field() sign = Int8Field(default=-1)
engine = CollapsingMergeTree( engine = CollapsingMergeTree(
sign_col='sign', sign_col="sign",
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)', 'event_group') partition_key=("toYYYYMM(date)", "event_group"),
) )
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
@ -161,30 +242,30 @@ class EnginesTestCase(_EnginesHelperTestCase):
parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table) parts = sorted(list(SystemPart.get(self.database)), key=lambda x: x.table)
self.assertEqual(2, len(parts)) self.assertEqual(2, len(parts))
self.assertEqual('testcollapsemodel', parts[0].table) self.assertEqual("testcollapsemodel", parts[0].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[0].partition.replace(' ', '')) self.assertEqual("(201701, 13)".replace(" ", ""), parts[0].partition.replace(" ", ""))
self.assertEqual('testmodel', parts[1].table) self.assertEqual("testmodel", parts[1].table)
self.assertEqual('(201701, 13)'.replace(' ', ''), parts[1].partition.replace(' ', '')) self.assertEqual("(201701, 13)".replace(" ", ""), parts[1].partition.replace(" ", ""))
def test_custom_primary_key(self): def test_custom_primary_key(self):
if self.database.server_version < (18, 1): if self.database.server_version < (18, 1):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
class TestModel(SampleModel): class TestModel(SampleModel):
engine = MergeTree( engine = MergeTree(
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)',), partition_key=("toYYYYMM(date)",),
primary_key=('date', 'event_id') primary_key=("date", "event_id"),
) )
class TestCollapseModel(SampleModel): class TestCollapseModel(SampleModel):
sign = Int8Field() sign = Int8Field(default=1)
engine = CollapsingMergeTree( engine = CollapsingMergeTree(
sign_col='sign', sign_col="sign",
order_by=('date', 'event_id', 'event_group'), order_by=("date", "event_id", "event_group"),
partition_key=('toYYYYMM(date)',), partition_key=("toYYYYMM(date)",),
primary_key=('date', 'event_id') primary_key=("date", "event_id"),
) )
self._create_and_insert(TestModel) self._create_and_insert(TestModel)
@ -195,28 +276,28 @@ class EnginesTestCase(_EnginesHelperTestCase):
class SampleModel(Model): class SampleModel(Model):
date = DateField() date = DateField()
event_id = UInt32Field() event_id = UInt32Field()
event_group = UInt32Field() event_group = UInt32Field()
event_count = UInt16Field() event_count = UInt16Field()
event_version = Int8Field() event_version = Int8Field()
event_uversion = UInt8Field(materialized='abs(event_version)') event_uversion = UInt8Field(materialized="abs(event_version)")
class DistributedTestCase(_EnginesHelperTestCase): class DistributedTestCase(_EnginesHelperTestCase):
def test_without_table_name(self): def test_without_table_name(self):
engine = Distributed('my_cluster') engine = Distributed("my_cluster")
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
engine.create_table_sql(self.database) engine.create_table_sql(self.database)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'Cannot create Distributed engine: specify an underlying table') self.assertEqual(str(exc), "Cannot create Distributed engine: specify an underlying table")
def test_with_table_name(self): def test_with_table_name(self):
engine = Distributed('my_cluster', 'foo') engine = Distributed("my_cluster", "foo")
sql = engine.create_table_sql(self.database) sql = engine.create_table_sql(self.database)
self.assertEqual(sql, 'Distributed(`my_cluster`, `test-db`, `foo`)') self.assertEqual(sql, "Distributed(`my_cluster`, `test-db`, `foo`)")
class TestModel(SampleModel): class TestModel(SampleModel):
engine = TinyLog() engine = TinyLog()
@ -231,7 +312,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_bad_cluster_name(self): def test_bad_cluster_name(self):
with self.assertRaises(ServerError) as cm: with self.assertRaises(ServerError) as cm:
d_model = self._create_distributed('cluster_name') d_model = self._create_distributed("cluster_name")
self.database.count(d_model) self.database.count(d_model)
exc = cm.exception exc = cm.exception
@ -243,7 +324,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log() engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2): class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost', self.TestModel) engine = Distributed("test_shard_localhost", self.TestModel)
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
@ -251,7 +332,7 @@ class DistributedTestCase(_EnginesHelperTestCase):
def test_minimal_engine(self): def test_minimal_engine(self):
class TestDistributedModel(DistributedModel, self.TestModel): class TestDistributedModel(DistributedModel, self.TestModel):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
@ -263,64 +344,89 @@ class DistributedTestCase(_EnginesHelperTestCase):
engine = Log() engine = Log()
class TestDistributedModel(DistributedModel, self.TestModel, TestModel2): class TestDistributedModel(DistributedModel, self.TestModel, TestModel2):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure ' self.assertEqual(
'that your model has exactly one non-distributed superclass') str(exc),
"When defining Distributed engine without the table_name ensure "
"that your model has exactly one non-distributed superclass",
)
def test_minimal_engine_no_superclasses(self): def test_minimal_engine_no_superclasses(self):
class TestDistributedModel(DistributedModel): class TestDistributedModel(DistributedModel):
engine = Distributed('test_shard_localhost') engine = Distributed("test_shard_localhost")
self.database.create_table(self.TestModel) self.database.create_table(self.TestModel)
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
self.database.create_table(TestDistributedModel) self.database.create_table(TestDistributedModel)
exc = cm.exception exc = cm.exception
self.assertEqual(str(exc), 'When defining Distributed engine without the table_name ensure ' self.assertEqual(
'that your model has a parent model') str(exc),
"When defining Distributed engine without the table_name ensure " "that your model has a parent model",
)
def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True): def _test_insert_select(self, local_to_distributed, test_model=TestModel, include_readonly=True):
d_model = self._create_distributed('test_shard_localhost', underlying=test_model) d_model = self._create_distributed("test_shard_localhost", underlying=test_model)
if local_to_distributed: if local_to_distributed:
to_insert, to_select = test_model, d_model to_insert, to_select = test_model, d_model
else: else:
to_insert, to_select = d_model, test_model to_insert, to_select = d_model, test_model
self.database.insert([ self.database.insert(
to_insert(date='2017-01-01', event_id=1, event_group=1, event_count=1, event_version=1), [
to_insert(date='2017-01-02', event_id=2, event_group=2, event_count=2, event_version=2) to_insert(
]) date="2017-01-01",
event_id=1,
event_group=1,
event_count=1,
event_version=1,
),
to_insert(
date="2017-01-02",
event_id=2,
event_group=2,
event_count=2,
event_version=2,
),
]
)
# event_uversion is materialized field. So * won't select it and it will be zero # event_uversion is materialized field. So * won't select it and it will be zero
res = self.database.select('SELECT *, event_uversion FROM $table ORDER BY event_id', res = self.database.select(
model_class=to_select) "SELECT *, event_uversion FROM $table ORDER BY event_id",
model_class=to_select,
)
res = [row for row in res] res = [row for row in res]
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertDictEqual({ self.assertDictEqual(
'date': datetime.date(2017, 1, 1), {
'event_id': 1, "date": datetime.date(2017, 1, 1),
'event_group': 1, "event_id": 1,
'event_count': 1, "event_group": 1,
'event_version': 1, "event_count": 1,
'event_uversion': 1 "event_version": 1,
}, res[0].to_dict(include_readonly=include_readonly)) "event_uversion": 1,
self.assertDictEqual({ },
'date': datetime.date(2017, 1, 2), res[0].to_dict(include_readonly=include_readonly),
'event_id': 2, )
'event_group': 2, self.assertDictEqual(
'event_count': 2, {
'event_version': 2, "date": datetime.date(2017, 1, 2),
'event_uversion': 2 "event_id": 2,
}, res[1].to_dict(include_readonly=include_readonly)) "event_group": 2,
"event_count": 2,
"event_version": 2,
"event_uversion": 2,
},
res[1].to_dict(include_readonly=include_readonly),
)
@unittest.skip("Bad support of materialized fields in Distributed tables "
"https://groups.google.com/forum/#!topic/clickhouse/XEYRRwZrsSc")
def test_insert_distributed_select_local(self): def test_insert_distributed_select_local(self):
return self._test_insert_select(local_to_distributed=False) return self._test_insert_select(local_to_distributed=False)

View File

@ -1,57 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from infi.clickhouse_orm.database import Database
from infi.clickhouse_orm.models import Model
from infi.clickhouse_orm.fields import *
from infi.clickhouse_orm.engines import *
class FixedStringFieldsTest(unittest.TestCase):
def setUp(self):
self.database = Database('test-db', log_statements=True)
self.database.create_table(FixedStringModel)
def tearDown(self):
self.database.drop_database()
def _insert_sample_data(self):
self.database.insert([
FixedStringModel(date_field='2016-08-30', fstr_field=''),
FixedStringModel(date_field='2016-08-30'),
FixedStringModel(date_field='2016-08-31', fstr_field='foo'),
FixedStringModel(date_field='2016-08-31', fstr_field=u'לילה')
])
def _assert_sample_data(self, results):
self.assertEqual(len(results), 4)
self.assertEqual(results[0].fstr_field, '')
self.assertEqual(results[1].fstr_field, 'ABCDEFGHIJK')
self.assertEqual(results[2].fstr_field, 'foo')
self.assertEqual(results[3].fstr_field, u'לילה')
def test_insert_and_select(self):
self._insert_sample_data()
query = 'SELECT * from $table ORDER BY date_field'
results = list(self.database.select(query, FixedStringModel))
self._assert_sample_data(results)
def test_ad_hoc_model(self):
self._insert_sample_data()
query = 'SELECT * from $db.fixedstringmodel ORDER BY date_field'
results = list(self.database.select(query))
self._assert_sample_data(results)
def test_assignment_error(self):
for value in (17, 'this is too long', u'זה ארוך', None, 99.9):
with self.assertRaises(ValueError):
FixedStringModel(fstr_field=value)
class FixedStringModel(Model):
date_field = DateField()
fstr_field = FixedStringField(12, default='ABCDEFGHIJK')
engine = MergeTree('date_field', ('date_field',))

View File

@ -1,19 +1,21 @@
import unittest
from .base_test_with_data import *
from .test_querysets import SampleModel
from datetime import date, datetime, tzinfo, timedelta
import pytz
from ipaddress import IPv4Address, IPv6Address
import logging import logging
import unittest
from datetime import date, datetime, timedelta
from decimal import Decimal from decimal import Decimal
from ipaddress import IPv4Address, IPv6Address
from infi.clickhouse_orm.database import ServerError import pytz
from infi.clickhouse_orm.utils import NO_VALUE
from infi.clickhouse_orm.funcs import F from clickhouse_orm.database import ServerError
from clickhouse_orm.fields import DateTimeField
from clickhouse_orm.funcs import F
from clickhouse_orm.utils import NO_VALUE
from .base_test_with_data import Person, TestCaseWithData
from .test_querysets import SampleModel
class FuncsTestCase(TestCaseWithData): class FuncsTestCase(TestCaseWithData):
def setUp(self): def setUp(self):
super(FuncsTestCase, self).setUp() super(FuncsTestCase, self).setUp()
self.database.insert(self._sample_data()) self.database.insert(self._sample_data())
@ -23,70 +25,75 @@ class FuncsTestCase(TestCaseWithData):
count = 0 count = 0
for instance in qs: for instance in qs:
count += 1 count += 1
logging.info('\t[%d]\t%s' % (count, instance.to_dict())) logging.info("\t[%d]\t%s" % (count, instance.to_dict()))
self.assertEqual(count, expected_count) self.assertEqual(count, expected_count)
self.assertEqual(qs.count(), expected_count) self.assertEqual(qs.count(), expected_count)
def _test_func(self, func, expected_value=NO_VALUE): def _call_func(self, func):
sql = 'SELECT %s AS value' % func.to_sql() sql = "SELECT %s AS value" % func.to_sql()
logging.info(sql) logging.info(sql)
try: try:
result = list(self.database.select(sql)) result = list(self.database.select(sql))
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
if expected_value != NO_VALUE:
print('Comparing %s to %s' % (result[0].value, expected_value))
self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
except ServerError as e: except ServerError as e:
if 'Unknown function' in e.message: if "Unknown function" in str(e):
logging.warning(e.message) logging.warning(str(e))
return # ignore functions that don't exist in the used ClickHouse version return # ignore functions that don't exist in the used ClickHouse version
raise raise
def _test_func(self, func, expected_value=NO_VALUE):
result = self._call_func(func)
if expected_value != NO_VALUE:
print("Comparing %s to %s" % (result, expected_value))
self.assertEqual(result, expected_value)
return result if result else None
def _test_aggr(self, func, expected_value=NO_VALUE): def _test_aggr(self, func, expected_value=NO_VALUE):
qs = Person.objects_in(self.database).aggregate(value=func) qs = Person.objects_in(self.database).aggregate(value=func)
logging.info(qs.as_sql()) logging.info(qs.as_sql())
try: try:
result = list(qs) result = list(qs)
logging.info('\t==> %s', result[0].value if result else '<empty>') logging.info("\t==> %s", result[0].value if result else "<empty>")
if expected_value != NO_VALUE: if expected_value != NO_VALUE:
self.assertEqual(result[0].value, expected_value) self.assertEqual(result[0].value, expected_value)
return result[0].value if result else None return result[0].value if result else None
except ServerError as e: except ServerError as e:
if 'Unknown function' in e.message: if "Unknown function" in str(e):
logging.warning(e.message) logging.warning(str(e))
return # ignore functions that don't exist in the used ClickHouse version return # ignore functions that don't exist in the used ClickHouse version
raise raise
def test_func_to_sql(self): def test_func_to_sql(self):
# No args # No args
self.assertEqual(F('func').to_sql(), 'func()') self.assertEqual(F("func").to_sql(), "func()")
# String args # String args
self.assertEqual(F('func', "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')") self.assertEqual(F("func", "Wendy's", u"Wendy's").to_sql(), "func('Wendy\\'s', 'Wendy\\'s')")
# Numeric args # Numeric args
self.assertEqual(F('func', 1, 1.1, Decimal('3.3')).to_sql(), "func(1, 1.1, 3.3)") self.assertEqual(F("func", 1, 1.1, Decimal("3.3")).to_sql(), "func(1, 1.1, 3.3)")
# Date args # Date args
self.assertEqual(F('func', date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))") self.assertEqual(F("func", date(2018, 12, 31)).to_sql(), "func(toDate('2018-12-31'))")
# Datetime args # Datetime args
self.assertEqual(F('func', datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))") self.assertEqual(F("func", datetime(2018, 12, 31)).to_sql(), "func(toDateTime('1546214400'))")
# Boolean args # Boolean args
self.assertEqual(F('func', True, False).to_sql(), "func(1, 0)") self.assertEqual(F("func", True, False).to_sql(), "func(1, 0)")
# Timezone args # Timezone args
self.assertEqual(F('func', pytz.utc).to_sql(), "func('UTC')") self.assertEqual(F("func", pytz.utc).to_sql(), "func('UTC')")
self.assertEqual(F('func', pytz.timezone('Europe/Athens')).to_sql(), "func('Europe/Athens')") self.assertEqual(F("func", pytz.timezone("Europe/Athens")).to_sql(), "func('Europe/Athens')")
# Null args # Null args
self.assertEqual(F('func', None).to_sql(), "func(NULL)") self.assertEqual(F("func", None).to_sql(), "func(NULL)")
# Fields as args # Fields as args
self.assertEqual(F('func', SampleModel.color).to_sql(), "func(`color`)") self.assertEqual(F("func", SampleModel.color).to_sql(), "func(`color`)")
# Funcs as args # Funcs as args
self.assertEqual(F('func', F('sqrt', 25)).to_sql(), 'func(sqrt(25))') self.assertEqual(F("func", F("sqrt", 25)).to_sql(), "func(sqrt(25))")
# Iterables as args # Iterables as args
x = [1, 'z', F('foo', 17)] x = [1, "z", F("foo", 17)]
for y in [x, iter(x)]: for y in [x, iter(x)]:
self.assertEqual(F('func', y, 5).to_sql(), "func([1, 'z', foo(17)], 5)") self.assertEqual(F("func", y, 5).to_sql(), "func([1, 'z', foo(17)], 5)")
# Tuples as args # Tuples as args
self.assertEqual(F('func', [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])") self.assertEqual(F("func", [(1, 2), (3, 4)]).to_sql(), "func([(1, 2), (3, 4)])")
self.assertEqual(F('func', tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)") self.assertEqual(F("func", tuple(x), 5).to_sql(), "func((1, 'z', foo(17)), 5)")
# Binary operator functions # Binary operator functions
self.assertEqual(F.plus(1, 2).to_sql(), "(1 + 2)") self.assertEqual(F.plus(1, 2).to_sql(), "(1 + 2)")
self.assertEqual(F.lessOrEquals(1, 2).to_sql(), "(1 <= 2)") self.assertEqual(F.lessOrEquals(1, 2).to_sql(), "(1 <= 2)")
@ -106,32 +113,32 @@ class FuncsTestCase(TestCaseWithData):
def test_filter_date_field(self): def test_filter_date_field(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
# People born on the 30th # People born on the 30th
self._test_qs(qs.filter(F('equals', F('toDayOfMonth', Person.birthday), 30)), 3) self._test_qs(qs.filter(F("equals", F("toDayOfMonth", Person.birthday), 30)), 3)
self._test_qs(qs.filter(F('toDayOfMonth', Person.birthday) == 30), 3) self._test_qs(qs.filter(F("toDayOfMonth", Person.birthday) == 30), 3)
self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3) self._test_qs(qs.filter(F.toDayOfMonth(Person.birthday) == 30), 3)
# People born on Sunday # People born on Sunday
self._test_qs(qs.filter(F('equals', F('toDayOfWeek', Person.birthday), 7)), 18) self._test_qs(qs.filter(F("equals", F("toDayOfWeek", Person.birthday), 7)), 18)
self._test_qs(qs.filter(F('toDayOfWeek', Person.birthday) == 7), 18) self._test_qs(qs.filter(F("toDayOfWeek", Person.birthday) == 7), 18)
self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18) self._test_qs(qs.filter(F.toDayOfWeek(Person.birthday) == 7), 18)
# People born on 1976-10-01 # People born on 1976-10-01
self._test_qs(qs.filter(F('equals', Person.birthday, '1976-10-01')), 1) self._test_qs(qs.filter(F("equals", Person.birthday, "1976-10-01")), 1)
self._test_qs(qs.filter(F('equals', Person.birthday, date(1976, 10, 1))), 1) self._test_qs(qs.filter(F("equals", Person.birthday, date(1976, 10, 1))), 1)
self._test_qs(qs.filter(Person.birthday == date(1976, 10, 1)), 1) self._test_qs(qs.filter(Person.birthday == date(1976, 10, 1)), 1)
def test_func_as_field_value(self): def test_func_as_field_value(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96) self._test_qs(qs.filter(height__gt=F.plus(1, 0.61)), 96)
self._test_qs(qs.exclude(birthday=F.today()), 100) self._test_qs(qs.exclude(birthday=F.today()), 100)
self._test_qs(qs.filter(birthday__between=['1970-01-01', F.today()]), 100) self._test_qs(qs.filter(birthday__between=["1970-01-01", F.today()]), 100)
def test_in_and_not_in(self): def test_in_and_not_in(self):
qs = Person.objects_in(self.database) qs = Person.objects_in(self.database)
self._test_qs(qs.filter(Person.first_name.isIn(['Ciaran', 'Elton'])), 4) self._test_qs(qs.filter(Person.first_name.isIn(["Ciaran", "Elton"])), 4)
self._test_qs(qs.filter(~Person.first_name.isIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.filter(~Person.first_name.isIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.filter(Person.first_name.isNotIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.filter(Person.first_name.isNotIn(["Ciaran", "Elton"])), 96)
self._test_qs(qs.exclude(Person.first_name.isIn(['Ciaran', 'Elton'])), 96) self._test_qs(qs.exclude(Person.first_name.isIn(["Ciaran", "Elton"])), 96)
# In subquery # In subquery
subquery = qs.filter(F.startsWith(Person.last_name, 'M')).only(Person.first_name) subquery = qs.filter(F.startsWith(Person.last_name, "M")).only(Person.first_name)
self._test_qs(qs.filter(Person.first_name.isIn(subquery)), 4) self._test_qs(qs.filter(Person.first_name.isIn(subquery)), 4)
def test_comparison_operators(self): def test_comparison_operators(self):
@ -213,14 +220,14 @@ class FuncsTestCase(TestCaseWithData):
dt = datetime(2018, 12, 31, 11, 22, 33) dt = datetime(2018, 12, 31, 11, 22, 33)
self._test_func(F.toYear(d), 2018) self._test_func(F.toYear(d), 2018)
self._test_func(F.toYear(dt), 2018) self._test_func(F.toYear(dt), 2018)
self._test_func(F.toISOYear(dt, 'Europe/Athens'), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1 self._test_func(F.toISOYear(dt, "Europe/Athens"), 2019) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toQuarter(d), 4) self._test_func(F.toQuarter(d), 4)
self._test_func(F.toQuarter(dt), 4) self._test_func(F.toQuarter(dt), 4)
self._test_func(F.toMonth(d), 12) self._test_func(F.toMonth(d), 12)
self._test_func(F.toMonth(dt), 12) self._test_func(F.toMonth(dt), 12)
self._test_func(F.toWeek(d), 52) self._test_func(F.toWeek(d), 52)
self._test_func(F.toWeek(dt), 52) self._test_func(F.toWeek(dt), 52)
self._test_func(F.toISOWeek(d), 1) # 2018-12-31 is ISO year 2019, week 1, day 1 self._test_func(F.toISOWeek(d), 1) # 2018-12-31 is ISO year 2019, week 1, day 1
self._test_func(F.toISOWeek(dt), 1) self._test_func(F.toISOWeek(dt), 1)
self._test_func(F.toDayOfYear(d), 365) self._test_func(F.toDayOfYear(d), 365)
self._test_func(F.toDayOfYear(dt), 365) self._test_func(F.toDayOfYear(dt), 365)
@ -239,189 +246,256 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.toStartOfYear(d), date(2018, 1, 1)) self._test_func(F.toStartOfYear(d), date(2018, 1, 1))
self._test_func(F.toStartOfYear(dt), date(2018, 1, 1)) self._test_func(F.toStartOfYear(dt), date(2018, 1, 1))
self._test_func(F.toStartOfMinute(dt), datetime(2018, 12, 31, 11, 22, 0, tzinfo=pytz.utc)) self._test_func(F.toStartOfMinute(dt), datetime(2018, 12, 31, 11, 22, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfFiveMinute(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc)) self._test_func(
self._test_func(F.toStartOfFifteenMinutes(dt), datetime(2018, 12, 31, 11, 15, 0, tzinfo=pytz.utc)) F.toStartOfFiveMinute(dt),
datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc),
)
self._test_func(
F.toStartOfFifteenMinutes(dt),
datetime(2018, 12, 31, 11, 15, 0, tzinfo=pytz.utc),
)
self._test_func(F.toStartOfHour(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)) self._test_func(F.toStartOfHour(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc))
self._test_func(F.toStartOfISOYear(dt), date(2018, 12, 31)) self._test_func(F.toStartOfISOYear(dt), date(2018, 12, 31))
self._test_func(F.toStartOfTenMinutes(dt), datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc)) self._test_func(
F.toStartOfTenMinutes(dt),
datetime(2018, 12, 31, 11, 20, 0, tzinfo=pytz.utc),
)
self._test_func(F.toStartOfWeek(dt), date(2018, 12, 30)) self._test_func(F.toStartOfWeek(dt), date(2018, 12, 30))
self._test_func(F.toTime(dt), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(F.toTime(dt), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toUnixTimestamp(dt, 'UTC'), int(dt.replace(tzinfo=pytz.utc).timestamp())) self._test_func(F.toUnixTimestamp(dt, "UTC"), int(dt.replace(tzinfo=pytz.utc).timestamp()))
self._test_func(F.toYYYYMM(d), 201812) self._test_func(F.toYYYYMM(d), 201812)
self._test_func(F.toYYYYMM(dt), 201812) self._test_func(F.toYYYYMM(dt), 201812)
self._test_func(F.toYYYYMM(dt, 'Europe/Athens'), 201812) self._test_func(F.toYYYYMM(dt, "Europe/Athens"), 201812)
self._test_func(F.toYYYYMMDD(d), 20181231) self._test_func(F.toYYYYMMDD(d), 20181231)
self._test_func(F.toYYYYMMDD(dt), 20181231) self._test_func(F.toYYYYMMDD(dt), 20181231)
self._test_func(F.toYYYYMMDD(dt, 'Europe/Athens'), 20181231) self._test_func(F.toYYYYMMDD(dt, "Europe/Athens"), 20181231)
self._test_func(F.toYYYYMMDDhhmmss(d), 20181231000000) self._test_func(F.toYYYYMMDDhhmmss(d), 20181231000000)
self._test_func(F.toYYYYMMDDhhmmss(dt, 'Europe/Athens'), 20181231132233) self._test_func(F.toYYYYMMDDhhmmss(dt, "Europe/Athens"), 20181231132233)
self._test_func(F.toRelativeYearNum(dt), 2018) self._test_func(F.toRelativeYearNum(dt), 2018)
self._test_func(F.toRelativeYearNum(dt, 'Europe/Athens'), 2018) self._test_func(F.toRelativeYearNum(dt, "Europe/Athens"), 2018)
self._test_func(F.toRelativeMonthNum(dt), 2018 * 12 + 12) self._test_func(F.toRelativeMonthNum(dt), 2018 * 12 + 12)
self._test_func(F.toRelativeMonthNum(dt, 'Europe/Athens'), 2018 * 12 + 12) self._test_func(F.toRelativeMonthNum(dt, "Europe/Athens"), 2018 * 12 + 12)
self._test_func(F.toRelativeWeekNum(dt), 2557) self._test_func(F.toRelativeWeekNum(dt), 2557)
self._test_func(F.toRelativeWeekNum(dt, 'Europe/Athens'), 2557) self._test_func(F.toRelativeWeekNum(dt, "Europe/Athens"), 2557)
self._test_func(F.toRelativeDayNum(dt), 17896) self._test_func(F.toRelativeDayNum(dt), 17896)
self._test_func(F.toRelativeDayNum(dt, 'Europe/Athens'), 17896) self._test_func(F.toRelativeDayNum(dt, "Europe/Athens"), 17896)
self._test_func(F.toRelativeHourNum(dt), 429515) self._test_func(F.toRelativeHourNum(dt), 429515)
self._test_func(F.toRelativeHourNum(dt, 'Europe/Athens'), 429515) self._test_func(F.toRelativeHourNum(dt, "Europe/Athens"), 429515)
self._test_func(F.toRelativeMinuteNum(dt), 25770922) self._test_func(F.toRelativeMinuteNum(dt), 25770922)
self._test_func(F.toRelativeMinuteNum(dt, 'Europe/Athens'), 25770922) self._test_func(F.toRelativeMinuteNum(dt, "Europe/Athens"), 25770922)
self._test_func(F.toRelativeSecondNum(dt), 1546255353) self._test_func(F.toRelativeSecondNum(dt), 1546255353)
self._test_func(F.toRelativeSecondNum(dt, 'Europe/Athens'), 1546255353) self._test_func(F.toRelativeSecondNum(dt, "Europe/Athens"), 1546255353)
self._test_func(F.timeSlot(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)) self._test_func(F.timeSlot(dt), datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc))
self._test_func(F.timeSlots(dt, 300), [datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)]) self._test_func(F.timeSlots(dt, 300), [datetime(2018, 12, 31, 11, 0, 0, tzinfo=pytz.utc)])
self._test_func(F.formatDateTime(dt, '%D %T', 'Europe/Athens'), '12/31/18 13:22:33') self._test_func(F.formatDateTime(dt, "%D %T", "Europe/Athens"), "12/31/18 13:22:33")
self._test_func(F.addDays(d, 7), date(2019, 1, 7)) self._test_func(F.addDays(d, 7), date(2019, 1, 7))
self._test_func(F.addDays(dt, 7, 'Europe/Athens')) self._test_func(F.addDays(dt, 7, "Europe/Athens"))
self._test_func(F.addHours(dt, 7, 'Europe/Athens')) self._test_func(F.addHours(dt, 7, "Europe/Athens"))
self._test_func(F.addMinutes(dt, 7, 'Europe/Athens')) self._test_func(F.addMinutes(dt, 7, "Europe/Athens"))
self._test_func(F.addMonths(d, 7), date(2019, 7, 31)) self._test_func(F.addMonths(d, 7), date(2019, 7, 31))
self._test_func(F.addMonths(dt, 7, 'Europe/Athens')) self._test_func(F.addMonths(dt, 7, "Europe/Athens"))
self._test_func(F.addQuarters(d, 7)) self._test_func(F.addQuarters(d, 7))
self._test_func(F.addQuarters(dt, 7, 'Europe/Athens')) self._test_func(F.addQuarters(dt, 7, "Europe/Athens"))
self._test_func(F.addSeconds(d, 7)) self._test_func(F.addSeconds(d, 7))
self._test_func(F.addSeconds(dt, 7, 'Europe/Athens')) self._test_func(F.addSeconds(dt, 7, "Europe/Athens"))
self._test_func(F.addWeeks(d, 7)) self._test_func(F.addWeeks(d, 7))
self._test_func(F.addWeeks(dt, 7, 'Europe/Athens')) self._test_func(F.addWeeks(dt, 7, "Europe/Athens"))
self._test_func(F.addYears(d, 7)) self._test_func(F.addYears(d, 7))
self._test_func(F.addYears(dt, 7, 'Europe/Athens')) self._test_func(F.addYears(dt, 7, "Europe/Athens"))
self._test_func(F.subtractDays(d, 3)) self._test_func(F.subtractDays(d, 3))
self._test_func(F.subtractDays(dt, 3, 'Europe/Athens')) self._test_func(F.subtractDays(dt, 3, "Europe/Athens"))
self._test_func(F.subtractHours(d, 3)) self._test_func(F.subtractHours(d, 3))
self._test_func(F.subtractHours(dt, 3, 'Europe/Athens')) self._test_func(F.subtractHours(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMinutes(d, 3)) self._test_func(F.subtractMinutes(d, 3))
self._test_func(F.subtractMinutes(dt, 3, 'Europe/Athens')) self._test_func(F.subtractMinutes(dt, 3, "Europe/Athens"))
self._test_func(F.subtractMonths(d, 3)) self._test_func(F.subtractMonths(d, 3))
self._test_func(F.subtractMonths(dt, 3, 'Europe/Athens')) self._test_func(F.subtractMonths(dt, 3, "Europe/Athens"))
self._test_func(F.subtractQuarters(d, 3)) self._test_func(F.subtractQuarters(d, 3))
self._test_func(F.subtractQuarters(dt, 3, 'Europe/Athens')) self._test_func(F.subtractQuarters(dt, 3, "Europe/Athens"))
self._test_func(F.subtractSeconds(d, 3)) self._test_func(F.subtractSeconds(d, 3))
self._test_func(F.subtractSeconds(dt, 3, 'Europe/Athens')) self._test_func(F.subtractSeconds(dt, 3, "Europe/Athens"))
self._test_func(F.subtractWeeks(d, 3)) self._test_func(F.subtractWeeks(d, 3))
self._test_func(F.subtractWeeks(dt, 3, 'Europe/Athens')) self._test_func(F.subtractWeeks(dt, 3, "Europe/Athens"))
self._test_func(F.subtractYears(d, 3)) self._test_func(F.subtractYears(d, 3))
self._test_func(F.subtractYears(dt, 3, 'Europe/Athens')) self._test_func(F.subtractYears(dt, 3, "Europe/Athens"))
self._test_func(F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3)) self._test_func(
self._test_func(F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3)) F.now() + F.toIntervalSecond(3) + F.toIntervalMinute(3) + F.toIntervalHour(3) + F.toIntervalDay(3)
self._test_func(F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000)) )
self._test_func(
F.now() + F.toIntervalWeek(3) + F.toIntervalMonth(3) + F.toIntervalQuarter(3) + F.toIntervalYear(3)
)
self._test_func(
F.now() + F.toIntervalSecond(3000) - F.toIntervalDay(3000) == F.now() + timedelta(seconds=3000, days=-3000)
)
def test_date_functions__utc_only(self): def test_date_functions_utc_only(self):
if self.database.server_timezone != pytz.utc: if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone') raise unittest.SkipTest("This test must run with UTC as the server timezone")
d = date(2018, 12, 31) d = date(2018, 12, 31)
dt = datetime(2018, 12, 31, 11, 22, 33) dt = datetime(2018, 12, 31, 11, 22, 33)
athens_tz = pytz.timezone('Europe/Athens') athens_tz = pytz.timezone("Europe/Athens")
self._test_func(F.toHour(dt), 11) self._test_func(F.toHour(dt), 11)
self._test_func(F.toStartOfDay(dt), datetime(2018, 12, 31, 0, 0, 0, tzinfo=pytz.utc)) self._test_func(F.toStartOfDay(dt), datetime(2018, 12, 31, 0, 0, 0, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, pytz.utc), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(F.toTime(dt, pytz.utc), datetime(1970, 1, 2, 11, 22, 33, tzinfo=pytz.utc))
self._test_func(F.toTime(dt, 'Europe/Athens'), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33))) self._test_func(
self._test_func(F.toTime(dt, athens_tz), athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33))) F.toTime(dt, "Europe/Athens"),
self._test_func(F.toTimeZone(dt, 'Europe/Athens'), athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33))) athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)),
self._test_func(F.now(), datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)) # FIXME this may fail if the timing is just right )
self._test_func(
F.toTime(dt, athens_tz),
athens_tz.localize(datetime(1970, 1, 2, 13, 22, 33)),
)
self._test_func(
F.toTimeZone(dt, "Europe/Athens"),
athens_tz.localize(datetime(2018, 12, 31, 13, 22, 33)),
)
self._test_func(F.today(), datetime.utcnow().date()) self._test_func(F.today(), datetime.utcnow().date())
self._test_func(F.yesterday(), datetime.utcnow().date() - timedelta(days=1)) self._test_func(F.yesterday(), datetime.utcnow().date() - timedelta(days=1))
self._test_func(F.toYYYYMMDDhhmmss(dt), 20181231112233) self._test_func(F.toYYYYMMDDhhmmss(dt), 20181231112233)
self._test_func(F.formatDateTime(dt, '%D %T'), '12/31/18 11:22:33') self._test_func(F.formatDateTime(dt, "%D %T"), "12/31/18 11:22:33")
self._test_func(F.addHours(d, 7), datetime(2018, 12, 31, 7, 0, 0, tzinfo=pytz.utc)) self._test_func(F.addHours(d, 7), datetime(2018, 12, 31, 7, 0, 0, tzinfo=pytz.utc))
self._test_func(F.addMinutes(d, 7), datetime(2018, 12, 31, 0, 7, 0, tzinfo=pytz.utc)) self._test_func(F.addMinutes(d, 7), datetime(2018, 12, 31, 0, 7, 0, tzinfo=pytz.utc))
actual = self._call_func(F.now())
expected = datetime.utcnow().replace(tzinfo=pytz.utc, microsecond=0)
self.assertLess((actual - expected).total_seconds(), 1e-3)
def test_type_conversion_functions(self): def test_type_conversion_functions(self):
for f in (F.toUInt8, F.toUInt16, F.toUInt32, F.toUInt64, F.toInt8, F.toInt16, F.toInt32, F.toInt64, F.toFloat32, F.toFloat64): for f in (
F.toUInt8,
F.toUInt16,
F.toUInt32,
F.toUInt64,
F.toInt8,
F.toInt16,
F.toInt32,
F.toInt64,
F.toFloat32,
F.toFloat64,
):
self._test_func(f(17), 17) self._test_func(f(17), 17)
self._test_func(f('17'), 17) self._test_func(f("17"), 17)
for f in (F.toUInt8OrZero, F.toUInt16OrZero, F.toUInt32OrZero, F.toUInt64OrZero, F.toInt8OrZero, F.toInt16OrZero, F.toInt32OrZero, F.toInt64OrZero, F.toFloat32OrZero, F.toFloat64OrZero): for f in (
self._test_func(f('17'), 17) F.toUInt8OrZero,
self._test_func(f('a'), 0) F.toUInt16OrZero,
F.toUInt32OrZero,
F.toUInt64OrZero,
F.toInt8OrZero,
F.toInt16OrZero,
F.toInt32OrZero,
F.toInt64OrZero,
F.toFloat32OrZero,
F.toFloat64OrZero,
):
self._test_func(f("17"), 17)
self._test_func(f("a"), 0)
for f in (F.toDecimal32, F.toDecimal64, F.toDecimal128): for f in (F.toDecimal32, F.toDecimal64, F.toDecimal128):
self._test_func(f(17.17, 2), Decimal('17.17')) self._test_func(f(17.17, 2), Decimal("17.17"))
self._test_func(f('17.17', 2), Decimal('17.17')) self._test_func(f("17.17", 2), Decimal("17.17"))
self._test_func(F.toDate('2018-12-31'), date(2018, 12, 31)) self._test_func(F.toDate("2018-12-31"), date(2018, 12, 31))
self._test_func(F.toString(123), '123') self._test_func(F.toString(123), "123")
self._test_func(F.toFixedString('123', 5), '123') self._test_func(F.toFixedString("123", 5), "123")
self._test_func(F.toStringCutToZero('123\0'), '123') self._test_func(F.toStringCutToZero("123\0"), "123")
self._test_func(F.CAST(17, 'String'), '17') self._test_func(F.CAST(17, "String"), "17")
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffort("31/12/2019 10:05AM", "Europe/Athens"))
with self.assertRaises(ServerError): with self.assertRaises(ServerError):
self._test_func(F.parseDateTimeBestEffort('foo')) self._test_func(F.parseDateTimeBestEffort("foo"))
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrNull('foo'), None) self._test_func(F.parseDateTimeBestEffortOrNull("foo"), None)
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM', 'Europe/Athens')) self._test_func(F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM", "Europe/Athens"))
self._test_func(F.parseDateTimeBestEffortOrZero('foo'), DateTimeField.class_default) self._test_func(F.parseDateTimeBestEffortOrZero("foo"), DateTimeField.class_default)
def test_type_conversion_functions__utc_only(self): def test_type_conversion_functions__utc_only(self):
if self.database.server_timezone != pytz.utc: if self.database.server_timezone != pytz.utc:
raise unittest.SkipTest('This test must run with UTC as the server timezone') raise unittest.SkipTest("This test must run with UTC as the server timezone")
self._test_func(F.toDateTime('2018-12-31 11:22:33'), datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc)) self._test_func(
self._test_func(F.toDateTime64('2018-12-31 11:22:33.001', 6), datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc)) F.toDateTime("2018-12-31 11:22:33"),
self._test_func(F.parseDateTimeBestEffort('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) datetime(2018, 12, 31, 11, 22, 33, tzinfo=pytz.utc),
self._test_func(F.parseDateTimeBestEffortOrNull('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) )
self._test_func(F.parseDateTimeBestEffortOrZero('31/12/2019 10:05AM'), datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc)) self._test_func(
F.toDateTime64("2018-12-31 11:22:33.001", 6),
datetime(2018, 12, 31, 11, 22, 33, 1000, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffort("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffortOrNull("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
self._test_func(
F.parseDateTimeBestEffortOrZero("31/12/2019 10:05AM"),
datetime(2019, 12, 31, 10, 5, tzinfo=pytz.utc),
)
def test_string_functions(self): def test_string_functions(self):
self._test_func(F.empty(''), 1) self._test_func(F.empty(""), 1)
self._test_func(F.empty('x'), 0) self._test_func(F.empty("x"), 0)
self._test_func(F.notEmpty(''), 0) self._test_func(F.notEmpty(""), 0)
self._test_func(F.notEmpty('x'), 1) self._test_func(F.notEmpty("x"), 1)
self._test_func(F.length('x'), 1) self._test_func(F.length("x"), 1)
self._test_func(F.lengthUTF8('x'), 1) self._test_func(F.lengthUTF8("x"), 1)
self._test_func(F.lower('Ab'), 'ab') self._test_func(F.lower("Ab"), "ab")
self._test_func(F.upper('Ab'), 'AB') self._test_func(F.upper("Ab"), "AB")
self._test_func(F.lowerUTF8('Ab'), 'ab') self._test_func(F.lowerUTF8("Ab"), "ab")
self._test_func(F.upperUTF8('Ab'), 'AB') self._test_func(F.upperUTF8("Ab"), "AB")
self._test_func(F.reverse('Ab'), 'bA') self._test_func(F.reverse("Ab"), "bA")
self._test_func(F.reverseUTF8('Ab'), 'bA') self._test_func(F.reverseUTF8("Ab"), "bA")
self._test_func(F.concat('Ab', 'Cd', 'Ef'), 'AbCdEf') self._test_func(F.concat("Ab", "Cd", "Ef"), "AbCdEf")
self._test_func(F.substring('123456', 3, 2), '34') self._test_func(F.substring("123456", 3, 2), "34")
self._test_func(F.substringUTF8('123456', 3, 2), '34') self._test_func(F.substringUTF8("123456", 3, 2), "34")
self._test_func(F.appendTrailingCharIfAbsent('Hello', '!'), 'Hello!') self._test_func(F.appendTrailingCharIfAbsent("Hello", "!"), "Hello!")
self._test_func(F.appendTrailingCharIfAbsent('Hello!', '!'), 'Hello!') self._test_func(F.appendTrailingCharIfAbsent("Hello!", "!"), "Hello!")
self._test_func(F.convertCharset(F.convertCharset('Hello', 'latin1', 'utf16'), 'utf16', 'latin1'), 'Hello') self._test_func(
self._test_func(F.startsWith('aaa', 'aa'), True) F.convertCharset(F.convertCharset("Hello", "latin1", "utf16"), "utf16", "latin1"),
self._test_func(F.startsWith('aaa', 'bb'), False) "Hello",
self._test_func(F.endsWith('aaa', 'aa'), True) )
self._test_func(F.endsWith('aaa', 'bb'), False) self._test_func(F.startsWith("aaa", "aa"), True)
self._test_func(F.trimLeft(' abc '), 'abc ') self._test_func(F.startsWith("aaa", "bb"), False)
self._test_func(F.trimRight(' abc '), ' abc') self._test_func(F.endsWith("aaa", "aa"), True)
self._test_func(F.trimBoth(' abc '), 'abc') self._test_func(F.endsWith("aaa", "bb"), False)
self._test_func(F.CRC32('whoops'), 3361378926) self._test_func(F.trimLeft(" abc "), "abc ")
self._test_func(F.trimRight(" abc "), " abc")
self._test_func(F.trimBoth(" abc "), "abc")
self._test_func(F.CRC32("whoops"), 3361378926)
def test_string_search_functions(self): def test_string_search_functions(self):
self._test_func(F.position('Hello, world!', '!'), 13) self._test_func(F.position("Hello, world!", "!"), 13)
self._test_func(F.positionCaseInsensitive('Hello, world!', 'hello'), 1) self._test_func(F.positionCaseInsensitive("Hello, world!", "hello"), 1)
self._test_func(F.positionUTF8('Привет, мир!', '!'), 12) self._test_func(F.positionUTF8("Привет, мир!", "!"), 12)
self._test_func(F.positionCaseInsensitiveUTF8('Привет, мир!', 'Мир'), 9) self._test_func(F.positionCaseInsensitiveUTF8("Привет, мир!", "Мир"), 9)
self._test_func(F.like('Hello, world!', '%ll%'), 1) self._test_func(F.like("Hello, world!", "%ll%"), 1)
self._test_func(F.notLike('Hello, world!', '%ll%'), 0) self._test_func(F.notLike("Hello, world!", "%ll%"), 0)
self._test_func(F.match('Hello, world!', '[lmnop]{3}'), 1) self._test_func(F.match("Hello, world!", "[lmnop]{3}"), 1)
self._test_func(F.extract('Hello, world!', '[lmnop]{3}'), 'llo') self._test_func(F.extract("Hello, world!", "[lmnop]{3}"), "llo")
self._test_func(F.extractAll('Hello, world!', '[a-z]+'), ['ello', 'world']) self._test_func(F.extractAll("Hello, world!", "[a-z]+"), ["ello", "world"])
self._test_func(F.ngramDistance('Hello', 'Hello'), 0) self._test_func(F.ngramDistance("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitive('Hello', 'hello'), 0) self._test_func(F.ngramDistanceCaseInsensitive("Hello", "hello"), 0)
self._test_func(F.ngramDistanceUTF8('Hello', 'Hello'), 0) self._test_func(F.ngramDistanceUTF8("Hello", "Hello"), 0)
self._test_func(F.ngramDistanceCaseInsensitiveUTF8('Hello', 'hello'), 0) self._test_func(F.ngramDistanceCaseInsensitiveUTF8("Hello", "hello"), 0)
self._test_func(F.ngramSearch('Hello', 'Hello'), 1) self._test_func(F.ngramSearch("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitive('Hello', 'hello'), 1) self._test_func(F.ngramSearchCaseInsensitive("Hello", "hello"), 1)
self._test_func(F.ngramSearchUTF8('Hello', 'Hello'), 1) self._test_func(F.ngramSearchUTF8("Hello", "Hello"), 1)
self._test_func(F.ngramSearchCaseInsensitiveUTF8('Hello', 'hello'), 1) self._test_func(F.ngramSearchCaseInsensitiveUTF8("Hello", "hello"), 1)
def test_base64_functions(self): def test_base64_functions(self):
try: try:
self._test_func(F.base64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.base64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(F.base64Encode('Hello')), 'Hello') self._test_func(F.tryBase64Decode(F.base64Encode("Hello")), "Hello")
self._test_func(F.tryBase64Decode(':-)')) self._test_func(F.tryBase64Decode(":-)"))
except ServerError as e: except ServerError as e:
# ClickHouse version that doesn't support these functions # ClickHouse version that doesn't support these functions
raise unittest.SkipTest(e.message) raise unittest.SkipTest(str(e))
def test_replace_functions(self): def test_replace_functions(self):
haystack = 'hello' haystack = "hello"
self._test_func(F.replace(haystack, 'l', 'L'), 'heLLo') self._test_func(F.replace(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceAll(haystack, 'l', 'L'), 'heLLo') self._test_func(F.replaceAll(haystack, "l", "L"), "heLLo")
self._test_func(F.replaceOne(haystack, 'l', 'L'), 'heLlo') self._test_func(F.replaceOne(haystack, "l", "L"), "heLlo")
self._test_func(F.replaceRegexpAll(haystack, '[eo]', 'X'), 'hXllX') self._test_func(F.replaceRegexpAll(haystack, "[eo]", "X"), "hXllX")
self._test_func(F.replaceRegexpOne(haystack, '[eo]', 'X'), 'hXllo') self._test_func(F.replaceRegexpOne(haystack, "[eo]", "X"), "hXllo")
self._test_func(F.regexpQuoteMeta('[eo]'), '\\[eo\\]') self._test_func(F.regexpQuoteMeta("[eo]"), "\\[eo\\]")
def test_math_functions(self): def test_math_functions(self):
x = 17 x = 17
@ -515,15 +589,15 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.arrayDifference(arr), [0, 1, 1]) self._test_func(F.arrayDifference(arr), [0, 1, 1])
self._test_func(F.arrayDistinct(arr + arr), arr) self._test_func(F.arrayDistinct(arr + arr), arr)
self._test_func(F.arrayIntersect(arr, [3, 4]), [3]) self._test_func(F.arrayIntersect(arr, [3, 4]), [3])
self._test_func(F.arrayReduce('min', arr), 1) self._test_func(F.arrayReduce("min", arr), 1)
self._test_func(F.arrayReverse(arr), [3, 2, 1]) self._test_func(F.arrayReverse(arr), [3, 2, 1])
def test_split_and_merge_functions(self): def test_split_and_merge_functions(self):
self._test_func(F.splitByChar('_', 'a_b_c'), ['a', 'b', 'c']) self._test_func(F.splitByChar("_", "a_b_c"), ["a", "b", "c"])
self._test_func(F.splitByString('__', 'a__b__c'), ['a', 'b', 'c']) self._test_func(F.splitByString("__", "a__b__c"), ["a", "b", "c"])
self._test_func(F.arrayStringConcat(['a', 'b', 'c']), 'abc') self._test_func(F.arrayStringConcat(["a", "b", "c"]), "abc")
self._test_func(F.arrayStringConcat(['a', 'b', 'c'], '_'), 'a_b_c') self._test_func(F.arrayStringConcat(["a", "b", "c"], "_"), "a_b_c")
self._test_func(F.alphaTokens('aaa.bbb.111'), ['aaa', 'bbb']) self._test_func(F.alphaTokens("aaa.bbb.111"), ["aaa", "bbb"])
def test_bit_functions(self): def test_bit_functions(self):
x = 17 x = 17
@ -546,23 +620,44 @@ class FuncsTestCase(TestCaseWithData):
def test_bitmap_functions(self): def test_bitmap_functions(self):
self._test_func(F.bitmapToArray(F.bitmapBuild([1, 2, 3])), [1, 2, 3]) self._test_func(F.bitmapToArray(F.bitmapBuild([1, 2, 3])), [1, 2, 3])
self._test_func(F.bitmapContains(F.bitmapBuild([1, 5, 7, 9]), F.toUInt32(9)), 1) self._test_func(F.bitmapContains(F.bitmapBuild([1, 5, 7, 9]), F.toUInt32(9)), 1)
self._test_func(F.bitmapHasAny(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 1) self._test_func(F.bitmapHasAny(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 1)
self._test_func(F.bitmapHasAll(F.bitmapBuild([1,2,3]), F.bitmapBuild([3,4,5])), 0) self._test_func(F.bitmapHasAll(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 0)
self._test_func(F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [3]) self._test_func(
self._test_func(F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 3, 4, 5]) F.bitmapToArray(F.bitmapAnd(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
self._test_func(F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2, 4, 5]) [3],
self._test_func(F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))), [1, 2]) )
self._test_func(
F.bitmapToArray(F.bitmapOr(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2, 3, 4, 5],
)
self._test_func(
F.bitmapToArray(F.bitmapXor(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2, 4, 5],
)
self._test_func(
F.bitmapToArray(F.bitmapAndnot(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5]))),
[1, 2],
)
self._test_func(F.bitmapCardinality(F.bitmapBuild([1, 2, 3, 4, 5])), 5) self._test_func(F.bitmapCardinality(F.bitmapBuild([1, 2, 3, 4, 5])), 5)
self._test_func(F.bitmapAndCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 1) self._test_func(
F.bitmapAndCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
1,
)
self._test_func(F.bitmapOrCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 5) self._test_func(F.bitmapOrCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 5)
self._test_func(F.bitmapXorCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 4) self._test_func(
self._test_func(F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])), 2) F.bitmapXorCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
4,
)
self._test_func(
F.bitmapAndnotCardinality(F.bitmapBuild([1, 2, 3]), F.bitmapBuild([3, 4, 5])),
2,
)
def test_hash_functions(self): def test_hash_functions(self):
args = ['x', 'y', 'z'] args = ["x", "y", "z"]
x = 17 x = 17
s = 'hello' s = "hello"
url = 'http://example.com/a/b/c/d' url = "http://example.com/a/b/c/d"
self._test_func(F.hex(F.MD5(s))) self._test_func(F.hex(F.MD5(s)))
self._test_func(F.hex(F.sipHash128(s))) self._test_func(F.hex(F.sipHash128(s)))
self._test_func(F.hex(F.cityHash64(*args))) self._test_func(F.hex(F.cityHash64(*args)))
@ -594,17 +689,18 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.rand(17)) self._test_func(F.rand(17))
self._test_func(F.rand64()) self._test_func(F.rand64())
self._test_func(F.rand64(17)) self._test_func(F.rand64(17))
if self.database.server_version >= (19, 15): # buggy in older versions if self.database.server_version >= (19, 15): # buggy in older versions
self._test_func(F.randConstant()) self._test_func(F.randConstant())
self._test_func(F.randConstant(17)) self._test_func(F.randConstant(17))
def test_encoding_functions(self): def test_encoding_functions(self):
self._test_func(F.hex(F.unhex('0FA1')), '0FA1') self._test_func(F.hex(F.unhex("0FA1")), "0FA1")
self._test_func(F.bitmaskToArray(17)) self._test_func(F.bitmaskToArray(17))
self._test_func(F.bitmaskToList(18)) self._test_func(F.bitmaskToList(18))
def test_uuid_functions(self): def test_uuid_functions(self):
from uuid import UUID from uuid import UUID
uuid = self._test_func(F.generateUUIDv4()) uuid = self._test_func(F.generateUUIDv4())
self.assertEqual(type(uuid), UUID) self.assertEqual(type(uuid), UUID)
s = str(uuid) s = str(uuid)
@ -612,17 +708,30 @@ class FuncsTestCase(TestCaseWithData):
self._test_func(F.UUIDNumToString(F.UUIDStringToNum(s)), s) self._test_func(F.UUIDNumToString(F.UUIDStringToNum(s)), s)
def test_ip_funcs(self): def test_ip_funcs(self):
self._test_func(F.IPv4NumToString(F.toUInt32(1)), '0.0.0.1') self._test_func(F.IPv4NumToString(F.toUInt32(1)), "0.0.0.1")
self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), '0.0.0.xxx') self._test_func(F.IPv4NumToStringClassC(F.toUInt32(1)), "0.0.0.xxx")
self._test_func(F.IPv4StringToNum('0.0.0.17'), 17) self._test_func(F.IPv4StringToNum("0.0.0.17"), 17)
self._test_func(F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum('192.168.0.1'))), '::ffff:192.168.0.1') self._test_func(
self._test_func(F.IPv6NumToString(F.IPv6StringToNum('2a02:6b8::11')), '2a02:6b8::11') F.IPv6NumToString(F.IPv4ToIPv6(F.IPv4StringToNum("192.168.0.1"))),
self._test_func(F.toIPv4('10.20.30.40'), IPv4Address('10.20.30.40')) "::ffff:192.168.0.1",
self._test_func(F.toIPv6('2001:438:ffff::407d:1bc1'), IPv6Address('2001:438:ffff::407d:1bc1')) )
self._test_func(F.IPv4CIDRToRange(F.toIPv4('192.168.5.2'), 16), self._test_func(F.IPv6NumToString(F.IPv6StringToNum("2a02:6b8::11")), "2a02:6b8::11")
[IPv4Address('192.168.0.0'), IPv4Address('192.168.255.255')]) self._test_func(F.toIPv4("10.20.30.40"), IPv4Address("10.20.30.40"))
self._test_func(F.IPv6CIDRToRange(F.toIPv6('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 32), self._test_func(
[IPv6Address('2001:db8::'), IPv6Address('2001:db8:ffff:ffff:ffff:ffff:ffff:ffff')]) F.toIPv6("2001:438:ffff::407d:1bc1"),
IPv6Address("2001:438:ffff::407d:1bc1"),
)
self._test_func(
F.IPv4CIDRToRange(F.toIPv4("192.168.5.2"), 16),
[IPv4Address("192.168.0.0"), IPv4Address("192.168.255.255")],
)
self._test_func(
F.IPv6CIDRToRange(F.toIPv6("2001:0db8:0000:85a3:0000:0000:ac1f:8001"), 32),
[
IPv6Address("2001:db8::"),
IPv6Address("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff"),
],
)
def test_aggregate_funcs(self): def test_aggregate_funcs(self):
self._test_aggr(F.any(Person.first_name)) self._test_aggr(F.any(Person.first_name))
@ -630,7 +739,10 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.anyLast(Person.first_name)) self._test_aggr(F.anyLast(Person.first_name))
self._test_aggr(F.argMin(Person.first_name, Person.height)) self._test_aggr(F.argMin(Person.first_name, Person.height))
self._test_aggr(F.argMax(Person.first_name, Person.height)) self._test_aggr(F.argMax(Person.first_name, Person.height))
self._test_aggr(F.round(F.avg(Person.height), 4), sum(p.height for p in self._sample_data()) / 100) self._test_aggr(
F.round(F.avg(Person.height), 4),
sum(p.height for p in self._sample_data()) / 100,
)
self._test_aggr(F.corr(Person.height, Person.height), 1) self._test_aggr(F.corr(Person.height, Person.height), 1)
self._test_aggr(F.count(), 100) self._test_aggr(F.count(), 100)
self._test_aggr(F.round(F.covarPop(Person.height, Person.height), 2), 0) self._test_aggr(F.round(F.covarPop(Person.height, Person.height), 2), 0)
@ -649,32 +761,32 @@ class FuncsTestCase(TestCaseWithData):
self._test_aggr(F.varSamp(Person.height)) self._test_aggr(F.varSamp(Person.height))
def test_aggregate_funcs__or_default(self): def test_aggregate_funcs__or_default(self):
self.database.raw('TRUNCATE TABLE person') self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrDefault(), 0) self._test_aggr(F.countOrDefault(), 0)
self._test_aggr(F.maxOrDefault(Person.height), 0) self._test_aggr(F.maxOrDefault(Person.height), 0)
def test_aggregate_funcs__or_null(self): def test_aggregate_funcs__or_null(self):
self.database.raw('TRUNCATE TABLE person') self.database.raw("TRUNCATE TABLE person")
self._test_aggr(F.countOrNull(), None) self._test_aggr(F.countOrNull(), None)
self._test_aggr(F.maxOrNull(Person.height), None) self._test_aggr(F.maxOrNull(Person.height), None)
def test_aggregate_funcs__if(self): def test_aggregate_funcs__if(self):
self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > 'H')) self._test_aggr(F.argMinIf(Person.first_name, Person.height, Person.last_name > "H"))
self._test_aggr(F.countIf(Person.last_name > 'H'), 57) self._test_aggr(F.countIf(Person.last_name > "H"), 57)
self._test_aggr(F.minIf(Person.height, Person.last_name > 'H'), 1.6) self._test_aggr(F.minIf(Person.height, Person.last_name > "H"), 1.6)
def test_aggregate_funcs__or_default_if(self): def test_aggregate_funcs__or_default_if(self):
self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > 'Z')) self._test_aggr(F.argMinOrDefaultIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrDefaultIf(Person.last_name > 'Z'), 0) self._test_aggr(F.countOrDefaultIf(Person.last_name > "Z"), 0)
self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > 'Z'), 0) self._test_aggr(F.minOrDefaultIf(Person.height, Person.last_name > "Z"), 0)
def test_aggregate_funcs__or_null_if(self): def test_aggregate_funcs__or_null_if(self):
self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > 'Z')) self._test_aggr(F.argMinOrNullIf(Person.first_name, Person.height, Person.last_name > "Z"))
self._test_aggr(F.countOrNullIf(Person.last_name > 'Z'), None) self._test_aggr(F.countOrNullIf(Person.last_name > "Z"), None)
self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > 'Z'), None) self._test_aggr(F.minOrNullIf(Person.height, Person.last_name > "Z"), None)
def test_quantile_funcs(self): def test_quantile_funcs(self):
cond = Person.last_name > 'H' cond = Person.last_name > "H"
weight_expr = F.toUInt32(F.round(Person.height)) weight_expr = F.toUInt32(F.round(Person.height))
# Quantile # Quantile
self._test_aggr(F.quantile(0.9)(Person.height)) self._test_aggr(F.quantile(0.9)(Person.height))
@ -712,13 +824,13 @@ class FuncsTestCase(TestCaseWithData):
def test_top_k_funcs(self): def test_top_k_funcs(self):
self._test_aggr(F.topK(3)(Person.height)) self._test_aggr(F.topK(3)(Person.height))
self._test_aggr(F.topKOrDefault(3)(Person.height)) self._test_aggr(F.topKOrDefault(3)(Person.height))
self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > 'H')) self._test_aggr(F.topKIf(3)(Person.height, Person.last_name > "H"))
self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > 'H')) self._test_aggr(F.topKOrDefaultIf(3)(Person.height, Person.last_name > "H"))
weight_expr = F.toUInt32(F.round(Person.height)) weight_expr = F.toUInt32(F.round(Person.height))
self._test_aggr(F.topKWeighted(3)(Person.height, weight_expr)) self._test_aggr(F.topKWeighted(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedOrDefault(3)(Person.height, weight_expr)) self._test_aggr(F.topKWeightedOrDefault(3)(Person.height, weight_expr))
self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > 'H')) self._test_aggr(F.topKWeightedIf(3)(Person.height, weight_expr, Person.last_name > "H"))
self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > 'H')) self._test_aggr(F.topKWeightedOrDefaultIf(3)(Person.height, weight_expr, Person.last_name > "H"))
def test_null_funcs(self): def test_null_funcs(self):
self._test_func(F.ifNull(17, 18), 17) self._test_func(F.ifNull(17, 18), 17)

View File

@ -1,14 +1,14 @@
import unittest import unittest
from infi.clickhouse_orm import * from clickhouse_orm import Database, F, Index, MergeTree, Model
from clickhouse_orm.fields import DateField, Int32Field, StringField
class IndexesTest(unittest.TestCase): class IndexesTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.database = Database('test-db', log_statements=True) self.database = Database("test-db", log_statements=True)
if self.database.server_version < (20, 1, 2, 4): if self.database.server_version < (20, 1, 2, 4):
raise unittest.SkipTest('ClickHouse version too old') raise unittest.SkipTest("ClickHouse version too old")
def tearDown(self): def tearDown(self):
self.database.drop_database() self.database.drop_database()
@ -29,4 +29,4 @@ class ModelWithIndexes(Model):
i4 = Index(F.lower(f2), type=Index.tokenbf_v1(256, 2, 0), granularity=2) i4 = Index(F.lower(f2), type=Index.tokenbf_v1(256, 2, 0), granularity=2)
i5 = Index((F.toQuarter(date), f2), type=Index.bloom_filter(), granularity=3) i5 = Index((F.toQuarter(date), f2), type=Index.bloom_filter(), granularity=3)
engine = MergeTree('date', ('date',)) engine = MergeTree("date", ("date",))

Some files were not shown because too many files have changed in this diff Show More