infi.clickhouse_orm/src/infi/clickhouse_orm/fields.py

296 lines
8.2 KiB
Python
Raw Normal View History

2016-08-01 10:28:10 +03:00
from six import string_types, text_type, binary_type
2016-06-23 14:11:20 +03:00
import datetime
import pytz
import time
2016-09-01 15:25:48 +03:00
from .utils import escape, parse_array
2016-06-23 14:11:20 +03:00
class Field(object):
creation_counter = 0
class_default = 0
db_type = None
def __init__(self, default=None):
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
2016-08-31 15:26:28 +03:00
self.default = self.class_default if default is None else default
2016-06-23 14:11:20 +03:00
def to_python(self, value):
2016-06-26 16:52:25 +03:00
'''
2016-06-23 14:11:20 +03:00
Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this.
2016-06-26 16:52:25 +03:00
'''
2016-06-23 14:11:20 +03:00
return value
def validate(self, value):
2016-06-26 16:52:25 +03:00
'''
Called after to_python to validate that the value is suitable for the field's database type.
Subclasses should override this.
'''
pass
def _range_check(self, value, min_value, max_value):
2016-06-26 16:52:25 +03:00
'''
Utility method to check that the given value is between min_value and max_value.
'''
if value < min_value or value > max_value:
raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
2016-09-01 15:25:48 +03:00
def to_db_string(self, value, quote=True):
2016-06-26 16:52:25 +03:00
'''
2016-09-01 15:25:48 +03:00
Returns the field's value prepared for writing to the database.
When quote is true, strings are surrounded by single quotes.
2016-06-26 16:52:25 +03:00
'''
2016-09-01 15:25:48 +03:00
return escape(value, quote)
2016-06-23 14:11:20 +03:00
2016-08-31 15:26:28 +03:00
def get_sql(self, with_default=True):
'''
Returns an SQL expression describing the field (e.g. for CREATE TABLE).
'''
if with_default:
2016-09-01 15:25:48 +03:00
default = self.to_db_string(self.default)
return '%s DEFAULT %s' % (self.db_type, default)
2016-08-31 15:26:28 +03:00
else:
return self.db_type
2016-06-23 14:11:20 +03:00
class StringField(Field):
class_default = ''
db_type = 'String'
def to_python(self, value):
2016-08-01 10:28:10 +03:00
if isinstance(value, text_type):
2016-06-23 14:11:20 +03:00
return value
2016-08-01 10:28:10 +03:00
if isinstance(value, binary_type):
2016-06-23 14:11:20 +03:00
return value.decode('UTF-8')
2016-06-23 18:24:20 +03:00
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
2016-06-23 14:11:20 +03:00
class DateField(Field):
min_value = datetime.date(1970, 1, 1)
max_value = datetime.date(2038, 1, 19)
class_default = min_value
2016-06-23 14:11:20 +03:00
db_type = 'Date'
def to_python(self, value):
if isinstance(value, datetime.date):
return value
if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value)
2016-08-01 10:28:10 +03:00
if isinstance(value, string_types):
2016-06-30 16:12:02 +03:00
if value == '0000-00-00':
return DateField.min_value
2016-06-23 14:11:20 +03:00
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def validate(self, value):
self._range_check(value, DateField.min_value, DateField.max_value)
2016-06-23 14:11:20 +03:00
2016-09-01 15:25:48 +03:00
def to_db_string(self, value, quote=True):
return escape(value.isoformat(), quote)
2016-06-23 14:11:20 +03:00
class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime'
def to_python(self, value):
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day)
if isinstance(value, int):
return datetime.datetime.fromtimestamp(value, pytz.utc)
2016-08-01 10:28:10 +03:00
if isinstance(value, string_types):
2016-09-29 11:24:22 +03:00
if value == '0000-00-00 00:00:00':
return self.class_default
2016-06-30 12:11:47 +03:00
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
2016-06-23 14:11:20 +03:00
2016-09-01 15:25:48 +03:00
def to_db_string(self, value, quote=True):
return escape(int(time.mktime(value.timetuple())), quote)
2016-06-23 14:11:20 +03:00
class BaseIntField(Field):
def to_python(self, value):
try:
2016-06-23 14:11:20 +03:00
return int(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def validate(self, value):
self._range_check(value, self.min_value, self.max_value)
2016-06-23 14:11:20 +03:00
class UInt8Field(BaseIntField):
min_value = 0
max_value = 2**8 - 1
2016-06-23 14:11:20 +03:00
db_type = 'UInt8'
class UInt16Field(BaseIntField):
min_value = 0
max_value = 2**16 - 1
2016-06-23 14:11:20 +03:00
db_type = 'UInt16'
class UInt32Field(BaseIntField):
min_value = 0
max_value = 2**32 - 1
2016-06-23 14:11:20 +03:00
db_type = 'UInt32'
class UInt64Field(BaseIntField):
min_value = 0
max_value = 2**64 - 1
2016-06-23 14:11:20 +03:00
db_type = 'UInt64'
class Int8Field(BaseIntField):
min_value = -2**7
max_value = 2**7 - 1
2016-06-23 14:11:20 +03:00
db_type = 'Int8'
class Int16Field(BaseIntField):
2016-06-26 17:41:17 +03:00
min_value = -2**15
max_value = 2**15 - 1
2016-06-23 14:11:20 +03:00
db_type = 'Int16'
class Int32Field(BaseIntField):
2016-06-26 17:41:17 +03:00
min_value = -2**31
max_value = 2**31 - 1
2016-06-23 14:11:20 +03:00
db_type = 'Int32'
class Int64Field(BaseIntField):
2016-06-26 17:41:17 +03:00
min_value = -2**63
max_value = 2**63 - 1
2016-06-23 14:11:20 +03:00
db_type = 'Int64'
class BaseFloatField(Field):
def to_python(self, value):
try:
2016-06-23 14:11:20 +03:00
return float(value)
except:
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
2016-06-23 14:11:20 +03:00
class Float32Field(BaseFloatField):
db_type = 'Float32'
class Float64Field(BaseFloatField):
db_type = 'Float64'
2016-08-31 15:26:28 +03:00
class BaseEnumField(Field):
def __init__(self, enum_cls, default=None):
self.enum_cls = enum_cls
if default is None:
default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default)
def to_python(self, value):
if isinstance(value, self.enum_cls):
return value
try:
if isinstance(value, text_type):
return self.enum_cls[value]
if isinstance(value, binary_type):
return self.enum_cls[value.decode('UTF-8')]
if isinstance(value, int):
return self.enum_cls(value)
except (KeyError, ValueError):
pass
raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
2016-09-01 15:25:48 +03:00
def to_db_string(self, value, quote=True):
return escape(value.name, quote)
2016-08-31 15:26:28 +03:00
def get_sql(self, with_default=True):
values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
if with_default:
2016-09-01 15:25:48 +03:00
default = self.to_db_string(self.default)
sql = '%s DEFAULT %s' % (sql, default)
2016-08-31 15:26:28 +03:00
return sql
2016-09-01 15:25:48 +03:00
2016-08-31 15:26:28 +03:00
@classmethod
def create_ad_hoc_field(cls, db_type):
'''
Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
this method returns a matching enum field.
'''
import re
try:
Enum # exists in Python 3.4+
except NameError:
from enum import Enum # use the enum34 library instead
members = {}
for match in re.finditer("'(\w+)' = (\d+)", db_type):
members[match.group(1)] = int(match.group(2))
enum_cls = Enum('AdHocEnum', members)
field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
return field_class(enum_cls)
class Enum8Field(BaseEnumField):
db_type = 'Enum8'
class Enum16Field(BaseEnumField):
db_type = 'Enum16'
2016-09-01 15:25:48 +03:00
class ArrayField(Field):
class_default = []
def __init__(self, inner_field, default=None):
self.inner_field = inner_field
super(ArrayField, self).__init__(default)
def to_python(self, value):
if isinstance(value, text_type):
value = parse_array(value)
elif isinstance(value, binary_type):
value = parse_array(value.decode('UTF-8'))
elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
return [self.inner_field.to_python(v) for v in value]
def validate(self, value):
for v in value:
self.inner_field.validate(v)
def to_db_string(self, value, quote=True):
array = [self.inner_field.to_db_string(v, quote=True) for v in value]
return '[' + ', '.join(array) + ']'
def get_sql(self, with_default=True):
from .utils import escape
return 'Array(%s)' % self.inner_field.get_sql(with_default=False)