mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2024-11-22 09:06:41 +03:00
model field conversion on assignment
This commit is contained in:
parent
25e85adc0d
commit
b08f1b3688
|
@ -40,7 +40,7 @@ class Database(object):
|
||||||
self._send(gen())
|
self._send(gen())
|
||||||
|
|
||||||
def count(self, model_class, conditions=None):
|
def count(self, model_class, conditions=None):
|
||||||
query = 'SELECT uniq(height) FROM %s.%s' % (self.db_name, model_class.table_name())
|
query = 'SELECT count() FROM %s.%s' % (self.db_name, model_class.table_name())
|
||||||
if conditions:
|
if conditions:
|
||||||
query += ' WHERE ' + conditions
|
query += ' WHERE ' + conditions
|
||||||
r = self._send(query)
|
r = self._send(query)
|
||||||
|
|
|
@ -38,7 +38,7 @@ class StringField(Field):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value.decode('UTF-8')
|
return value.decode('UTF-8')
|
||||||
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def get_db_prep_value(self, value):
|
||||||
if isinstance(value, unicode):
|
if isinstance(value, unicode):
|
||||||
|
@ -58,7 +58,7 @@ class DateField(Field):
|
||||||
return DateField.class_default + datetime.timedelta(days=value)
|
return DateField.class_default + datetime.timedelta(days=value)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
|
||||||
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def get_db_prep_value(self, value):
|
||||||
return value.isoformat()
|
return value.isoformat()
|
||||||
|
@ -78,7 +78,7 @@ class DateTimeField(Field):
|
||||||
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
return datetime.datetime.fromtimestamp(value, pytz.utc)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
return datetime.datetime.strptime(value, '%Y-%m-%d %H-%M-%S')
|
return datetime.datetime.strptime(value, '%Y-%m-%d %H-%M-%S')
|
||||||
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def get_db_prep_value(self, value):
|
||||||
return int(time.mktime(value.timetuple()))
|
return int(time.mktime(value.timetuple()))
|
||||||
|
@ -91,7 +91,7 @@ class BaseIntField(Field):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
return int(value)
|
return int(value)
|
||||||
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
|
|
||||||
class UInt8Field(BaseIntField):
|
class UInt8Field(BaseIntField):
|
||||||
|
@ -139,9 +139,9 @@ class BaseFloatField(Field):
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
if isinstance(value, float):
|
if isinstance(value, float):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring) or isinstance(value, int):
|
||||||
return float(value)
|
return float(value)
|
||||||
raise ValueError('Invalid value for %s: %r', self.__class__.__name__, value)
|
raise ValueError('Invalid value for %s: %r' % (self.__class__.__name__, value))
|
||||||
|
|
||||||
|
|
||||||
class Float32Field(BaseFloatField):
|
class Float32Field(BaseFloatField):
|
||||||
|
|
|
@ -4,10 +4,12 @@ from engines import *
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(type):
|
class ModelBase(type):
|
||||||
|
'''
|
||||||
|
A metaclass for ORM models. It adds the _fields list to model classes.
|
||||||
|
'''
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
|
new_cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
|
||||||
#print name, bases, attrs
|
|
||||||
# Build a list of fields, in the order they were listed in the class
|
# Build a list of fields, in the order they were listed in the class
|
||||||
fields = [item for item in attrs.items() if isinstance(item[1], Field)]
|
fields = [item for item in attrs.items() if isinstance(item[1], Field)]
|
||||||
fields.sort(key=lambda item: item[1].creation_counter)
|
fields.sort(key=lambda item: item[1].creation_counter)
|
||||||
|
@ -16,16 +18,34 @@ class ModelBase(type):
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
|
'''
|
||||||
|
A base class for ORM models.
|
||||||
|
'''
|
||||||
|
|
||||||
__metaclass__ = ModelBase
|
__metaclass__ = ModelBase
|
||||||
engine = None
|
engine = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
'''
|
||||||
|
Creates a model instance, using keyword arguments as field values.
|
||||||
|
Since values are immediately converted to their Pythonic type,
|
||||||
|
invalid values will cause a ValueError to be raised.
|
||||||
|
'''
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
for name, field in self._fields:
|
for name, field in self._fields:
|
||||||
val = kwargs.get(name, field.default)
|
val = kwargs.get(name, field.default)
|
||||||
setattr(self, name, val)
|
setattr(self, name, val)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
field = self.get_field(name)
|
||||||
|
if field:
|
||||||
|
value = field.to_python(value)
|
||||||
|
super(Model, self).__setattr__(name, value)
|
||||||
|
|
||||||
|
def get_field(self, name):
|
||||||
|
field = getattr(self.__class__, name, None)
|
||||||
|
return field if isinstance(field, Field) else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def table_name(cls):
|
def table_name(cls):
|
||||||
return cls.__name__.lower()
|
return cls.__name__.lower()
|
||||||
|
@ -54,8 +74,9 @@ class Model(object):
|
||||||
values = iter(parse_tsv(line))
|
values = iter(parse_tsv(line))
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for name, field in cls._fields:
|
for name, field in cls._fields:
|
||||||
kwargs[name] = field.to_python(values.next())
|
kwargs[name] = values.next()
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
# TODO verify that the number of values matches the number of fields
|
||||||
|
|
||||||
def to_tsv(self):
|
def to_tsv(self):
|
||||||
'''
|
'''
|
||||||
|
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
|
@ -6,7 +6,7 @@ from infi.clickhouse_orm.fields import *
|
||||||
from infi.clickhouse_orm.engines import *
|
from infi.clickhouse_orm.engines import *
|
||||||
|
|
||||||
|
|
||||||
class ORMTestCase(unittest.TestCase):
|
class DatabaseTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.database = Database('test_db')
|
self.database = Database('test_db')
|
54
tests/test_models.py
Normal file
54
tests/test_models.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
import unittest
|
||||||
|
import datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
from infi.clickhouse_orm.models import Model
|
||||||
|
from infi.clickhouse_orm.fields import *
|
||||||
|
from infi.clickhouse_orm.engines import *
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
# Check that all fields have their defaults
|
||||||
|
instance = SimpleModel()
|
||||||
|
self.assertEquals(instance.date_field, datetime.date(1970, 1, 1))
|
||||||
|
self.assertEquals(instance.datetime_field, datetime.datetime(1970, 1, 1, tzinfo=pytz.utc))
|
||||||
|
self.assertEquals(instance.str_field, 'dozo')
|
||||||
|
self.assertEquals(instance.int_field, 17)
|
||||||
|
self.assertEquals(instance.float_field, 0)
|
||||||
|
|
||||||
|
def test_assignment(self):
|
||||||
|
# Check that all fields are assigned during construction
|
||||||
|
kwargs = dict(
|
||||||
|
date_field=datetime.date(1973, 12, 6),
|
||||||
|
datetime_field=datetime.datetime(2000, 5, 24, 10, 22, tzinfo=pytz.utc),
|
||||||
|
str_field='aloha',
|
||||||
|
int_field=-50,
|
||||||
|
float_field=3.14
|
||||||
|
)
|
||||||
|
instance = SimpleModel(**kwargs)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
self.assertEquals(kwargs[name], getattr(instance, name))
|
||||||
|
|
||||||
|
def test_string_conversion(self):
|
||||||
|
# Check field conversion from string during construction
|
||||||
|
instance = SimpleModel(date_field='1973-12-06', int_field='100', float_field='7')
|
||||||
|
self.assertEquals(instance.date_field, datetime.date(1973, 12, 6))
|
||||||
|
self.assertEquals(instance.int_field, 100)
|
||||||
|
self.assertEquals(instance.float_field, 7)
|
||||||
|
# Check field conversion from string during assignment
|
||||||
|
instance.int_field = '99'
|
||||||
|
self.assertEquals(instance.int_field, 99)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(Model):
|
||||||
|
|
||||||
|
date_field = DateField()
|
||||||
|
datetime_field = DateTimeField()
|
||||||
|
str_field = StringField(default='dozo')
|
||||||
|
int_field = Int32Field(default=17)
|
||||||
|
float_field = Float32Field()
|
||||||
|
|
||||||
|
engine = MergeTree('date_field', ('int_field', 'date_field'))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user