Add Python 3 support

This commit is contained in:
Itai Shirav 2016-08-01 10:28:10 +03:00
parent 290f8da9ce
commit 071bcd283d
8 changed files with 62 additions and 41 deletions

View File

@ -3,7 +3,7 @@ prefer-final = false
newest = false newest = false
download-cache = .cache download-cache = .cache
develop = . develop = .
parts = parts =
[project] [project]
name = infi.clickhouse_orm name = infi.clickhouse_orm
@ -12,7 +12,8 @@ namespace_packages = ['infi']
install_requires = [ install_requires = [
'pytz', 'pytz',
'requests', 'requests',
'setuptools' 'setuptools',
'six'
] ]
version_file = src/infi/clickhouse_orm/__version__.py version_file = src/infi/clickhouse_orm/__version__.py
description = A Python library for working with the ClickHouse database description = A Python library for working with the ClickHouse database

View File

@ -16,7 +16,10 @@ SETUP_INFO = dict(
"License :: OSI Approved :: Python Software Foundation License", "License :: OSI Approved :: Python Software Foundation License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.4",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Database"
], ],
install_requires = ${project:install_requires}, install_requires = ${project:install_requires},

View File

@ -1,11 +1,12 @@
import requests import requests
from collections import namedtuple from collections import namedtuple
from models import ModelBase from .models import ModelBase
from utils import escape, parse_tsv, import_submodules from .utils import escape, parse_tsv, import_submodules
from math import ceil from math import ceil
import datetime import datetime
import logging import logging
from string import Template from string import Template
from six import PY3, string_types
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
@ -35,19 +36,20 @@ class Database(object):
self._send('DROP DATABASE `%s`' % self.db_name) self._send('DROP DATABASE `%s`' % self.db_name)
def insert(self, model_instances): def insert(self, model_instances):
from six import next
i = iter(model_instances) i = iter(model_instances)
try: try:
first_instance = i.next() first_instance = next(i)
except StopIteration: except StopIteration:
return # model_instances is empty return # model_instances is empty
model_class = first_instance.__class__ model_class = first_instance.__class__
def gen(): def gen():
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class) yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
yield first_instance.to_tsv() yield first_instance.to_tsv().encode('utf-8')
yield '\n' yield '\n'.encode('utf-8')
for instance in i: for instance in i:
yield instance.to_tsv() yield instance.to_tsv().encode('utf-8')
yield '\n' yield '\n'.encode('utf-8')
self._send(gen()) self._send(gen())
def count(self, model_class, conditions=None): def count(self, model_class, conditions=None):
@ -88,7 +90,7 @@ class Database(object):
) )
def migrate(self, migrations_package_name, up_to=9999): def migrate(self, migrations_package_name, up_to=9999):
from migrations import MigrationHistory from .migrations import MigrationHistory
logger = logging.getLogger('migrations') logger = logging.getLogger('migrations')
applied_migrations = self._get_applied_migrations(migrations_package_name) applied_migrations = self._get_applied_migrations(migrations_package_name)
modules = import_submodules(migrations_package_name) modules = import_submodules(migrations_package_name)
@ -102,13 +104,15 @@ class Database(object):
break break
def _get_applied_migrations(self, migrations_package_name): def _get_applied_migrations(self, migrations_package_name):
from migrations import MigrationHistory from .migrations import MigrationHistory
self.create_table(MigrationHistory) self.create_table(MigrationHistory)
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
query = self._substitute(query, MigrationHistory) query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query)) return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False): def _send(self, data, settings=None, stream=False):
if PY3 and isinstance(data, string_types):
data = data.encode('utf-8')
params = self._build_params(settings) params = self._build_params(settings)
r = requests.post(self.db_url, params=params, data=data, stream=stream) r = requests.post(self.db_url, params=params, data=data, stream=stream)
if r.status_code != 200: if r.status_code != 200:
@ -118,7 +122,7 @@ class Database(object):
def _build_params(self, settings): def _build_params(self, settings):
params = dict(settings or {}) params = dict(settings or {})
if self.username: if self.username:
params['username'] = self.username params['user'] = self.username
if self.password: if self.password:
params['password'] = self.password params['password'] = self.password
return params return params

View File

@ -1,3 +1,4 @@
from six import string_types, text_type, binary_type
import datetime import datetime
import pytz import pytz
import time import time
@ -48,17 +49,12 @@ class StringField(Field):
db_type = 'String' db_type = 'String'
def to_python(self, value): def to_python(self, value):
if isinstance(value, unicode): if isinstance(value, text_type):
return value return value
if isinstance(value, str): if isinstance(value, binary_type):
return value.decode('UTF-8') return value.decode('UTF-8')
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
def get_db_prep_value(self, value):
if isinstance(value, unicode):
return value.encode('UTF-8')
return value
class DateField(Field): class DateField(Field):
@ -72,7 +68,7 @@ class DateField(Field):
return value return value
if isinstance(value, int): if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value) return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, basestring): if isinstance(value, string_types):
if value == '0000-00-00': if value == '0000-00-00':
return DateField.min_value return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date() return datetime.datetime.strptime(value, '%Y-%m-%d').date()
@ -97,7 +93,7 @@ class DateTimeField(Field):
return datetime.datetime(value.year, value.month, value.day) return datetime.datetime(value.year, value.month, value.day)
if isinstance(value, int): if isinstance(value, int):
return datetime.datetime.fromtimestamp(value, pytz.utc) return datetime.datetime.fromtimestamp(value, pytz.utc)
if isinstance(value, basestring): if isinstance(value, string_types):
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))

