mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-26 02:33:45 +03:00
Always keep datetime fields in UTC internally, and convert server timezone to UTC when parsing query results.
This commit is contained in:
parent
a73a69ef52
commit
f29d737f29
|
@ -4,9 +4,12 @@ from .models import ModelBase
|
||||||
from .utils import escape, parse_tsv, import_submodules
|
from .utils import escape, parse_tsv, import_submodules
|
||||||
from math import ceil
|
from math import ceil
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
|
||||||
from string import Template
|
from string import Template
|
||||||
from six import PY3, string_types
|
from six import PY3, string_types
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger('clickhouse_orm')
|
||||||
|
|
||||||
|
|
||||||
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
|
||||||
|
@ -26,6 +29,7 @@ class Database(object):
|
||||||
self.readonly = readonly
|
self.readonly = readonly
|
||||||
if not self.readonly:
|
if not self.readonly:
|
||||||
self.create_database()
|
self.create_database()
|
||||||
|
self.server_timezone = self._get_server_timezone()
|
||||||
|
|
||||||
def create_database(self):
|
def create_database(self):
|
||||||
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
|
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
|
||||||
|
@ -82,7 +86,7 @@ class Database(object):
|
||||||
field_types = parse_tsv(next(lines))
|
field_types = parse_tsv(next(lines))
|
||||||
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
|
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
|
||||||
for line in lines:
|
for line in lines:
|
||||||
yield model_class.from_tsv(line, field_names)
|
yield model_class.from_tsv(line, field_names, self.server_timezone)
|
||||||
|
|
||||||
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
|
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
|
||||||
count = self.count(model_class, conditions)
|
count = self.count(model_class, conditions)
|
||||||
|
@ -154,3 +158,11 @@ class Database(object):
|
||||||
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
|
||||||
query = Template(query).substitute(mapping)
|
query = Template(query).substitute(mapping)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
def _get_server_timezone(self):
|
||||||
|
try:
|
||||||
|
r = self._send('SELECT timezone()')
|
||||||
|
return pytz.timezone(r.text.strip())
|
||||||
|
except DatabaseException:
|
||||||
|
logger.exception('Cannot determine server timezone, assuming UTC')
|
||||||
|
return pytz.utc
|
||||||
|
|
|
@ -2,6 +2,7 @@ from six import string_types, text_type, binary_type
|
||||||
import datetime
|
import datetime
|
||||||
import pytz
|
import pytz
|
||||||
import time
|
import time
|
||||||
|
from calendar import timegm
|
||||||
|
|
||||||
from .utils import escape, parse_array
|
from .utils import escape, parse_array
|
||||||
|
|
||||||
|
@ -24,10 +25,11 @@ class Field(object):
|
||||||
self.alias = alias
|
self.alias = alias
|
||||||
self.materialized = materialized
|
self.materialized = materialized
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
'''
|
'''
|
||||||
Converts the input value into the expected Python data type, raising ValueError if the
|
Converts the input value into the expected Python data type, raising ValueError if the
|
||||||
data can't be converted. Returns the converted value. Subclasses should override this.
|
data can't be converted. Returns the converted value. Subclasses should override this.
|
||||||
|
The timezone_in_use parameter should be consulted when parsing datetime fields.
|
||||||
'''
|
'''
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -77,7 +79,7 @@ class StringField(Field):
|
||||||
class_default = ''
|
class_default = ''
|
||||||
db_type = 'String'
|
db_type = 'String'
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, text_type):
|
if isinstance(value, text_type):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, binary_type):
|
if isinstance(value, binary_type):
|
||||||
|
@ -92,11 +94,11 @@ class DateField(Field):
|
||||||
class_default = min_value
|
class_default = min_value
|
||||||
db_type = 'Date'
|
db_type = 'Date'
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, datetime.date):
|
|
||||||
return value
|
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return value.date()
|
return value.date()
|
||||||
|
if isinstance(value, datetime.date):
|
||||||
|
return value
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return DateField.class_default + datetime.timedelta(days=value)
|
return DateField.class_default + datetime.timedelta(days=value)
|
||||||
if isinstance(value, string_types):
|
if isinstance(value, string_types):
|
||||||
|
@ -117,26 +119,27 @@ class DateTimeField(Field):
|
||||||
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
|
||||||
db_type = 'DateTime'
|
db_type = 'DateTime'
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return value
|
return value.astimezone(pytz.utc) if value.tzinfo else value.replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, datetime.date):
|
if isinstance(value, datetime.date):
|
||||||
return datetime.datetime(value.year, value.month, value.day)
|
return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
|
||||||
if isinstance(value, string_types):
|
if isinstance(value, string_types):
|
||||||
if value == '0000-00-00 00:00:00':
|
if value == '0000-00-00 00:00:00':
|
||||||
return self.class_default
|
return self.class_default
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
dt = datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
|
||||||
|
return timezone_in_use.localize(dt).astimezone(pytz.utc)
|
||||||
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def to_db_string(self, value, quote=True):
|
def to_db_string(self, value, quote=True):
|
||||||
return escape(int(time.mktime(value.timetuple())), quote)
|
return escape(timegm(value.utctimetuple()), quote)
|
||||||
|
|
||||||
|
|
||||||
class BaseIntField(Field):
|
class BaseIntField(Field):
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
except:
|
except:
|
||||||
|
@ -204,7 +207,7 @@ class Int64Field(BaseIntField):
|
||||||
|
|
||||||
class BaseFloatField(Field):
|
class BaseFloatField(Field):
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
try:
|
try:
|
||||||
return float(value)
|
return float(value)
|
||||||
except:
|
except:
|
||||||
|
@ -229,7 +232,7 @@ class BaseEnumField(Field):
|
||||||
default = list(enum_cls)[0]
|
default = list(enum_cls)[0]
|
||||||
super(BaseEnumField, self).__init__(default, alias, materialized)
|
super(BaseEnumField, self).__init__(default, alias, materialized)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, self.enum_cls):
|
if isinstance(value, self.enum_cls):
|
||||||
return value
|
return value
|
||||||
try:
|
try:
|
||||||
|
@ -291,14 +294,14 @@ class ArrayField(Field):
|
||||||
self.inner_field = inner_field
|
self.inner_field = inner_field
|
||||||
super(ArrayField, self).__init__(default, alias, materialized)
|
super(ArrayField, self).__init__(default, alias, materialized)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value, timezone_in_use):
|
||||||
if isinstance(value, text_type):
|
if isinstance(value, text_type):
|
||||||
value = parse_array(value)
|
value = parse_array(value)
|
||||||
elif isinstance(value, binary_type):
|
elif isinstance(value, binary_type):
|
||||||
value = parse_array(value.decode('UTF-8'))
|
value = parse_array(value.decode('UTF-8'))
|
||||||
elif not isinstance(value, (list, tuple)):
|
elif not isinstance(value, (list, tuple)):
|
||||||
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
|
raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
|
||||||
return [self.inner_field.to_python(v) for v in value]
|
return [self.inner_field.to_python(v, timezone_in_use) for v in value]
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
for v in value:
|
for v in value:
|
||||||
|
|
|
@ -3,6 +3,7 @@ from .engines import *
|
||||||
from .fields import Field
|
from .fields import Field
|
||||||
|
|
||||||
from six import with_metaclass
|
from six import with_metaclass
|
||||||
|
import pytz
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
logger = getLogger('clickhouse_orm')
|
logger = getLogger('clickhouse_orm')
|
||||||
|
@ -96,7 +97,7 @@ class Model(with_metaclass(ModelBase)):
|
||||||
'''
|
'''
|
||||||
field = self.get_field(name)
|
field = self.get_field(name)
|
||||||
if field:
|
if field:
|
||||||
value = field.to_python(value)
|
value = field.to_python(value, pytz.utc)
|
||||||
field.validate(value)
|
field.validate(value)
|
||||||
super(Model, self).__setattr__(name, value)
|
super(Model, self).__setattr__(name, value)
|
||||||
|
|
||||||
|
@ -136,7 +137,7 @@ class Model(with_metaclass(ModelBase)):
|
||||||
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db_name, cls.table_name())
|
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db_name, cls.table_name())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tsv(cls, line, field_names=None):
|
def from_tsv(cls, line, field_names=None, timezone_in_use=pytz.utc):
|
||||||
'''
|
'''
|
||||||
Create a model instance from a tab-separated line. The line may or may not include a newline.
|
Create a model instance from a tab-separated line. The line may or may not include a newline.
|
||||||
The field_names list must match the fields defined in the model, but does not have to include all of them.
|
The field_names list must match the fields defined in the model, but does not have to include all of them.
|
||||||
|
@ -147,7 +148,8 @@ class Model(with_metaclass(ModelBase)):
|
||||||
values = iter(parse_tsv(line))
|
values = iter(parse_tsv(line))
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for name in field_names:
|
for name in field_names:
|
||||||
kwargs[name] = next(values)
|
field = getattr(cls, name)
|
||||||
|
kwargs[name] = field.to_python(next(values), timezone_in_use)
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
def to_tsv(self, insertable_only=False):
|
def to_tsv(self, insertable_only=False):
|
||||||
|
|
53
tests/test_simple_fields.py
Normal file
53
tests/test_simple_fields.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import unittest
|
||||||
|
from infi.clickhouse_orm.fields import *
|
||||||
|
from datetime import date, datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleFieldsTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_date_field(self):
|
||||||
|
f = DateField()
|
||||||
|
# Valid values
|
||||||
|
for value in (date(1970, 1, 1), datetime(1970, 1, 1), '1970-01-01', '0000-00-00', 0):
|
||||||
|
self.assertEquals(f.to_python(value, pytz.utc), date(1970, 1, 1))
|
||||||
|
# Invalid values
|
||||||
|
for value in ('nope', '21/7/1999', 0.5):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
f.to_python(value, pytz.utc)
|
||||||
|
# Range check
|
||||||
|
for value in (date(1900, 1, 1), date(2900, 1, 1)):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
f.validate(value)
|
||||||
|
|
||||||
|
def test_datetime_field(self):
|
||||||
|
f = DateTimeField()
|
||||||
|
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
|
||||||
|
# Valid values
|
||||||
|
for value in (date(1970, 1, 1), datetime(1970, 1, 1), epoch,
|
||||||
|
epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
|
||||||
|
'1970-01-01 00:00:00', '0000-00-00 00:00:00', 0):
|
||||||
|
dt = f.to_python(value, pytz.utc)
|
||||||
|
self.assertEquals(dt.tzinfo, pytz.utc)
|
||||||
|
self.assertEquals(dt, epoch)
|
||||||
|
# Verify that conversion to and from db string does not change value
|
||||||
|
dt2 = f.to_python(int(f.to_db_string(dt)), pytz.utc)
|
||||||
|
self.assertEquals(dt, dt2)
|
||||||
|
# Invalid values
|
||||||
|
for value in ('nope', '21/7/1999', 0.5):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
f.to_python(value, pytz.utc)
|
||||||
|
|
||||||
|
def test_uint8_field(self):
|
||||||
|
f = UInt8Field()
|
||||||
|
# Valid values
|
||||||
|
for value in (17, '17', 17.0):
|
||||||
|
self.assertEquals(f.to_python(value, pytz.utc), 17)
|
||||||
|
# Invalid values
|
||||||
|
for value in ('nope', date.today()):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
f.to_python(value, pytz.utc)
|
||||||
|
# Range check
|
||||||
|
for value in (-1, 1000):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
f.validate(value)
|
Loading…
Reference in New Issue
Block a user