Always keep datetime fields in UTC internally, and convert server timezone to UTC when parsing query results.

This commit is contained in:
Itai Shirav 2017-02-07 15:25:16 +02:00
parent a73a69ef52
commit f29d737f29
4 changed files with 91 additions and 21 deletions

View File

@ -4,9 +4,12 @@ from .models import ModelBase
from .utils import escape, parse_tsv, import_submodules from .utils import escape, parse_tsv, import_submodules
from math import ceil from math import ceil
import datetime import datetime
import logging
from string import Template from string import Template
from six import PY3, string_types from six import PY3, string_types
import pytz
import logging
logger = logging.getLogger('clickhouse_orm')
Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size')
@ -26,6 +29,7 @@ class Database(object):
self.readonly = readonly self.readonly = readonly
if not self.readonly: if not self.readonly:
self.create_database() self.create_database()
self.server_timezone = self._get_server_timezone()
def create_database(self): def create_database(self):
self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name) self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
@ -82,7 +86,7 @@ class Database(object):
field_types = parse_tsv(next(lines)) field_types = parse_tsv(next(lines))
model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types)) model_class = model_class or ModelBase.create_ad_hoc_model(zip(field_names, field_types))
for line in lines: for line in lines:
yield model_class.from_tsv(line, field_names) yield model_class.from_tsv(line, field_names, self.server_timezone)
def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None): def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
count = self.count(model_class, conditions) count = self.count(model_class, conditions)
@ -154,3 +158,11 @@ class Database(object):
mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name()) mapping['table'] = "`%s`.`%s`" % (self.db_name, model_class.table_name())
query = Template(query).substitute(mapping) query = Template(query).substitute(mapping)
return query return query
def _get_server_timezone(self):
try:
r = self._send('SELECT timezone()')
return pytz.timezone(r.text.strip())
except DatabaseException:
logger.exception('Cannot determine server timezone, assuming UTC')
return pytz.utc

View File

@ -2,6 +2,7 @@ from six import string_types, text_type, binary_type
import datetime import datetime
import pytz import pytz
import time import time
from calendar import timegm
from .utils import escape, parse_array from .utils import escape, parse_array
@ -24,10 +25,11 @@ class Field(object):
self.alias = alias self.alias = alias
self.materialized = materialized self.materialized = materialized
def to_python(self, value): def to_python(self, value, timezone_in_use):
''' '''
Converts the input value into the expected Python data type, raising ValueError if the Converts the input value into the expected Python data type, raising ValueError if the
data can't be converted. Returns the converted value. Subclasses should override this. data can't be converted. Returns the converted value. Subclasses should override this.
The timezone_in_use parameter should be consulted when parsing datetime fields.
''' '''
return value return value
@ -77,7 +79,7 @@ class StringField(Field):
class_default = '' class_default = ''
db_type = 'String' db_type = 'String'
def to_python(self, value): def to_python(self, value, timezone_in_use):
if isinstance(value, text_type): if isinstance(value, text_type):
return value return value
if isinstance(value, binary_type): if isinstance(value, binary_type):
@ -92,11 +94,11 @@ class DateField(Field):
class_default = min_value class_default = min_value
db_type = 'Date' db_type = 'Date'
def to_python(self, value): def to_python(self, value, timezone_in_use):
if isinstance(value, datetime.date):
return value
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return value.date() return value.date()
if isinstance(value, datetime.date):
return value
if isinstance(value, int): if isinstance(value, int):
return DateField.class_default + datetime.timedelta(days=value) return DateField.class_default + datetime.timedelta(days=value)
if isinstance(value, string_types): if isinstance(value, string_types):
@ -117,26 +119,27 @@ class DateTimeField(Field):
class_default = datetime.datetime.fromtimestamp(0, pytz.utc) class_default = datetime.datetime.fromtimestamp(0, pytz.utc)
db_type = 'DateTime' db_type = 'DateTime'
def to_python(self, value): def to_python(self, value, timezone_in_use):
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return value return value.astimezone(pytz.utc) if value.tzinfo else value.replace(tzinfo=pytz.utc)
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day) return datetime.datetime(value.year, value.month, value.day, tzinfo=pytz.utc)
if isinstance(value, int): if isinstance(value, int):
return datetime.datetime.fromtimestamp(value, pytz.utc) return datetime.datetime.utcfromtimestamp(value).replace(tzinfo=pytz.utc)
if isinstance(value, string_types): if isinstance(value, string_types):
if value == '0000-00-00 00:00:00': if value == '0000-00-00 00:00:00':
return self.class_default return self.class_default
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') dt = datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
return timezone_in_use.localize(dt).astimezone(pytz.utc)
raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value)) raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
def to_db_string(self, value, quote=True): def to_db_string(self, value, quote=True):
return escape(int(time.mktime(value.timetuple())), quote) return escape(timegm(value.utctimetuple()), quote)
class BaseIntField(Field): class BaseIntField(Field):
def to_python(self, value): def to_python(self, value, timezone_in_use):
try: try:
return int(value) return int(value)
except: except:
@ -204,7 +207,7 @@ class Int64Field(BaseIntField):
class BaseFloatField(Field): class BaseFloatField(Field):
def to_python(self, value): def to_python(self, value, timezone_in_use):
try: try:
return float(value) return float(value)
except: except:
@ -229,7 +232,7 @@ class BaseEnumField(Field):
default = list(enum_cls)[0] default = list(enum_cls)[0]
super(BaseEnumField, self).__init__(default, alias, materialized) super(BaseEnumField, self).__init__(default, alias, materialized)
def to_python(self, value): def to_python(self, value, timezone_in_use):
if isinstance(value, self.enum_cls): if isinstance(value, self.enum_cls):
return value return value
try: try:
@ -291,14 +294,14 @@ class ArrayField(Field):
self.inner_field = inner_field self.inner_field = inner_field
super(ArrayField, self).__init__(default, alias, materialized) super(ArrayField, self).__init__(default, alias, materialized)
def to_python(self, value): def to_python(self, value, timezone_in_use):
if isinstance(value, text_type): if isinstance(value, text_type):
value = parse_array(value) value = parse_array(value)
elif isinstance(value, binary_type): elif isinstance(value, binary_type):
value = parse_array(value.decode('UTF-8')) value = parse_array(value.decode('UTF-8'))
elif not isinstance(value, (list, tuple)): elif not isinstance(value, (list, tuple)):
raise ValueError('ArrayField expects list or tuple, not %s' % type(value)) raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
return [self.inner_field.to_python(v) for v in value] return [self.inner_field.to_python(v, timezone_in_use) for v in value]
def validate(self, value): def validate(self, value):
for v in value: for v in value:

