Getting started

This commit is contained in:
Itai Shirav 2016-06-23 14:11:20 +03:00
parent 2f99d58a61
commit 4da45b0be5
5 changed files with 315 additions and 2 deletions

View File

@ -9,7 +9,11 @@ parts =
name = infi.clickhouse_utils name = infi.clickhouse_utils
company = Infinidat company = Infinidat
namespace_packages = ['infi'] namespace_packages = ['infi']
install_requires = ['setuptools'] install_requires = [
'pytz',
'requests',
'setuptools'
]
version_file = src/infi/clickhouse_utils/__version__.py version_file = src/infi/clickhouse_utils/__version__.py
description = A Python library for working with the ClickHouse database description = A Python library for working with the ClickHouse database
long_description = A Python library for working with the ClickHouse database long_description = A Python library for working with the ClickHouse database

View File

@ -0,0 +1,64 @@
class Engine(object):
def create_table_sql(self):
raise NotImplementedError()
class MergeTree(Engine):
def __init__(self, date_col, key_cols, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None):
self.date_col = date_col
self.key_cols = key_cols
self.sampling_expr = sampling_expr
self.index_granularity = index_granularity
self.replica_table_path = replica_table_path
self.replica_name = replica_name
# TODO verify that both replica fields are either present or missing
def create_table_sql(self):
name = self.__class__.__name__
if self.replica_name:
name = 'Replicated' + name
params = self._build_sql_params()
return '%s(%s)' % (name, ', '.join(params))
def _build_sql_params(self):
params = []
if self.replica_name:
params += ["'%s'" % self.replica_table_path, "'%s'" % self.replica_name]
params.append(self.date_col)
if self.sampling_expr:
params.append(self.sampling_expr)
params.append('(%s)' % ', '.join(self.key_cols))
params.append(str(self.index_granularity))
return params
class CollapsingMergeTree(MergeTree):
def __init__(self, date_col, key_cols, sign_col, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None):
super(CollapsingMergeTree, self).__init__(date_col, key_cols, sampling_expr, index_granularity, replica_table_path, replica_name)
self.sign_col = sign_col
def _build_sql_params(self):
params = super(CollapsingMergeTree, self)._build_sql_params()
params.append(self.sign_col)
return params
class SummingMergeTree(MergeTree):
def __init__(self, date_col, key_cols, summing_cols=None, sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None):
super(SummingMergeTree, self).__init__(date_col, key_cols, sampling_expr, index_granularity, replica_table_path, replica_name)
self.summing_cols = summing_cols
def _build_sql_params(self):
params = super(SummingMergeTree, self)._build_sql_params()
if self.summing_cols:
params.append('(%s)' % ', '.join(self.summing_cols))
return params

View File

@ -0,0 +1,155 @@
import datetime
import pytz
import time
class Field(object):
creation_counter = 0
class_default = 0
db_type = None
def __init__(self, default=None):
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.default = default or self.class_default
def to_python(self, value):
"""
Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this.
"""
return value
def get_db_prep_value(self, value):
"""
Returns the field's value prepared for interacting with the database.
"""
return value
class StringField(Field):
class_default = ''
db_type = 'String'
def to_python(self, value):
if isinstance(value, unicode):
return value
if isinstance(value, str):
return value.decode('UTF-8')
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
def get_db_prep_value(self, value):
if isinstance(value, unicode):
return value.encode('UTF-8')
return value
class DateField(Field):
class_default = datetime.date(1970, 1, 1)
db_type = 'Date'
def to_python(self, value):
if isinstance(value, datetime.date):
return value
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, basestring):
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
def get_db_prep_value(self, value):
return value.isoformat()
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
def to_python(self, value):
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day)
if isinstance(value, int):
return datetime.datetime.fromtimestamp(value, pytz.utc)
if isinstance(value, basestring):
return datetime.datetime.strptime(value, '%Y-%m-%d %H-%M-%S')
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
def get_db_prep_value(self, value):
return int(time.mktime(value.timetuple()))
class BaseIntField(Field):
def to_python(self, value):
if isinstance(value, int):
return value
if isinstance(value, basestring):
return int(value)
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
class UInt8Field(BaseIntField):
db_type = 'UInt8'
class UInt16Field(BaseIntField):
db_type = 'UInt16'
class UInt32Field(BaseIntField):
db_type = 'UInt32'
class UInt64Field(BaseIntField):
db_type = 'UInt64'
class Int8Field(BaseIntField):
db_type = 'Int8'
class Int16Field(BaseIntField):
db_type = 'Int16'
class Int32Field(BaseIntField):
db_type = 'Int32'
class Int64Field(BaseIntField):
db_type = 'Int64'
class BaseFloatField(Field):
def to_python(self, value):
if isinstance(value, float):
return value
if isinstance(value, basestring):
return float(value)
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
class Float32Field(BaseFloatField):
db_type = 'Float32'
class Float64Field(BaseFloatField):
db_type = 'Float64'

View File

@ -0,0 +1,62 @@
from fields import *
from utils import escape, parse_tsv
from engines import *
class ModelBase(type):
def __new__(cls, name, bases, attrs):
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
#print name, bases, attrs
# Build a list of fields, in the order they were listed in the class
fields = [item for item in attrs.items() if isinstance(item[1], Field)]
fields.sort(key=lambda item: item[1].creation_counter)
setattr(new_cls, '_fields', fields)
return new_cls
class Model(object):
__metaclass__ = ModelBase
engine = None
def __init__(self, *args, **kwargs):
super(Model, self).__init__()
for name, field in self._fields:
val = kwargs.get(name, field.default)
setattr(self, name, val)
@classmethod
def table_name(cls):
return cls.__name__.lower()
@classmethod
def create_table_sql(cls, db):
parts = ['CREATE TABLE IF NOT EXISTS %s.%s (' % (db, cls.table_name())]
for name, field in cls._fields:
default = field.get_db_prep_value(field.default)
parts.append(' %s %s DEFAULT %s,' % (name, field.db_type, escape(default)))
parts.append(')')
parts.append('ENGINE = ' + cls.engine.create_table_sql())
return '\n'.join(parts)
@classmethod
def from_tsv(cls, line):
'''
Create a model instance from a tab-separated line. The line may or may not include a newline.
'''
values = iter(parse_tsv(line))
kwargs = {}
for name, field in cls._fields:
kwargs[name] = field.to_python(values.next())
return cls(**kwargs)
def to_tsv(self):
'''
Returns the instance's column values as a tab-separated line. A newline is not included.
'''
parts = []
for name, field in self._fields:
value = field.get_db_prep_value(field.to_python(getattr(self, name)))
parts.append(escape(value, quote=False))
return '\t'.join(parts)

View File

@ -0,0 +1,28 @@
SPECIAL_CHARS = {
"\b" : "\\b",
"\f" : "\\f",
"\r" : "\\r",
"\n" : "\\n",
"\t" : "\\t",
"\0" : "\\0",
"\\" : "\\\\",
"'" : "\\'"
}
def escape(value, quote=True):
if isinstance(value, basestring):
chars = (SPECIAL_CHARS.get(c, c) for c in value)
return "'" + "".join(chars) + "'" if quote else "".join(chars)
return str(value)
def unescape(value):
return value.decode('string_escape')
def parse_tsv(line):
if line[-1] == '\n':
line = line[:-1]
return [_unescape(value) for value in line.split('\t')]