From 071bcd283dbc8a1fd10e154f32a5edbe39b5faed Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Mon, 1 Aug 2016 10:28:10 +0300 Subject: [PATCH] Add Python 3 support --- buildout.cfg | 5 +++-- setup.in | 3 +++ src/infi/clickhouse_orm/database.py | 26 +++++++++++++++----------- src/infi/clickhouse_orm/fields.py | 14 +++++--------- src/infi/clickhouse_orm/migrations.py | 12 ++++++------ src/infi/clickhouse_orm/models.py | 21 ++++++++++++--------- src/infi/clickhouse_orm/utils.py | 13 +++++++++---- tests/test_database.py | 9 +++++++++ 8 files changed, 62 insertions(+), 41 deletions(-) diff --git a/buildout.cfg b/buildout.cfg index 567d999..ddc5939 100644 --- a/buildout.cfg +++ b/buildout.cfg @@ -3,7 +3,7 @@ prefer-final = false newest = false download-cache = .cache develop = . -parts = +parts = [project] name = infi.clickhouse_orm @@ -12,7 +12,8 @@ namespace_packages = ['infi'] install_requires = [ 'pytz', 'requests', - 'setuptools' + 'setuptools', + 'six' ] version_file = src/infi/clickhouse_orm/__version__.py description = A Python library for working with the ClickHouse database diff --git a/setup.in b/setup.in index deb9512..9b65942 100644 --- a/setup.in +++ b/setup.in @@ -16,7 +16,10 @@ SETUP_INFO = dict( "License :: OSI Approved :: Python Software Foundation License", "Operating System :: OS Independent", "Programming Language :: Python", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.4", "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Database" ], install_requires = ${project:install_requires}, diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index 913e54c..23f115b 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -1,11 +1,12 @@ import requests from collections import namedtuple -from models import ModelBase -from utils import escape, parse_tsv, import_submodules +from .models import ModelBase +from .utils import escape, parse_tsv, import_submodules from math import ceil import datetime import logging from string import Template +from six import PY3, string_types Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') @@ -35,19 +36,20 @@ class Database(object): self._send('DROP DATABASE `%s`' % self.db_name) def insert(self, model_instances): + from six import next i = iter(model_instances) try: - first_instance = i.next() + first_instance = next(i) except StopIteration: return # model_instances is empty model_class = first_instance.__class__ def gen(): - yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class) - yield first_instance.to_tsv() - yield '\n' + yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8') + yield first_instance.to_tsv().encode('utf-8') + yield '\n'.encode('utf-8') for instance in i: - yield instance.to_tsv() - yield '\n' + yield instance.to_tsv().encode('utf-8') + yield '\n'.encode('utf-8') self._send(gen()) def count(self, model_class, conditions=None): @@ -88,7 +90,7 @@ class Database(object): ) def migrate(self, migrations_package_name, up_to=9999): - from migrations import MigrationHistory + from .migrations import MigrationHistory logger = logging.getLogger('migrations') applied_migrations = self._get_applied_migrations(migrations_package_name) modules = import_submodules(migrations_package_name) @@ -102,13 +104,15 @@ class Database(object): break def _get_applied_migrations(self, migrations_package_name): - from migrations import MigrationHistory + from .migrations import MigrationHistory self.create_table(MigrationHistory) query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name query = self._substitute(query, MigrationHistory) return set(obj.module_name for obj in self.select(query)) def _send(self, data, settings=None, stream=False): + if PY3 and isinstance(data, string_types): + data = data.encode('utf-8') params = self._build_params(settings) r = requests.post(self.db_url, params=params, data=data, stream=stream) if r.status_code != 200: @@ -118,7 +122,7 @@ class Database(object): def _build_params(self, settings): params = dict(settings or {}) if self.username: - params['username'] = self.username + params['user'] = self.username if self.password: params['password'] = self.password return params diff --git a/src/infi/clickhouse_orm/fields.py b/src/infi/clickhouse_orm/fields.py index 44e833a..390c303 100644 --- a/src/infi/clickhouse_orm/fields.py +++ b/src/infi/clickhouse_orm/fields.py @@ -1,3 +1,4 @@ +from six import string_types, text_type, binary_type import datetime import pytz import time @@ -48,17 +49,12 @@ class StringField(Field): db_type = 'String' def to_python(self, value): - if isinstance(value, unicode): + if isinstance(value, text_type): return value - if isinstance(value, str): + if isinstance(value, binary_type): return value.decode('UTF-8') raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value)) - def get_db_prep_value(self, value): - if isinstance(value, unicode): - return value.encode('UTF-8') - return value - class DateField(Field): @@ -72,7 +68,7 @@ class DateField(Field): return value if isinstance(value, int): return DateField.class_default + datetime.timedelta(days=value) - if isinstance(value, basestring): + if isinstance(value, string_types): if value == '0000-00-00': return DateField.min_value return datetime.datetime.strptime(value, '%Y-%m-%d').date() @@ -97,7 +93,7 @@ class DateTimeField(Field): return datetime.datetime(value.year, value.month, value.day) if isinstance(value, int): return datetime.datetime.fromtimestamp(value, pytz.utc) - if isinstance(value, basestring): + if isinstance(value, string_types): return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) diff --git a/src/infi/clickhouse_orm/migrations.py b/src/infi/clickhouse_orm/migrations.py index 770a36e..203943f 100644 --- a/src/infi/clickhouse_orm/migrations.py +++ b/src/infi/clickhouse_orm/migrations.py @@ -1,9 +1,9 @@ -from models import Model -from fields import DateField, StringField -from engines import MergeTree -from utils import escape +from .models import Model +from .fields import DateField, StringField +from .engines import MergeTree +from .utils import escape -from itertools import izip +from six.moves import zip import logging logger = logging.getLogger('migrations') @@ -74,7 +74,7 @@ class AlterTable(Operation): prev_name = name # Identify fields whose type was changed model_fields = [(name, field.db_type) for name, field in self.model_class._fields] - for model_field, table_field in izip(model_fields, self._get_table_fields(database)): + for model_field, table_field in zip(model_fields, self._get_table_fields(database)): assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement' if model_field[1] != table_field[1]: logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1]) diff --git a/src/infi/clickhouse_orm/models.py b/src/infi/clickhouse_orm/models.py index de15590..c6a0ff8 100644 --- a/src/infi/clickhouse_orm/models.py +++ b/src/infi/clickhouse_orm/models.py @@ -1,6 +1,8 @@ -from utils import escape, parse_tsv -from engines import * -from fields import Field +from .utils import escape, parse_tsv +from .engines import * +from .fields import Field + +from six import with_metaclass class ModelBase(type): @@ -26,9 +28,10 @@ class ModelBase(type): @classmethod def create_ad_hoc_model(cls, fields): # fields is a list of tuples (name, db_type) - import fields as orm_fields + import infi.clickhouse_orm.fields as orm_fields # Check if model exists in cache - cache_key = unicode(fields) + fields = list(fields) + cache_key = str(fields) if cache_key in cls.ad_hoc_model_cache: return cls.ad_hoc_model_cache[cache_key] # Create an ad hoc model class @@ -44,12 +47,11 @@ class ModelBase(type): return model_class -class Model(object): +class Model(with_metaclass(ModelBase)): ''' A base class for ORM models. ''' - __metaclass__ = ModelBase engine = None def __init__(self, **kwargs): @@ -61,7 +63,7 @@ class Model(object): ''' super(Model, self).__init__() # Assign field values from keyword arguments - for name, value in kwargs.iteritems(): + for name, value in kwargs.items(): field = self.get_field(name) if field: setattr(self, name, value) @@ -126,11 +128,12 @@ class Model(object): The field_names list must match the fields defined in the model, but does not have to include all of them. If omitted, it is assumed to be the names of all fields in the model, in order of definition. ''' + from six import next field_names = field_names or [name for name, field in cls._fields] values = iter(parse_tsv(line)) kwargs = {} for name in field_names: - kwargs[name] = values.next() + kwargs[name] = next(values) return cls(**kwargs) def to_tsv(self): diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index 0ac8ca4..f52a7ae 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -1,3 +1,6 @@ +from six import string_types, binary_type, text_type, PY3 +import codecs + SPECIAL_CHARS = { "\b" : "\\b", @@ -12,17 +15,19 @@ SPECIAL_CHARS = { def escape(value, quote=True): - if isinstance(value, basestring): + if isinstance(value, string_types): chars = (SPECIAL_CHARS.get(c, c) for c in value) - return "'" + "".join(chars) + "'" if quote else "".join(chars) - return str(value) + value = "'" + "".join(chars) + "'" if quote else "".join(chars) + return text_type(value) def unescape(value): - return value.decode('string_escape') + return codecs.escape_decode(value)[0].decode('utf-8') def parse_tsv(line): + if PY3 and isinstance(line, binary_type): + line = line.decode() if line[-1] == '\n': line = line[:-1] return [unescape(value) for value in line.split('\t')] diff --git a/tests/test_database.py b/tests/test_database.py index 6950753..ce6f671 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import unittest from infi.clickhouse_orm.database import Database @@ -91,6 +93,13 @@ class DatabaseTestCase(unittest.TestCase): # Verify that all instances were returned self.assertEquals(len(instances), len(data)) + def test_special_chars(self): + s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r' + p = Person(first_name=s) + self.database.insert([p]) + p = list(self.database.select("SELECT * from $table", Person))[0] + self.assertEquals(p.first_name, s) + def _sample_data(self): for entry in data: yield Person(**entry)