View File

@ -3,6 +3,7 @@ from .engines import *
from .fields import Field from .fields import Field
from six import with_metaclass from six import with_metaclass
import pytz
from logging import getLogger from logging import getLogger
logger = getLogger('clickhouse_orm') logger = getLogger('clickhouse_orm')
@ -96,7 +97,7 @@ class Model(with_metaclass(ModelBase)):
''' '''
field = self.get_field(name) field = self.get_field(name)
if field: if field:
value = field.to_python(value) value = field.to_python(value, pytz.utc)
field.validate(value) field.validate(value)
super(Model, self).__setattr__(name, value) super(Model, self).__setattr__(name, value)
@ -136,7 +137,7 @@ class Model(with_metaclass(ModelBase)):
return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db_name, cls.table_name()) return 'DROP TABLE IF EXISTS `%s`.`%s`' % (db_name, cls.table_name())
@classmethod @classmethod
def from_tsv(cls, line, field_names=None): def from_tsv(cls, line, field_names=None, timezone_in_use=pytz.utc):
''' '''
Create a model instance from a tab-separated line. The line may or may not include a newline. Create a model instance from a tab-separated line. The line may or may not include a newline.
The field_names list must match the fields defined in the model, but does not have to include all of them. The field_names list must match the fields defined in the model, but does not have to include all of them.
@ -147,7 +148,8 @@ class Model(with_metaclass(ModelBase)):
values = iter(parse_tsv(line)) values = iter(parse_tsv(line))
kwargs = {} kwargs = {}
for name in field_names: for name in field_names:
kwargs[name] = next(values) field = getattr(cls, name)
kwargs[name] = field.to_python(next(values), timezone_in_use)
return cls(**kwargs) return cls(**kwargs)
def to_tsv(self, insertable_only=False): def to_tsv(self, insertable_only=False):

View File

@ -0,0 +1,53 @@
import unittest
from infi.clickhouse_orm.fields import *
from datetime import date, datetime
import pytz
class SimpleFieldsTest(unittest.TestCase):
def test_date_field(self):
f = DateField()
# Valid values
for value in (date(1970, 1, 1), datetime(1970, 1, 1), '1970-01-01', '0000-00-00', 0):
self.assertEquals(f.to_python(value, pytz.utc), date(1970, 1, 1))
# Invalid values
for value in ('nope', '21/7/1999', 0.5):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
# Range check
for value in (date(1900, 1, 1), date(2900, 1, 1)):
with self.assertRaises(ValueError):
f.validate(value)
def test_datetime_field(self):
f = DateTimeField()
epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
# Valid values
for value in (date(1970, 1, 1), datetime(1970, 1, 1), epoch,
epoch.astimezone(pytz.timezone('US/Eastern')), epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
'1970-01-01 00:00:00', '0000-00-00 00:00:00', 0):
dt = f.to_python(value, pytz.utc)
self.assertEquals(dt.tzinfo, pytz.utc)
self.assertEquals(dt, epoch)
# Verify that conversion to and from db string does not change value
dt2 = f.to_python(int(f.to_db_string(dt)), pytz.utc)
self.assertEquals(dt, dt2)
# Invalid values
for value in ('nope', '21/7/1999', 0.5):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
def test_uint8_field(self):
f = UInt8Field()
# Valid values
for value in (17, '17', 17.0):
self.assertEquals(f.to_python(value, pytz.utc), 17)
# Invalid values
for value in ('nope', date.today()):
with self.assertRaises(ValueError):
f.to_python(value, pytz.utc)
# Range check
for value in (-1, 1000):
with self.assertRaises(ValueError):
f.validate(value)