View File

@ -1,9 +1,9 @@
from models import Model from .models import Model
from fields import DateField, StringField from .fields import DateField, StringField
from engines import MergeTree from .engines import MergeTree
from utils import escape from .utils import escape
from itertools import izip from six.moves import zip
import logging import logging
logger = logging.getLogger('migrations') logger = logging.getLogger('migrations')
@ -74,7 +74,7 @@ class AlterTable(Operation):
prev_name = name prev_name = name
# Identify fields whose type was changed # Identify fields whose type was changed
model_fields = [(name, field.db_type) for name, field in self.model_class._fields] model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
for model_field, table_field in izip(model_fields, self._get_table_fields(database)): for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement' assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
if model_field[1] != table_field[1]: if model_field[1] != table_field[1]:
logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1]) logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1])

View File

@ -1,6 +1,8 @@
from utils import escape, parse_tsv from .utils import escape, parse_tsv
from engines import * from .engines import *
from fields import Field from .fields import Field
from six import with_metaclass
class ModelBase(type): class ModelBase(type):
@ -26,9 +28,10 @@ class ModelBase(type):
@classmethod @classmethod
def create_ad_hoc_model(cls, fields): def create_ad_hoc_model(cls, fields):
# fields is a list of tuples (name, db_type) # fields is a list of tuples (name, db_type)
import fields as orm_fields import infi.clickhouse_orm.fields as orm_fields
# Check if model exists in cache # Check if model exists in cache
cache_key = unicode(fields) fields = list(fields)
cache_key = str(fields)
if cache_key in cls.ad_hoc_model_cache: if cache_key in cls.ad_hoc_model_cache:
return cls.ad_hoc_model_cache[cache_key] return cls.ad_hoc_model_cache[cache_key]
# Create an ad hoc model class # Create an ad hoc model class
@ -44,12 +47,11 @@ class ModelBase(type):
return model_class return model_class
class Model(object): class Model(with_metaclass(ModelBase)):
''' '''
A base class for ORM models. A base class for ORM models.
''' '''
__metaclass__ = ModelBase
engine = None engine = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -61,7 +63,7 @@ class Model(object):
''' '''
super(Model, self).__init__() super(Model, self).__init__()
# Assign field values from keyword arguments # Assign field values from keyword arguments
for name, value in kwargs.iteritems(): for name, value in kwargs.items():
field = self.get_field(name) field = self.get_field(name)
if field: if field:
setattr(self, name, value) setattr(self, name, value)
@ -126,11 +128,12 @@ class Model(object):
The field_names list must match the fields defined in the model, but does not have to include all of them. The field_names list must match the fields defined in the model, but does not have to include all of them.
If omitted, it is assumed to be the names of all fields in the model, in order of definition. If omitted, it is assumed to be the names of all fields in the model, in order of definition.
''' '''
from six import next
field_names = field_names or [name for name, field in cls._fields] field_names = field_names or [name for name, field in cls._fields]
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
kwargs[name] = values.next() kwargs[name] = next(values)
return cls(**kwargs) return cls(**kwargs)
def to_tsv(self): def to_tsv(self):

View File

@ -1,3 +1,6 @@
from six import string_types, binary_type, text_type, PY3
import codecs
SPECIAL_CHARS = { SPECIAL_CHARS = {
"\b" : "\\b", "\b" : "\\b",
@ -12,17 +15,19 @@ SPECIAL_CHARS = {
def escape(value, quote=True): def escape(value, quote=True):
if isinstance(value, basestring): if isinstance(value, string_types):
chars = (SPECIAL_CHARS.get(c, c) for c in value) chars = (SPECIAL_CHARS.get(c, c) for c in value)
return "'" + "".join(chars) + "'" if quote else "".join(chars) value = "'" + "".join(chars) + "'" if quote else "".join(chars)
return str(value) return text_type(value)
def unescape(value): def unescape(value):
return value.decode('string_escape') return codecs.escape_decode(value)[0].decode('utf-8')
def parse_tsv(line): def parse_tsv(line):
if PY3 and isinstance(line, binary_type):
line = line.decode()
if line[-1] == '\n': if line[-1] == '\n':
line = line[:-1] line = line[:-1]
return [unescape(value) for value in line.split('\t')] return [unescape(value) for value in line.split('\t')]

View File

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
import unittest import unittest
from infi.clickhouse_orm.database import Database from infi.clickhouse_orm.database import Database
@ -91,6 +93,13 @@ class DatabaseTestCase(unittest.TestCase):
# Verify that all instances were returned # Verify that all instances were returned
self.assertEquals(len(instances), len(data)) self.assertEquals(len(instances), len(data))
def test_special_chars(self):
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
p = Person(first_name=s)
self.database.insert([p])
p = list(self.database.select("SELECT * from $table", Person))[0]
self.assertEquals(p.first_name, s)
def _sample_data(self): def _sample_data(self):
for entry in data: for entry in data:
yield Person(**entry) yield Person(**entry)