mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-25 10:13:45 +03:00
Add Python 3 support
This commit is contained in:
parent
290f8da9ce
commit
071bcd283d
|
@ -12,7 +12,8 @@ namespace_packages = ['infi']
|
|||
install_requires = [
|
||||
'pytz',
|
||||
'requests',
|
||||
'setuptools'
|
||||
'setuptools',
|
||||
'six'
|
||||
]
|
||||
version_file = src/infi/clickhouse_orm/__version__.py
|
||||
description = A Python library for working with the ClickHouse database
|
||||
|
|
3
setup.in
3
setup.in
|
@ -16,7 +16,10 @@ SETUP_INFO = dict(
|
|||
"License :: OSI Approved :: Python Software Foundation License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 2.7",
|
||||
"Programming Language :: Python :: 3.4",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Database"
|
||||
],
|
||||
|
||||
install_requires = ${project:install_requires},
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import requests
|
||||
from collections import namedtuple
|
||||
from models import ModelBase
|
||||
from utils import escape, parse_tsv, import_submodules
|
||||
from .models import ModelBase
|
||||
from .utils import escape, parse_tsv, import_submodules
|
||||
from math import ceil
|
||||
import datetime
|
||||
import logging
|
||||
from string import Template
|
||||
from six import PY3, string_types
|
||||
|
||||
|
||||
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
||||
|
@ -35,19 +36,20 @@ class Database(object):
|
|||
self._send('DROP DATABASE `%s`' % self.db_name)
|
||||
|
||||
def insert(self, model_instances):
|
||||
from six import next
|
||||
i = iter(model_instances)
|
||||
try:
|
||||
first_instance = i.next()
|
||||
first_instance = next(i)
|
||||
except StopIteration:
|
||||
return # model_instances is empty
|
||||
model_class = first_instance.__class__
|
||||
def gen():
|
||||
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class)
|
||||
yield first_instance.to_tsv()
|
||||
yield '\n'
|
||||
yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
|
||||
yield first_instance.to_tsv().encode('utf-8')
|
||||
yield '\n'.encode('utf-8')
|
||||
for instance in i:
|
||||
yield instance.to_tsv()
|
||||
yield '\n'
|
||||
yield instance.to_tsv().encode('utf-8')
|
||||
yield '\n'.encode('utf-8')
|
||||
self._send(gen())
|
||||
|
||||
def count(self, model_class, conditions=None):
|
||||
|
@ -88,7 +90,7 @@ class Database(object):
|
|||
)
|
||||
|
||||
def migrate(self, migrations_package_name, up_to=9999):
|
||||
from migrations import MigrationHistory
|
||||
from .migrations import MigrationHistory
|
||||
logger = logging.getLogger('migrations')
|
||||
applied_migrations = self._get_applied_migrations(migrations_package_name)
|
||||
modules = import_submodules(migrations_package_name)
|
||||
|
@ -102,13 +104,15 @@ class Database(object):
|
|||
break
|
||||
|
||||
def _get_applied_migrations(self, migrations_package_name):
|
||||
from migrations import MigrationHistory
|
||||
from .migrations import MigrationHistory
|
||||
self.create_table(MigrationHistory)
|
||||
query = "SELECT module_name from $table WHERE package_name = '%s'" % migrations_package_name
|
||||
query = self._substitute(query, MigrationHistory)
|
||||
return set(obj.module_name for obj in self.select(query))
|
||||
|
||||
def _send(self, data, settings=None, stream=False):
|
||||
if PY3 and isinstance(data, string_types):
|
||||
data = data.encode('utf-8')
|
||||
params = self._build_params(settings)
|
||||
r = requests.post(self.db_url, params=params, data=data, stream=stream)
|
||||
if r.status_code != 200:
|
||||
|
@ -118,7 +122,7 @@ class Database(object):
|
|||
def _build_params(self, settings):
|
||||
params = dict(settings or {})
|
||||
if self.username:
|
||||
params['username'] = self.username
|
||||
params['user'] = self.username
|
||||
if self.password:
|
||||
params['password'] = self.password
|
||||
return params
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from six import string_types, text_type, binary_type
|
||||
import datetime
|
||||
import pytz
|
||||
import time
|
||||
|
@ -48,17 +49,12 @@ class StringField(Field):
|
|||
db_type = 'String'
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, unicode):
|
||||
if isinstance(value, text_type):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
if isinstance(value, binary_type):
|
||||
return value.decode('UTF-8')
|
||||
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||
|
||||
def get_db_prep_value(self, value):
|
||||
if isinstance(value, unicode):
|
||||
return value.encode('UTF-8')
|
||||
return value
|
||||
|
||||
|
||||
class DateField(Field):
|
||||
|
||||
|
@ -72,7 +68,7 @@ class DateField(Field):
|
|||
return value
|
||||
if isinstance(value, int):
|
||||
return DateField.class_default + datetime.timedelta(days=value)
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, string_types):
|
||||
if value == '0000-00-00':
|
||||
return DateField.min_value
|
||||
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
||||
|
@ -97,7 +93,7 @@ class DateTimeField(Field):
|
|||
return datetime.datetime(value.year, value.month, value.day)
|
||||
if isinstance(value, int):
|
||||
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, string_types):
|
||||
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from models import Model
|
||||
from fields import DateField, StringField
|
||||
from engines import MergeTree
|
||||
from utils import escape
|
||||
from .models import Model
|
||||
from .fields import DateField, StringField
|
||||
from .engines import MergeTree
|
||||
from .utils import escape
|
||||
|
||||
from itertools import izip
|
||||
from six.moves import zip
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger('migrations')
|
||||
|
@ -74,7 +74,7 @@ class AlterTable(Operation):
|
|||
prev_name = name
|
||||
# Identify fields whose type was changed
|
||||
model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
|
||||
for model_field, table_field in izip(model_fields, self._get_table_fields(database)):
|
||||
for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
|
||||
assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
|
||||
if model_field[1] != table_field[1]:
|
||||
logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1])
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from utils import escape, parse_tsv
|
||||
from engines import *
|
||||
from fields import Field
|
||||
from .utils import escape, parse_tsv
|
||||
from .engines import *
|
||||
from .fields import Field
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
|
||||
class ModelBase(type):
|
||||
|
@ -26,9 +28,10 @@ class ModelBase(type):
|
|||
@classmethod
|
||||
def create_ad_hoc_model(cls, fields):
|
||||
# fields is a list of tuples (name, db_type)
|
||||
import fields as orm_fields
|
||||
import infi.clickhouse_orm.fields as orm_fields
|
||||
# Check if model exists in cache
|
||||
cache_key = unicode(fields)
|
||||
fields = list(fields)
|
||||
cache_key = str(fields)
|
||||
if cache_key in cls.ad_hoc_model_cache:
|
||||
return cls.ad_hoc_model_cache[cache_key]
|
||||
# Create an ad hoc model class
|
||||
|
@ -44,12 +47,11 @@ class ModelBase(type):
|
|||
return model_class
|
||||
|
||||
|
||||
class Model(object):
|
||||
class Model(with_metaclass(ModelBase)):
|
||||
'''
|
||||
A base class for ORM models.
|
||||
'''
|
||||
|
||||
__metaclass__ = ModelBase
|
||||
engine = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -61,7 +63,7 @@ class Model(object):
|
|||
'''
|
||||
super(Model, self).__init__()
|
||||
# Assign field values from keyword arguments
|
||||
for name, value in kwargs.iteritems():
|
||||
for name, value in kwargs.items():
|
||||
field = self.get_field(name)
|
||||
if field:
|
||||
setattr(self, name, value)
|
||||
|
@ -126,11 +128,12 @@ class Model(object):
|
|||
The field_names list must match the fields defined in the model, but does not have to include all of them.
|
||||
If omitted, it is assumed to be the names of all fields in the model, in order of definition.
|
||||
'''
|
||||
from six import next
|
||||
field_names = field_names or [name for name, field in cls._fields]
|
||||
values = iter(parse_tsv(line))
|
||||
kwargs = {}
|
||||
for name in field_names:
|
||||
kwargs[name] = values.next()
|
||||
kwargs[name] = next(values)
|
||||
return cls(**kwargs)
|
||||
|
||||
def to_tsv(self):
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from six import string_types, binary_type, text_type, PY3
|
||||
import codecs
|
||||
|
||||
|
||||
SPECIAL_CHARS = {
|
||||
"\b" : "\\b",
|
||||
|
@ -12,17 +15,19 @@ SPECIAL_CHARS = {
|
|||
|
||||
|
||||
def escape(value, quote=True):
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, string_types):
|
||||
chars = (SPECIAL_CHARS.get(c, c) for c in value)
|
||||
return "'" + "".join(chars) + "'" if quote else "".join(chars)
|
||||
return str(value)
|
||||
value = "'" + "".join(chars) + "'" if quote else "".join(chars)
|
||||
return text_type(value)
|
||||
|
||||
|
||||
def unescape(value):
|
||||
return value.decode('string_escape')
|
||||
return codecs.escape_decode(value)[0].decode('utf-8')
|
||||
|
||||
|
||||
def parse_tsv(line):
|
||||
if PY3 and isinstance(line, binary_type):
|
||||
line = line.decode()
|
||||
if line[-1] == '\n':
|
||||
line = line[:-1]
|
||||
return [unescape(value) for value in line.split('\t')]
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
||||
from infi.clickhouse_orm.database import Database
|
||||
|
@ -91,6 +93,13 @@ class DatabaseTestCase(unittest.TestCase):
|
|||
# Verify that all instances were returned
|
||||
self.assertEquals(len(instances), len(data))
|
||||
|
||||
def test_special_chars(self):
|
||||
s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
|
||||
p = Person(first_name=s)
|
||||
self.database.insert([p])
|
||||
p = list(self.database.select("SELECT * from $table", Person))[0]
|
||||
self.assertEquals(p.first_name, s)
|
||||
|
||||
def _sample_data(self):
|
||||
for entry in data:
|
||||
yield Person(**entry)
|
||||
|
|
Loading…
Reference in New Issue
Block a user