Remove usage of six

This commit is contained in:
Itai Shirav 2019-12-15 19:14:16 +02:00
parent 39f34b7c85
commit ef30f1d1bd
11 changed files with 52 additions and 68 deletions

View File

@ -14,8 +14,7 @@ install_requires = [
'iso8601 >= 0.1.12',
'pytz',
'requests',
'setuptools',
'six'
'setuptools'
]
version_file = src/infi/clickhouse_orm/__version__.py
description = A Python library for working with the ClickHouse database

View File

@ -8,7 +8,6 @@ from .utils import escape, parse_tsv, import_submodules
from math import ceil
import datetime
from string import Template
from six import PY3, string_types
import pytz
import logging
@ -174,7 +173,7 @@ class Database(object):
The name must be string, and the value is converted to string in case
it isn't. To remove a setting, pass `None` as the value.
'''
assert isinstance(name, string_types), 'Setting name must be a string'
assert isinstance(name, str), 'Setting name must be a string'
if value is None:
self.settings.pop(name, None)
else:
@ -187,7 +186,6 @@ class Database(object):
- `model_instances`: any iterable containing instances of a single model class.
- `batch_size`: number of records to send per chunk (use a lower number if your records are very large).
'''
from six import next
from io import BytesIO
i = iter(model_instances)
try:
@ -338,7 +336,7 @@ class Database(object):
return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False):
if isinstance(data, string_types):
if isinstance(data, str):
data = data.encode('utf-8')
if self.log_statements:
logger.info(data)

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import logging
import six
from .utils import comma_join
@ -37,7 +36,7 @@ class MergeTree(Engine):
def __init__(self, date_col=None, order_by=(), sampling_expr=None,
index_granularity=8192, replica_table_path=None, replica_name=None, partition_key=None):
assert type(order_by) in (list, tuple), 'order_by must be a list or tuple'
assert date_col is None or isinstance(date_col, six.string_types), 'date_col must be string if present'
assert date_col is None or isinstance(date_col, str), 'date_col must be string if present'
assert partition_key is None or type(partition_key) in (list, tuple),\
'partition_key must be tuple or list if present'
assert (replica_table_path is None) == (replica_name is None), \
@ -198,7 +197,7 @@ class Merge(Engine):
"""
def __init__(self, table_regex):
assert isinstance(table_regex, six.string_types), "'table_regex' parameter must be string"
assert isinstance(table_regex, str), "'table_regex' parameter must be string"
self.table_regex = table_regex
def create_table_sql(self, db):

View File

@ -1,5 +1,4 @@
from __future__ import unicode_literals
from six import string_types, text_type, binary_type, integer_types
import datetime
import iso8601
import pytz
@ -25,14 +24,14 @@ class Field(FunctionOperatorsMixin):
db_type = None # should be overridden by concrete subclasses
def __init__(self, default=None, alias=None, materialized=None, readonly=None, codec=None):
assert (None, None) in {(default, alias), (alias, materialized), (default, materialized)}, \
assert [default, alias, materialized].count(None) >= 2, \
"Only one of default, alias and materialized parameters can be given"
assert alias is None or isinstance(alias, F) or isinstance(alias, string_types) and alias != "",\
assert alias is None or isinstance(alias, F) or isinstance(alias, str) and alias != "",\
"Alias parameter must be a string or function object, if given"
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, string_types) and materialized != "",\
assert materialized is None or isinstance(materialized, F) or isinstance(materialized, str) and materialized != "",\
"Materialized parameter must be a string or function object, if given"
assert readonly is None or type(readonly) is bool, "readonly parameter must be bool if given"
assert codec is None or isinstance(codec, string_types) and codec != "", \
assert codec is None or isinstance(codec, str) and codec != "", \
"Codec field must be string, if given"
self.creation_counter = Field.creation_counter
@ -140,9 +139,9 @@ class StringField(Field):
db_type = 'String'
def to_python(self, value, timezone_in_use):
if isinstance(value, text_type):
if isinstance(value, str):
return value
if isinstance(value, binary_type):
if isinstance(value, bytes):
return value.decode('UTF-8')
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
@ -159,7 +158,7 @@ class FixedStringField(StringField):
return value.rstrip('\0')
def validate(self, value):
if isinstance(value, text_type):
if isinstance(value, str):
value = value.encode('UTF-8')
if len(value) > self._length:
raise ValueError('Value of %d bytes is too long for FixedStringField(%d)' % (len(value), self._length))
@ -179,7 +178,7 @@ class DateField(Field):
return value
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, string_types):
if isinstance(value, str):
if value == '0000-00-00':
return DateField.min_value
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
@ -204,7 +203,7 @@ class DateTimeField(Field):
return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
if isinstance(value, int):
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, string_types):
if isinstance(value, str):
if value == '0000-00-00 00:00:00':
return self.class_default
if len(value) == 10:
@ -217,7 +216,7 @@ class DateTimeField(Field):
# left the date naive in case of no tzinfo set
dt = iso8601.parse_date(value, default_timezone=None)
except iso8601.ParseError as e:
raise ValueError(text_type(e))
raise ValueError(str(e))
# convert naive to aware
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
@ -242,7 +241,7 @@ class BaseIntField(Field):
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return text_type(value)
return str(value)
def validate(self, value):
self._range_check(value, self.min_value, self.max_value)
@ -318,7 +317,7 @@ class BaseFloatField(Field):
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return text_type(value)
return str(value)
class Float32Field(BaseFloatField):
@ -362,7 +361,7 @@ class DecimalField(Field):
def to_db_string(self, value, quote=True):
# There's no need to call escape since numbers do not contain
# special characters, and never need quoting
return text_type(value)
return str(value)
def _round(self, value):
return value.quantize(self.exp)
@ -407,9 +406,9 @@ class BaseEnumField(Field):
if isinstance(value, self.enum_cls):
return value
try:
if isinstance(value, text_type):
if isinstance(value, str):
return self.enum_cls[value]
if isinstance(value, binary_type):
if isinstance(value, bytes):
return self.enum_cls[value.decode('UTF-8')]
if isinstance(value, int):
return self.enum_cls(value)
@ -467,9 +466,9 @@ class ArrayField(Field):
super(ArrayField, self).__init__(default, alias, materialized, readonly, codec)
def to_python(self, value, timezone_in_use):
if isinstance(value, text_type):
if isinstance(value, str):
value = parse_array(value)
elif isinstance(value, binary_type):
elif isinstance(value, bytes):
value = parse_array(value.decode('UTF-8'))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
@ -498,11 +497,11 @@ class UUIDField(Field):
def to_python(self, value, timezone_in_use):
if isinstance(value, UUID):
return value
elif isinstance(value, binary_type):
elif isinstance(value, bytes):
return UUID(bytes=value)
elif isinstance(value, string_types):
elif isinstance(value, str):
return UUID(value)
elif isinstance(value, integer_types):
elif isinstance(value, int):
return UUID(int=value)
elif isinstance(value, tuple):
return UUID(fields=value)
@ -521,7 +520,7 @@ class IPv4Field(Field):
def to_python(self, value, timezone_in_use):
if isinstance(value, IPv4Address):
return value
elif isinstance(value, (binary_type,) + string_types + integer_types):
elif isinstance(value, (bytes, str, int)):
return IPv4Address(value)
else:
raise ValueError('Invalid value for IPv4Address: %r' % value)
@ -538,7 +537,7 @@ class IPv6Field(Field):
def to_python(self, value, timezone_in_use):
if isinstance(value, IPv6Address):
return value
elif isinstance(value, (binary_type,) + string_types + integer_types):
elif isinstance(value, (bytes, str, int)):
return IPv6Address(value)
else:
raise ValueError('Invalid value for IPv6Address: %r' % value)

View File

@ -1,4 +1,3 @@
import six
from datetime import date, datetime, tzinfo
import functools
@ -144,21 +143,21 @@ class F(Cond, FunctionOperatorsMixin):
return arg.to_sql()
if isinstance(arg, Field):
return "`%s`" % arg.name
if isinstance(arg, six.string_types):
if isinstance(arg, str):
return StringField().to_db_string(arg)
if isinstance(arg, datetime):
return "toDateTime(%s)" % DateTimeField().to_db_string(arg)
if isinstance(arg, date):
return "toDate('%s')" % arg.isoformat()
if isinstance(arg, bool):
return six.text_type(int(arg))
return str(int(arg))
if isinstance(arg, tzinfo):
return StringField().to_db_string(arg.tzname(None))
if arg is None:
return 'NULL'
if is_iterable(arg):
return '[' + comma_join(F.arg_to_sql(x) for x in arg) + ']'
return six.text_type(arg)
return str(arg)
# Arithmetic functions

View File

@ -1,13 +1,8 @@
import six
from .models import Model, BufferModel
from .fields import DateField, StringField
from .engines import MergeTree
from .utils import escape
from six.moves import zip
from six import iteritems
import logging
logger = logging.getLogger('migrations')
@ -74,7 +69,7 @@ class AlterTable(Operation):
# Identify fields that were added to the model
prev_name = None
for name, field in iteritems(self.model_class.fields()):
for name, field in self.model_class.fields().items():
is_regular_field = not (field.materialized or field.alias)
if name not in table_fields:
logger.info(' Add column %s', name)
@ -94,7 +89,7 @@ class AlterTable(Operation):
# Secondly, MATERIALIZED and ALIAS fields are always at the end of the DESC, so we can't expect them to save
# attribute position. Watch https://github.com/Infinidat/infi.clickhouse_orm/issues/47
model_fields = {name: field.get_sql(with_default_expression=False, db=database)
for name, field in iteritems(self.model_class.fields())}
for name, field in self.model_class.fields().items()}
for field_name, field_sql in self._get_table_fields(database):
# All fields must have been created and dropped by this moment
assert field_name in model_fields, 'Model fields and table columns in disagreement'
@ -156,7 +151,7 @@ class RunSQL(Operation):
'''
def __init__(self, sql):
if isinstance(sql, six.string_types):
if isinstance(sql, str):
sql = [sql]
assert isinstance(sql, list), "'sql' parameter must be string or list of strings"

View File

@ -3,7 +3,7 @@ import sys
from collections import OrderedDict
from logging import getLogger
from six import with_metaclass, reraise, iteritems
from six import reraise
import pytz
from .fields import Field, StringField
@ -31,8 +31,8 @@ class ModelBase(type):
fields = base_fields
# Build a list of fields, in the order they were listed in the class
fields.update({n: f for n, f in iteritems(attrs) if isinstance(f, Field)})
fields = sorted(iteritems(fields), key=lambda item: item[1].creation_counter)
fields.update({n: f for n, f in attrs.items() if isinstance(f, Field)})
fields = sorted(fields.items(), key=lambda item: item[1].creation_counter)
# Build a dictionary of default values
defaults = {n: f.to_python(f.default, pytz.UTC) for n, f in fields}
@ -102,7 +102,7 @@ class ModelBase(type):
return getattr(orm_fields, name)()
class Model(with_metaclass(ModelBase)):
class Model(metaclass=ModelBase):
'''
A base class for ORM models. Each model class represent a ClickHouse table. For example:
@ -134,7 +134,7 @@ class Model(with_metaclass(ModelBase)):
# Assign default values
self.__dict__.update(self._defaults)
# Assign field values from keyword arguments
for name, value in iteritems(kwargs):
for name, value in kwargs.items():
field = self.get_field(name)
if field:
setattr(self, name, value)
@ -154,7 +154,7 @@ class Model(with_metaclass(ModelBase)):
except ValueError:
tp, v, tb = sys.exc_info()
new_msg = "{} (field '{}')".format(v, name)
reraise(tp, tp(new_msg), tb)
raise tp.with_traceback(tp(new_msg), tb)
super(Model, self).__setattr__(name, value)
def set_database(self, db):
@ -196,7 +196,7 @@ class Model(with_metaclass(ModelBase)):
'''
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = []
for name, field in iteritems(cls.fields()):
for name, field in cls.fields().items():
cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols))
parts.append(')')
@ -221,7 +221,6 @@ class Model(with_metaclass(ModelBase)):
- `timezone_in_use`: the timezone to use when parsing dates and datetimes.
- `database`: if given, sets the database that this instance belongs to.
'''
from six import next
values = iter(parse_tsv(line))
kwargs = {}
for name in field_names:
@ -242,7 +241,7 @@ class Model(with_metaclass(ModelBase)):
'''
data = self.__dict__
fields = self.fields(writable=not include_readonly)
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in iteritems(fields))
return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in fields.items())
def to_dict(self, include_readonly=True, field_names=None):
'''
@ -321,7 +320,7 @@ class MergeModel(Model):
assert isinstance(cls.engine, Merge), "engine must be an instance of engines.Merge"
parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db.db_name, cls.table_name())]
cols = []
for name, field in iteritems(cls.fields()):
for name, field in cls.fields().items():
if name != '_table':
cols.append(' %s %s' % (name, field.get_sql(db=db)))
parts.append(',\n'.join(cols))

View File

@ -1,6 +1,5 @@
from __future__ import unicode_literals
import six
import pytz
from copy import copy, deepcopy
from math import ceil
@ -62,7 +61,7 @@ class InOperator(Operator):
field = getattr(model_cls, field_name)
if isinstance(value, QuerySet):
value = value.as_sql()
elif isinstance(value, six.string_types):
elif isinstance(value, str):
pass
else:
value = comma_join([self._value_to_sql(field, v) for v in value])
@ -197,7 +196,7 @@ class Q(object):
OR_MODE = 'OR'
def __init__(self, *filter_funcs, **filter_fields):
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in six.iteritems(filter_fields)]
self._conds = list(filter_funcs) + [self._build_cond(k, v) for k, v in filter_fields.items()]
self._children = []
self._negate = False
self._mode = self.AND_MODE
@ -283,7 +282,6 @@ class Q(object):
return q
@six.python_2_unicode_compatible
class QuerySet(object):
"""
A queryset is an object that represents a database query using a specific `Model`.
@ -328,12 +326,12 @@ class QuerySet(object):
return self.as_sql()
def __getitem__(self, s):
if isinstance(s, six.integer_types):
if isinstance(s, int):
# Single index
assert s >= 0, 'negative indexes are not supported'
qs = copy(self)
qs._limits = (s, 1)
return six.next(iter(qs))
return next(iter(qs))
else:
# Slice
assert s.step in (None, 1), 'step is not supported in slices'

