mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-26 02:33:45 +03:00
Add Python 3 support
This commit is contained in:
parent
290f8da9ce
commit
071bcd283d
|
@ -12,7 +12,8 @@ namespace_packages = ['infi']
|
||||||
install_requires = [
|
install_requires = [
|
||||||
'pytz',
|
'pytz',
|
||||||
'requests',
|
'requests',
|
||||||
'setuptools'
|
'setuptools',
|
||||||
|
'six'
|
||||||
]
|
]
|
||||||
version_file = src/infi/clickhouse_orm/__version__.py
|
version_file = src/infi/clickhouse_orm/__version__.py
|
||||||
description = A Python library for working with the ClickHouse database
|
description = A Python library for working with the ClickHouse database
|
||||||
|
|
3
setup.in
3
setup.in
|
@ -16,7 +16,10 @@ SETUP_INFO = dict(
|
||||||
"License :: OSI Approved :: Python Software Foundation License",
|
"License :: OSI Approved :: Python Software Foundation License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python",
|
"Programming Language :: Python",
|
||||||
|
"Programming Language :: Python :: 2.7",
|
||||||
|
"Programming Language :: Python :: 3.4",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"Topic :: Database"
|
||||||
],
|
],
|
||||||
|
|
||||||
install_requires = ${project:install_requires},
|
install_requires = ${project:install_requires},
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import requests
|
import requests
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from models import ModelBase
|
from .models import ModelBase
|
||||||
from utils import escape, parse_tsv, import_submodules
|
from .utils import escape, parse_tsv, import_submodules
|
||||||
from math import ceil
|
from math import ceil
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from string import Template
|
from string import Template
|
||||||
|
from six import PY3, string_types
|
||||||
|
|
||||||
|
|
||||||
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
||||||
|
@ -35,19 +36,20 @@ class Database(object):
|
||||||
self._send('DROP DATABASE `%s`' % self.db_name)
|
self._send('DROP DATABASE `%s`' % self.db_name)
|
||||||
|
|
||||||
def insert(self, model_instances):
|
def insert(self, model_instances):
|
||||||
|
from six import next
|
||||||
i = iter(model_instances)
|
i = iter(model_instances)
|
||||||
try:
|
try:
|
||||||
first_instance = i.next()
|
first_instance = next(i)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return # model_instances is empty
|
return # model_instances is empty
|
||||||
model_class = first_instance.__class__
|
model_class = first_instance.__class__
|
||||||
def gen():
|
def gen():
|
||||||
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class)
|
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
|
||||||
yield first_instance.to_tsv()
|
yield first_instance.to_tsv().encode('utf-8')
|
||||||
yield '\n'
|
yield '\n'.encode('utf-8')
|
||||||
for instance in i:
|
for instance in i:
|
||||||
yield instance.to_tsv()
|
yield instance.to_tsv().encode('utf-8')
|
||||||
yield '\n'
|
yield '\n'.encode('utf-8')
|
||||||
self._send(gen())
|
self._send(gen())
|
||||||
|
|
||||||
def count(self, model_class, conditions=None):
|
def count(self, model_class, conditions=None):
|
||||||
|
@ -88,7 +90,7 @@ class Database(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def migrate(self, migrations_package_name, up_to=9999):
|
def migrate(self, migrations_package_name, up_to=9999):
|
||||||
from migrations import MigrationHistory
|
from .migrations import MigrationHistory
|
||||||
logger = logging.getLogger('migrations')
|
logger = logging.getLogger('migrations')
|
||||||
applied_migrations = self._get_applied_migrations(migrations_package_name)
|
applied_migrations = self._get_applied_migrations(migrations_package_name)
|
||||||
modules = import_submodules(migrations_package_name)
|
modules = import_submodules(migrations_package_name)
|
||||||
|
@ -102,13 +104,15 @@ class Database(object):
|
||||||
break
|
break
|
||||||
|
|
||||||
def _get_applied_migrations(self, migrations_package_name):
|
def _get_applied_migrations(self, migrations_package_name):
|
||||||
from migrations import MigrationHistory
|
from .migrations import MigrationHistory
|
||||||
self.create_table(MigrationHistory)
|
self.create_table(MigrationHistory)
|
||||||
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
||||||
query = self._substitute(query, MigrationHistory)
|
query = self._substitute(query, MigrationHistory)
|
||||||
return set(obj.module_name for obj in self.select(query))
|
return set(obj.module_name for obj in self.select(query))
|
||||||
|
|
||||||
def _send(self, data, settings=None, stream=False):
|
def _send(self, data, settings=None, stream=False):
|
||||||
|
if PY3 and isinstance(data, string_types):
|
||||||
|
data = data.encode('utf-8')
|
||||||
params = self._build_params(settings)
|
params = self._build_params(settings)
|
||||||
r = requests.post(self.db_url, params=params, data=data, stream=stream)
|
r = requests.post(self.db_url, params=params, data=data, stream=stream)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
|
@ -118,7 +122,7 @@ class Database(object):
|
||||||
def _build_params(self, settings):
|
def _build_params(self, settings):
|
||||||
params = dict(settings or {})
|
params = dict(settings or {})
|
||||||
if self.username:
|
if self.username:
|
||||||
params['username'] = self.username
|
params['user'] = self.username
|
||||||
if self.password:
|
if self.password:
|
||||||
params['password'] = self.password
|
params['password'] = self.password
|
||||||
return params
|
return params
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from six import string_types, text_type, binary_type
|
||||||
import datetime
|
import datetime
|
||||||
import pytz
|
import pytz
|
||||||
import time
|
import time
|
||||||
|
@ -48,17 +49,12 @@ class StringField(Field):
|
||||||
db_type = 'String'
|
db_type = 'String'
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
if isinstance(value, unicode):
|
if isinstance(value, text_type):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, str):
|
if isinstance(value, binary_type):
|
||||||
return value.decode('UTF-8')
|
return value.decode('UTF-8')
|
||||||
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
|
||||||
if isinstance(value, unicode):
|
|
||||||
return value.encode('UTF-8')
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class DateField(Field):
|
class DateField(Field):
|
||||||
|
|
||||||
|
@ -72,7 +68,7 @@ class DateField(Field):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return DateField.class_default + datetime.timedelta(days=value)
|
return DateField.class_default + datetime.timedelta(days=value)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, string_types):
|
||||||
if value == '0000-00-00':
|
if value == '0000-00-00':
|
||||||
return DateField.min_value
|
return DateField.min_value
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
||||||
|
@ -97,7 +93,7 @@ class DateTimeField(Field):
|
||||||
return datetime.datetime(value.year, value.month, value.day)
|
return datetime.datetime(value.year, value.month, value.day)
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, string_types):
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from models import Model
|
from .models import Model
|
||||||
from fields import DateField, StringField
|
from .fields import DateField, StringField
|
||||||
from engines import MergeTree
|
from .engines import MergeTree
|
||||||
from utils import escape
|
from .utils import escape
|
||||||
|
|
||||||
from itertools import izip
|
from six.moves import zip
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger('migrations')
|
logger = logging.getLogger('migrations')
|
||||||
|
@ -74,7 +74,7 @@ class AlterTable(Operation):
|
||||||
prev_name = name
|
prev_name = name
|
||||||
# Identify fields whose type was changed
|
# Identify fields whose type was changed
|
||||||
model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
|
model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
|
||||||
for model_field, table_field in izip(model_fields, self._get_table_fields(database)):
|
for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
|
||||||
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
|
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
|
||||||
if model_field[1] != table_field[1]:
|
if model_field[1] != table_field[1]:
|
||||||
logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1])
|
logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1])
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from utils import escape, parse_tsv
|
from .utils import escape, parse_tsv
|
||||||
from engines import *
|
from .engines import *
|
||||||
from fields import Field
|
from .fields import Field
|
||||||
|
|
||||||
|
from six import with_metaclass
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(type):
|
class ModelBase(type):
|
||||||
|
@ -26,9 +28,10 @@ class ModelBase(type):
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ad_hoc_model(cls, fields):
|
def create_ad_hoc_model(cls, fields):
|
||||||
# fields is a list of tuples (name, db_type)
|
# fields is a list of tuples (name, db_type)
|
||||||
import fields as orm_fields
|
import infi.clickhouse_orm.fields as orm_fields
|
||||||
# Check if model exists in cache
|
# Check if model exists in cache
|
||||||
cache_key = unicode(fields)
|
fields = list(fields)
|
||||||
|
cache_key = str(fields)
|
||||||
if cache_key in cls.ad_hoc_model_cache:
|
if cache_key in cls.ad_hoc_model_cache:
|
||||||
return cls.ad_hoc_model_cache[cache_key]
|
return cls.ad_hoc_model_cache[cache_key]
|
||||||
# Create an ad hoc model class
|
# Create an ad hoc model class
|
||||||
|
@ -44,12 +47,11 @@ class ModelBase(type):
|
||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(with_metaclass(ModelBase)):
|
||||||
'''
|
'''
|
||||||
A base class for ORM models.
|
A base class for ORM models.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
__metaclass__ = ModelBase
|
|
||||||
engine = None
|
engine = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -61,7 +63,7 @@ class Model(object):
|
||||||
'''
|
'''
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
# Assign field values from keyword arguments
|
# Assign field values from keyword arguments
|
||||||
for name, value in kwargs.iteritems():
|
for name, value in kwargs.items():
|
||||||
field = self.get_field(name)
|
field = self.get_field(name)
|
||||||
if field:
|
if field:
|
||||||
setattr(self, name, value)
|
setattr(self, name, value)
|
||||||
|
@ -126,11 +128,12 @@ class Model(object):
|
||||||
The field_names list must match the fields defined in the model, but does not have to include all of them.
|
The field_names list must match the fields defined in the model, but does not have to include all of them.
|
||||||
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
|
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
|
||||||
'''
|
'''
|
||||||
|
from six import next
|
||||||
field_names = field_names or [name for name, field in cls._fields]
|
field_names = field_names or [name for name, field in cls._fields]
|
||||||
values = iter(parse_tsv(line))
|
values = iter(parse_tsv(line))
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for name in field_names:
|
for name in field_names:
|
||||||
kwargs[name] = values.next()
|
kwargs[name] = next(values)
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
def to_tsv(self):
|
def to_tsv(self):
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from six import string_types, binary_type, text_type, PY3
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
|
||||||
SPECIAL_CHARS = {
|
SPECIAL_CHARS = {
|
||||||
"\b" : "\\b",
|
"\b" : "\\b",
|
||||||
|
@ -12,17 +15,19 @@ SPECIAL_CHARS = {
|
||||||
|
|
||||||
|
|
||||||
def escape(value, quote=True):
|
def escape(value, quote=True):
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, string_types):
|
||||||
chars = (SPECIAL_CHARS.get(c, c) for c in value)
|
chars = (SPECIAL_CHARS.get(c, c) for c in value)
|
||||||
return "'" + "".join(chars) + "'" if quote else "".join(chars)
|
value = "'" + "".join(chars) + "'" if quote else "".join(chars)
|
||||||
return str(value)
|
return text_type(value)
|
||||||
|
|
||||||
|
|
||||||
def unescape(value):
|
def unescape(value):
|
||||||
return value.decode('string_escape')
|
return codecs.escape_decode(value)[0].decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def parse_tsv(line):
|
def parse_tsv(line):
|
||||||
|
if PY3 and isinstance(line, binary_type):
|
||||||
|
line = line.decode()
|
||||||
if line[-1] == '\n':
|
if line[-1] == '\n':
|
||||||
line = line[:-1]
|
line = line[:-1]
|
||||||
return [unescape(value) for value in line.split('\t')]
|
return [unescape(value) for value in line.split('\t')]
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from infi.clickhouse_orm.database import Database
|
from infi.clickhouse_orm.database import Database
|
||||||
|
@ -91,6 +93,13 @@ class DatabaseTestCase(unittest.TestCase):
|
||||||
# Verify that all instances were returned
|
# Verify that all instances were returned
|
||||||
self.assertEquals(len(instances), len(data))
|
self.assertEquals(len(instances), len(data))
|
||||||
|
|
||||||
|
def test_special_chars(self):
|
||||||
|
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
|
||||||
|
p = Person(first_name=s)
|
||||||
|
self.database.insert([p])
|
||||||
|
p = list(self.database.select("SELECT * from $table", Person))[0]
|
||||||
|
self.assertEquals(p.first_name, s)
|
||||||
|
|
||||||
def _sample_data(self):
|
def _sample_data(self):
|
||||||
for entry in data:
|
for entry in data:
|
||||||
yield Person(**entry)
|
yield Person(**entry)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user