mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-11-14 06:35:19 +03:00
383 lines
15 KiB
Python
383 lines
15 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import re
|
|
import requests
|
|
from collections import namedtuple
|
|
from .models import ModelBase
|
|
from .utils import escape, parse_tsv, import_submodules
|
|
from math import ceil
|
|
import datetime
|
|
from string import Template
|
|
from six import PY3, string_types
|
|
import pytz
|
|
|
|
import logging
|
|
logger = logging.getLogger('clickhouse_orm')
|
|
|
|
|
|
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
|
|
|
|
|
class DatabaseException(Exception):
|
|
'''
|
|
Raised when a database operation fails.
|
|
'''
|
|
pass
|
|
|
|
|
|
class ServerError(DatabaseException):
|
|
"""
|
|
Raised when a server returns an error.
|
|
"""
|
|
def __init__(self, message):
|
|
self.code = None
|
|
processed = self.get_error_code_msg(message)
|
|
if processed:
|
|
self.code, self.message = processed
|
|
else:
|
|
# just skip custom init
|
|
# if non-standard message format
|
|
self.message = message
|
|
super(ServerError, self).__init__(message)
|
|
|
|
ERROR_PATTERN = re.compile(r'''
|
|
Code:\ (?P<code>\d+),
|
|
\ e\.displayText\(\)\ =\ (?P<type1>[^ \n]+):\ (?P<msg>.+?),
|
|
\ e.what\(\)\ =\ (?P<type2>[^ \n]+)
|
|
''', re.VERBOSE | re.DOTALL)
|
|
|
|
@classmethod
|
|
def get_error_code_msg(cls, full_error_message):
|
|
"""
|
|
Extract the code and message of the exception that clickhouse-server generated.
|
|
|
|
See the list of error codes here:
|
|
https://github.com/yandex/ClickHouse/blob/master/dbms/src/Common/ErrorCodes.cpp
|
|
"""
|
|
match = cls.ERROR_PATTERN.match(full_error_message)
|
|
if match:
|
|
# assert match.group('type1') == match.group('type2')
|
|
return int(match.group('code')), match.group('msg')
|
|
|
|
return 0, full_error_message
|
|
|
|
def __str__(self):
|
|
if self.code is not None:
|
|
return "{} ({})".format(self.message, self.code)
|
|
|
|
|
|
class Database(object):
|
|
'''
|
|
Database instances connect to a specific ClickHouse database for running queries,
|
|
inserting data and other operations.
|
|
'''
|
|
|
|
def __init__(self, db_name, db_url='http://localhost:8123/',
|
|
username=None, password=None, readonly=False, autocreate=True,
|
|
timeout=60, verify_ssl_cert=True):
|
|
'''
|
|
Initializes a database instance. Unless it's readonly, the database will be
|
|
created on the ClickHouse server if it does not already exist.
|
|
|
|
- `db_name`: name of the database to connect to.
|
|
- `db_url`: URL of the ClickHouse server.
|
|
- `username`: optional connection credentials.
|
|
- `password`: optional connection credentials.
|
|
- `readonly`: use a read-only connection.
|
|
- `autocreate`: automatically create the database if it does not exist (unless in readonly mode).
|
|
- `timeout`: the connection timeout in seconds.
|
|
- `verify_ssl_cert`: whether to verify the server's certificate when connecting via HTTPS.
|
|
'''
|
|
self.db_name = db_name
|
|
self.db_url = db_url
|
|
self.username = username
|
|
self.password = password
|
|
self.readonly = False
|
|
self.timeout = timeout
|
|
self.request_session = requests.Session()
|
|
self.request_session.verify = verify_ssl_cert
|
|
self.settings = {}
|
|
self.db_exists = False # this is required before running _is_existing_database
|
|
self.db_exists = self._is_existing_database()
|
|
if readonly:
|
|
if not self.db_exists:
|
|
raise DatabaseException('Database does not exist, and cannot be created under readonly connection')
|
|
self.connection_readonly = self._is_connection_readonly()
|
|
self.readonly = True
|
|
elif autocreate and not self.db_exists:
|
|
self.create_database()
|
|
self.server_version = self._get_server_version()
|
|
# Versions 1.1.53981 and below don't have timezone function
|
|
self.server_timezone = self._get_server_timezone() if self.server_version > (1, 1, 53981) else pytz.utc
|
|
|
|
def create_database(self):
|
|
'''
|
|
Creates the database on the ClickHouse server if it does not already exist.
|
|
'''
|
|
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
|
|
self.db_exists = True
|
|
|
|
def drop_database(self):
|
|
'''
|
|
Deletes the database on the ClickHouse server.
|
|
'''
|
|
self._send('DROP DATABASE `%s`' % self.db_name)
|
|
self.db_exists = False
|
|
|
|
def create_table(self, model_class):
|
|
'''
|
|
Creates a table for the given model class, if it does not exist already.
|
|
'''
|
|
if model_class.is_system_model():
|
|
raise DatabaseException("You can't create system table")
|
|
if getattr(model_class, 'engine') is None:
|
|
raise DatabaseException("%s class must define an engine" % model_class.__name__)
|
|
self._send(model_class.create_table_sql(self))
|
|
|
|
def drop_table(self, model_class):
|
|
'''
|
|
Drops the database table of the given model class, if it exists.
|
|
'''
|
|
if model_class.is_system_model():
|
|
raise DatabaseException("You can't drop system table")
|
|
self._send(model_class.drop_table_sql(self))
|
|
|
|
def does_table_exist(self, model_class):
|
|
'''
|
|
Checks whether a table for the given model class already exists.
|
|
Note that this only checks for existence of a table with the expected name.
|
|
'''
|
|
sql = "SELECT count() FROM system.tables WHERE database = '%s' AND name = '%s'"
|
|
r = self._send(sql % (self.db_name, model_class.table_name()))
|
|
return r.text.strip() == '1'
|
|
|
|
def add_setting(self, name, value):
|
|
'''
|
|
Adds a database setting that will be sent with every request.
|
|
For example, `db.add_setting("max_execution_time", 10)` will
|
|
limit query execution time to 10 seconds.
|
|
The name must be string, and the value is converted to string in case
|
|
it isn't. To remove a setting, pass `None` as the value.
|
|
'''
|
|
assert isinstance(name, string_types), 'Setting name must be a string'
|
|
if value is None:
|
|
self.settings.pop(name, None)
|
|
else:
|
|
self.settings[name] = str(value)
|
|
|
|
def insert(self, model_instances, batch_size=1000):
|
|
'''
|
|
Insert records into the database.
|
|
|
|
- `model_instances`: any iterable containing instances of a single model class.
|
|
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
|
|
'''
|
|
from six import next
|
|
from io import BytesIO
|
|
i = iter(model_instances)
|
|
try:
|
|
first_instance = next(i)
|
|
except StopIteration:
|
|
return # model_instances is empty
|
|
model_class = first_instance.__class__
|
|
|
|
if first_instance.is_read_only() or first_instance.is_system_model():
|
|
raise DatabaseException("You can't insert into read only and system tables")
|
|
|
|
fields_list = ','.join(
|
|
['`%s`' % name for name in first_instance.fields(writable=True)])
|
|
|
|
def gen():
|
|
buf = BytesIO()
|
|
query = 'INSERT INTO $table (%s) FORMAT TabSeparated\n' % fields_list
|
|
buf.write(self._substitute(query, model_class).encode('utf-8'))
|
|
first_instance.set_database(self)
|
|
buf.write(first_instance.to_tsv(include_readonly=False).encode('utf-8'))
|
|
buf.write('\n'.encode('utf-8'))
|
|
# Collect lines in batches of batch_size
|
|
lines = 2
|
|
for instance in i:
|
|
instance.set_database(self)
|
|
buf.write(instance.to_tsv(include_readonly=False).encode('utf-8'))
|
|
buf.write('\n'.encode('utf-8'))
|
|
lines += 1
|
|
if lines >= batch_size:
|
|
# Return the current batch of lines
|
|
yield buf.getvalue()
|
|
# Start a new batch
|
|
buf = BytesIO()
|
|
lines = 0
|
|
# Return any remaining lines in partial batch
|
|
if lines:
|
|
yield buf.getvalue()
|
|
self._send(gen())
|
|
|
|
def count(self, model_class, conditions=None):
|
|
'''
|
|
Counts the number of records in the model's table.
|
|
|
|
- `model_class`: the model to count.
|
|
- `conditions`: optional SQL conditions (contents of the WHERE clause).
|
|
'''
|
|
query = 'SELECT count() FROM $table'
|
|
if conditions:
|
|
query += ' WHERE ' + conditions
|
|
query = self._substitute(query, model_class)
|
|
r = self._send(query)
|
|
return int(r.text) if r.text else 0
|
|
|
|
def select(self, query, model_class=None, settings=None):
|
|
'''
|
|
Performs a query and returns a generator of model instances.
|
|
|
|
- `query`: the SQL query to execute.
|
|
- `model_class`: the model class matching the query's table,
|
|
or `None` for getting back instances of an ad-hoc model.
|
|
- `settings`: query settings to send as HTTP GET parameters
|
|
'''
|
|
query += ' FORMAT TabSeparatedWithNamesAndTypes'
|
|
query = self._substitute(query, model_class)
|
|
r = self._send(query, settings, True)
|
|
lines = r.iter_lines()
|
|
field_names = parse_tsv(next(lines))
|
|
field_types = parse_tsv(next(lines))
|
|
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
|
|
for line in lines:
|
|
# skip blank line left by WITH TOTALS modifier
|
|
if line:
|
|
yield model_class.from_tsv(line, field_names, self.server_timezone, self)
|
|
|
|
def raw(self, query, settings=None, stream=False):
|
|
'''
|
|
Performs a query and returns its output as text.
|
|
|
|
- `query`: the SQL query to execute.
|
|
- `settings`: query settings to send as HTTP GET parameters
|
|
- `stream`: if true, the HTTP response from ClickHouse will be streamed.
|
|
'''
|
|
query = self._substitute(query, None)
|
|
return self._send(query, settings=settings, stream=stream).text
|
|
|
|
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
|
|
'''
|
|
Selects records and returns a single page of model instances.
|
|
|
|
- `model_class`: the model class matching the query's table,
|
|
or `None` for getting back instances of an ad-hoc model.
|
|
- `order_by`: columns to use for sorting the query (contents of the ORDER BY clause).
|
|
- `page_num`: the page number (1-based), or -1 to get the last page.
|
|
- `page_size`: number of records to return per page.
|
|
- `conditions`: optional SQL conditions (contents of the WHERE clause).
|
|
- `settings`: query settings to send as HTTP GET parameters
|
|
|
|
The result is a namedtuple containing `objects` (list), `number_of_objects`,
|
|
`pages_total`, `number` (of the current page), and `page_size`.
|
|
'''
|
|
count = self.count(model_class, conditions)
|
|
pages_total = int(ceil(count / float(page_size)))
|
|
if page_num == -1:
|
|
page_num = pages_total
|
|
elif page_num < 1:
|
|
raise ValueError('Invalid page number: %d' % page_num)
|
|
offset = (page_num - 1) * page_size
|
|
query = 'SELECT * FROM $table'
|
|
if conditions:
|
|
query += ' WHERE ' + conditions
|
|
query += ' ORDER BY %s' % order_by
|
|
query += ' LIMIT %d, %d' % (offset, page_size)
|
|
query = self._substitute(query, model_class)
|
|
return Page(
|
|
objects=list(self.select(query, model_class, settings)),
|
|
number_of_objects=count,
|
|
pages_total=pages_total,
|
|
number=page_num,
|
|
page_size=page_size
|
|
)
|
|
|
|
def migrate(self, migrations_package_name, up_to=9999):
|
|
'''
|
|
Executes schema migrations.
|
|
|
|
- `migrations_package_name` - fully qualified name of the Python package
|
|
containing the migrations.
|
|
- `up_to` - number of the last migration to apply.
|
|
'''
|
|
from .migrations import MigrationHistory
|
|
logger = logging.getLogger('migrations')
|
|
applied_migrations = self._get_applied_migrations(migrations_package_name)
|
|
modules = import_submodules(migrations_package_name)
|
|
unapplied_migrations = set(modules.keys()) - applied_migrations
|
|
for name in sorted(unapplied_migrations):
|
|
logger.info('Applying migration %s...', name)
|
|
for operation in modules[name].operations:
|
|
operation.apply(self)
|
|
self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())])
|
|
if int(name[:4]) >= up_to:
|
|
break
|
|
|
|
def _get_applied_migrations(self, migrations_package_name):
|
|
from .migrations import MigrationHistory
|
|
self.create_table(MigrationHistory)
|
|
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
|
query = self._substitute(query, MigrationHistory)
|
|
return set(obj.module_name for obj in self.select(query))
|
|
|
|
def _send(self, data, settings=None, stream=False):
|
|
if isinstance(data, string_types):
|
|
data = data.encode('utf-8')
|
|
params = self._build_params(settings)
|
|
r = self.request_session.post(self.db_url, params=params, data=data, stream=stream, timeout=self.timeout)
|
|
if r.status_code != 200:
|
|
raise ServerError(r.text)
|
|
return r
|
|
|
|
def _build_params(self, settings):
|
|
params = dict(settings or {})
|
|
params.update(self.settings)
|
|
if self.db_exists:
|
|
params['database'] = self.db_name
|
|
if self.username:
|
|
params['user'] = self.username
|
|
if self.password:
|
|
params['password'] = self.password
|
|
# Send the readonly flag, unless the connection is already readonly (to prevent db error)
|
|
if self.readonly and not self.connection_readonly:
|
|
params['readonly'] = '1'
|
|
return params
|
|
|
|
def _substitute(self, query, model_class=None):
|
|
'''
|
|
Replaces $db and $table placeholders in the query.
|
|
'''
|
|
if '$' in query:
|
|
mapping = dict(db="`%s`" % self.db_name)
|
|
if model_class:
|
|
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
|
query = Template(query).safe_substitute(mapping)
|
|
return query
|
|
|
|
def _get_server_timezone(self):
|
|
try:
|
|
r = self._send('SELECT timezone()')
|
|
return pytz.timezone(r.text.strip())
|
|
except ServerError as e:
|
|
logger.exception('Cannot determine server timezone (%s), assuming UTC', e)
|
|
return pytz.utc
|
|
|
|
def _get_server_version(self, as_tuple=True):
|
|
try:
|
|
r = self._send('SELECT version();')
|
|
ver = r.text
|
|
except ServerError as e:
|
|
logger.exception('Cannot determine server version (%s), assuming 1.1.0', e)
|
|
ver = '1.1.0'
|
|
return tuple(int(n) for n in ver.split('.')) if as_tuple else ver
|
|
|
|
def _is_existing_database(self):
|
|
r = self._send("SELECT count() FROM system.databases WHERE name = '%s'" % self.db_name)
|
|
return r.text.strip() == '1'
|
|
|
|
def _is_connection_readonly(self):
|
|
r = self._send("SELECT value FROM system.settings WHERE name = 'readonly'")
|
|
return r.text.strip() != '0'
|