View File

@ -3,7 +3,6 @@ This file contains system readonly models that can be got from the database
https://clickhouse.yandex/docs/en/system_tables/
"""
from __future__ import unicode_literals
from six import string_types
from .database import Database
from .fields import *
@ -124,7 +123,7 @@ class SystemPart(Model):
:return: A list of SystemPart objects
"""
assert isinstance(database, Database), "database must be database.Database class instance"
assert isinstance(conditions, string_types), "conditions must be a string"
assert isinstance(conditions, str), "conditions must be a string"
if conditions:
conditions += " AND"
field_names = ','.join(cls.fields())

View File

@ -1,5 +1,4 @@
from __future__ import unicode_literals
from six import string_types, binary_type, text_type, PY3
import codecs
import re
@ -28,11 +27,11 @@ def escape(value, quote=True):
def escape_one(match):
return SPECIAL_CHARS[match.group(0)]
if isinstance(value, string_types):
if isinstance(value, str):
value = SPECIAL_CHARS_REGEX.sub(escape_one, value)
if quote:
value = "'" + value + "'"
return text_type(value)
return str(value)
def unescape(value):
@ -44,7 +43,7 @@ def string_or_func(obj):
def parse_tsv(line):
if PY3 and isinstance(line, binary_type):
if isinstance(line, bytes):
line = line.decode()
if line and line[-1] == '\n':
line = line[:-1]

View File

@ -86,7 +86,7 @@ class ModelTestCase(unittest.TestCase):
self.assertEqual(
"Invalid value for StringField: {} (field 'str_field')".format(repr(bad_value)),
text_type(cm.exception)
str(cm.exception)
)
def test_field_name_in_error_message_for_invalid_value_in_assignment(self):
@ -97,7 +97,7 @@ class ModelTestCase(unittest.TestCase):
self.assertEqual(
"Invalid value for Float32Field - {} (field 'float_field')".format(repr(bad_value)),
text_type(cm.exception)
str(cm.exception)
)