mirror of
				https://github.com/Infinidat/infi.clickhouse_orm.git
				synced 2025-11-04 01:37:34 +03:00 
			
		
		
		
	
						commit
						e37a4cebb1
					
				
							
								
								
									
										58
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								README.rst
									
									
									
									
									
								
							| 
						 | 
					@ -34,6 +34,20 @@ It is possible to provide a default value for a field, instead of its "natural"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
See below for the supported field types and table engines.
 | 
					See below for the supported field types and table engines.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Table Names
 | 
				
			||||||
 | 
					***********
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The table name used for the model is its class name, converted to lowercase. To override the default name,
 | 
				
			||||||
 | 
					implement the ``table_name`` method::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Person(models.Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @classmethod
 | 
				
			||||||
 | 
					        def table_name(cls):
 | 
				
			||||||
 | 
					            return 'people'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Using Models
 | 
					Using Models
 | 
				
			||||||
------------
 | 
					------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,7 +165,7 @@ The ``paginate`` method returns a ``namedtuple`` containing the following fields
 | 
				
			||||||
- ``objects`` - the list of objects in this page
 | 
					- ``objects`` - the list of objects in this page
 | 
				
			||||||
- ``number_of_objects`` - total number of objects in all pages
 | 
					- ``number_of_objects`` - total number of objects in all pages
 | 
				
			||||||
- ``pages_total`` - total number of pages
 | 
					- ``pages_total`` - total number of pages
 | 
				
			||||||
- ``number`` - the page number
 | 
					- ``number`` - the page number, starting from 1; the special value -1 may be used to retrieve the last page
 | 
				
			||||||
- ``page_size`` - the number of objects per page
 | 
					- ``page_size`` - the number of objects per page
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can optionally pass conditions to the query::
 | 
					You can optionally pass conditions to the query::
 | 
				
			||||||
| 
						 | 
					@ -191,8 +205,50 @@ UInt32Field    UInt32      int                Range 0 to 4294967295
 | 
				
			||||||
UInt64Field    UInt64      int/long           Range 0 to 18446744073709551615
 | 
					UInt64Field    UInt64      int/long           Range 0 to 18446744073709551615
 | 
				
			||||||
Float32Field   Float32     float
 | 
					Float32Field   Float32     float
 | 
				
			||||||
Float64Field   Float64     float
 | 
					Float64Field   Float64     float
 | 
				
			||||||
 | 
					Enum8Field     Enum8       Enum               See below
 | 
				
			||||||
 | 
					Enum16Field    Enum16      Enum               See below
 | 
				
			||||||
 | 
					ArrayField     Array       list               See below
 | 
				
			||||||
=============  ========    =================  ===================================================
 | 
					=============  ========    =================  ===================================================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Working with enum fields
 | 
				
			||||||
 | 
					************************
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					``Enum8Field`` and ``Enum16Field`` provide support for working with ClickHouse enum columns. They accept
 | 
				
			||||||
 | 
					strings or integers as values, and convert them to the matching Pythonic Enum member.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Python 3.4 and higher supports Enums natively. When using previous Python versions you 
 | 
				
			||||||
 | 
					need to install the `enum34` library.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Example of a model with an enum field::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Gender = Enum('Gender', 'male female unspecified')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Person(models.Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        first_name = fields.StringField()
 | 
				
			||||||
 | 
					        last_name = fields.StringField()
 | 
				
			||||||
 | 
					        birthday = fields.DateField()
 | 
				
			||||||
 | 
					        gender = fields.Enum32Field(Gender)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        engine = engines.MergeTree('birthday', ('first_name', 'last_name', 'birthday'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    suzy = Person(first_name='Suzy', last_name='Jones', gender=Gender.female)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Working with array fields
 | 
				
			||||||
 | 
					*************************
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					You can create array fields containing any data type, for example::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class SensorData(models.Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        date = fields.DateField()
 | 
				
			||||||
 | 
					        temperatures = fields.ArrayField(fields.Float32Field())
 | 
				
			||||||
 | 
					        humidity_levels = fields.ArrayField(fields.UInt8Field())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        engine = engines.MergeTree('date', ('date',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data = SensorData(date=date.today(), temperatures=[25.5, 31.2, 28.7], humidity_levels=[41, 39, 66])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Table Engines
 | 
					Table Engines
 | 
				
			||||||
-------------
 | 
					-------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,8 @@ recipe = infi.recipe.console_scripts
 | 
				
			||||||
eggs = ${project:name}
 | 
					eggs = ${project:name}
 | 
				
			||||||
	ipython
 | 
						ipython
 | 
				
			||||||
	nose
 | 
						nose
 | 
				
			||||||
 | 
						coverage
 | 
				
			||||||
 | 
						enum34
 | 
				
			||||||
	infi.unittest
 | 
						infi.unittest
 | 
				
			||||||
	infi.traceback
 | 
						infi.traceback
 | 
				
			||||||
	zc.buildout
 | 
						zc.buildout
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,12 +18,20 @@ class DatabaseException(Exception):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Database(object):
 | 
					class Database(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, db_name, db_url='http://localhost:8123/', username=None, password=None):
 | 
					    def __init__(self, db_name, db_url='http://localhost:8123/', username=None, password=None, readonly=False):
 | 
				
			||||||
        self.db_name = db_name
 | 
					        self.db_name = db_name
 | 
				
			||||||
        self.db_url = db_url
 | 
					        self.db_url = db_url
 | 
				
			||||||
        self.username = username
 | 
					        self.username = username
 | 
				
			||||||
        self.password = password
 | 
					        self.password = password
 | 
				
			||||||
        self._send('CREATE DATABASE IF NOT EXISTS `%s`' % db_name)
 | 
					        self.readonly = readonly
 | 
				
			||||||
 | 
					        if not self.readonly:
 | 
				
			||||||
 | 
					            self.create_database()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_database(self):
 | 
				
			||||||
 | 
					        self._send('CREATE DATABASE IF NOT EXISTS `%s`' % self.db_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def drop_database(self):
 | 
				
			||||||
 | 
					        self._send('DROP DATABASE `%s`' % self.db_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_table(self, model_class):
 | 
					    def create_table(self, model_class):
 | 
				
			||||||
        # TODO check that model has an engine
 | 
					        # TODO check that model has an engine
 | 
				
			||||||
| 
						 | 
					@ -32,10 +40,7 @@ class Database(object):
 | 
				
			||||||
    def drop_table(self, model_class):
 | 
					    def drop_table(self, model_class):
 | 
				
			||||||
        self._send(model_class.drop_table_sql(self.db_name))
 | 
					        self._send(model_class.drop_table_sql(self.db_name))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def drop_database(self):
 | 
					    def insert(self, model_instances, batch_size=1000):
 | 
				
			||||||
        self._send('DROP DATABASE `%s`' % self.db_name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def insert(self, model_instances):
 | 
					 | 
				
			||||||
        from six import next
 | 
					        from six import next
 | 
				
			||||||
        i = iter(model_instances)
 | 
					        i = iter(model_instances)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -45,11 +50,19 @@ class Database(object):
 | 
				
			||||||
        model_class = first_instance.__class__
 | 
					        model_class = first_instance.__class__
 | 
				
			||||||
        def gen():
 | 
					        def gen():
 | 
				
			||||||
            yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
 | 
					            yield self._substitute('INSERT INTO $table FORMAT TabSeparated\n', model_class).encode('utf-8')
 | 
				
			||||||
            yield first_instance.to_tsv().encode('utf-8')
 | 
					            yield (first_instance.to_tsv() + '\n').encode('utf-8')
 | 
				
			||||||
            yield '\n'.encode('utf-8')
 | 
					            # Collect lines in batches of batch_size
 | 
				
			||||||
 | 
					            batch = []
 | 
				
			||||||
            for instance in i:
 | 
					            for instance in i:
 | 
				
			||||||
                yield instance.to_tsv().encode('utf-8')
 | 
					                batch.append(instance.to_tsv())
 | 
				
			||||||
                yield '\n'.encode('utf-8')
 | 
					                if len(batch) >= batch_size:
 | 
				
			||||||
 | 
					                    # Return the current batch of lines
 | 
				
			||||||
 | 
					                    yield ('\n'.join(batch) + '\n').encode('utf-8')
 | 
				
			||||||
 | 
					                    # Start a new batch
 | 
				
			||||||
 | 
					                    batch = []
 | 
				
			||||||
 | 
					            # Return any remaining lines in partial batch
 | 
				
			||||||
 | 
					            if batch:
 | 
				
			||||||
 | 
					                yield ('\n'.join(batch) + '\n').encode('utf-8')
 | 
				
			||||||
        self._send(gen())
 | 
					        self._send(gen())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def count(self, model_class, conditions=None):
 | 
					    def count(self, model_class, conditions=None):
 | 
				
			||||||
| 
						 | 
					@ -74,6 +87,10 @@ class Database(object):
 | 
				
			||||||
    def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
 | 
					    def paginate(self, model_class, order_by, page_num=1, page_size=100, conditions=None, settings=None):
 | 
				
			||||||
        count = self.count(model_class, conditions)
 | 
					        count = self.count(model_class, conditions)
 | 
				
			||||||
        pages_total = int(ceil(count / float(page_size)))
 | 
					        pages_total = int(ceil(count / float(page_size)))
 | 
				
			||||||
 | 
					        if page_num == -1:
 | 
				
			||||||
 | 
					            page_num = pages_total
 | 
				
			||||||
 | 
					        elif page_num < 1:
 | 
				
			||||||
 | 
					            raise ValueError('Invalid page number: %d' % page_num)
 | 
				
			||||||
        offset = (page_num - 1) * page_size
 | 
					        offset = (page_num - 1) * page_size
 | 
				
			||||||
        query = 'SELECT * FROM $table'
 | 
					        query = 'SELECT * FROM $table'
 | 
				
			||||||
        if conditions:
 | 
					        if conditions:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,8 @@ import datetime
 | 
				
			||||||
import pytz
 | 
					import pytz
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .utils import escape, parse_array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Field(object):
 | 
					class Field(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +15,7 @@ class Field(object):
 | 
				
			||||||
    def __init__(self, default=None):
 | 
					    def __init__(self, default=None):
 | 
				
			||||||
        self.creation_counter = Field.creation_counter
 | 
					        self.creation_counter = Field.creation_counter
 | 
				
			||||||
        Field.creation_counter += 1
 | 
					        Field.creation_counter += 1
 | 
				
			||||||
        self.default = default or self.class_default
 | 
					        self.default = self.class_default if default is None else default
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_python(self, value):
 | 
					    def to_python(self, value):
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
| 
						 | 
					@ -36,11 +38,22 @@ class Field(object):
 | 
				
			||||||
        if value < min_value or value > max_value:
 | 
					        if value < min_value or value > max_value:
 | 
				
			||||||
            raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
 | 
					            raise ValueError('%s out of range - %s is not between %s and %s' % (self.__class__.__name__, value, min_value, max_value))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_db_prep_value(self, value):
 | 
					    def to_db_string(self, value, quote=True):
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        Returns the field's value prepared for interacting with the database.
 | 
					        Returns the field's value prepared for writing to the database.
 | 
				
			||||||
 | 
					        When quote is true, strings are surrounded by single quotes.
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        return value
 | 
					        return escape(value, quote)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_sql(self, with_default=True):
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Returns an SQL expression describing the field (e.g. for CREATE TABLE).
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        if with_default:
 | 
				
			||||||
 | 
					            default = self.to_db_string(self.default)
 | 
				
			||||||
 | 
					            return '%s DEFAULT %s' % (self.db_type, default)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.db_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StringField(Field):
 | 
					class StringField(Field):
 | 
				
			||||||
| 
						 | 
					@ -66,6 +79,8 @@ class DateField(Field):
 | 
				
			||||||
    def to_python(self, value):
 | 
					    def to_python(self, value):
 | 
				
			||||||
        if isinstance(value, datetime.date):
 | 
					        if isinstance(value, datetime.date):
 | 
				
			||||||
            return value
 | 
					            return value
 | 
				
			||||||
 | 
					        if isinstance(value, datetime.datetime):
 | 
				
			||||||
 | 
					            return value.date()
 | 
				
			||||||
        if isinstance(value, int):
 | 
					        if isinstance(value, int):
 | 
				
			||||||
            return DateField.class_default + datetime.timedelta(days=value)
 | 
					            return DateField.class_default + datetime.timedelta(days=value)
 | 
				
			||||||
        if isinstance(value, string_types):
 | 
					        if isinstance(value, string_types):
 | 
				
			||||||
| 
						 | 
					@ -77,8 +92,8 @@ class DateField(Field):
 | 
				
			||||||
    def validate(self, value):
 | 
					    def validate(self, value):
 | 
				
			||||||
        self._range_check(value, DateField.min_value, DateField.max_value)
 | 
					        self._range_check(value, DateField.min_value, DateField.max_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_db_prep_value(self, value):
 | 
					    def to_db_string(self, value, quote=True):
 | 
				
			||||||
        return value.isoformat()
 | 
					        return escape(value.isoformat(), quote)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DateTimeField(Field):
 | 
					class DateTimeField(Field):
 | 
				
			||||||
| 
						 | 
					@ -94,11 +109,13 @@ class DateTimeField(Field):
 | 
				
			||||||
        if isinstance(value, int):
 | 
					        if isinstance(value, int):
 | 
				
			||||||
            return datetime.datetime.fromtimestamp(value, pytz.utc)
 | 
					            return datetime.datetime.fromtimestamp(value, pytz.utc)
 | 
				
			||||||
        if isinstance(value, string_types):
 | 
					        if isinstance(value, string_types):
 | 
				
			||||||
 | 
					            if value == '0000-00-00 00:00:00':
 | 
				
			||||||
 | 
					                return self.class_default
 | 
				
			||||||
            return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
 | 
					            return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
 | 
				
			||||||
        raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
 | 
					        raise ValueError('Invalid value for %s - %r' % (self.__class__.__name__, value))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_db_prep_value(self, value):
 | 
					    def to_db_string(self, value, quote=True):
 | 
				
			||||||
        return int(time.mktime(value.timetuple()))
 | 
					        return escape(int(time.mktime(value.timetuple())), quote)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseIntField(Field):
 | 
					class BaseIntField(Field):
 | 
				
			||||||
| 
						 | 
					@ -187,3 +204,94 @@ class Float64Field(BaseFloatField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    db_type = 'Float64'
 | 
					    db_type = 'Float64'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseEnumField(Field):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, enum_cls, default=None):
 | 
				
			||||||
 | 
					        self.enum_cls = enum_cls
 | 
				
			||||||
 | 
					        if default is None:
 | 
				
			||||||
 | 
					            default = list(enum_cls)[0]
 | 
				
			||||||
 | 
					        super(BaseEnumField, self).__init__(default)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_python(self, value):
 | 
				
			||||||
 | 
					        if isinstance(value, self.enum_cls):
 | 
				
			||||||
 | 
					            return value
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            if isinstance(value, text_type):
 | 
				
			||||||
 | 
					                return self.enum_cls[value]
 | 
				
			||||||
 | 
					            if isinstance(value, binary_type):
 | 
				
			||||||
 | 
					                return self.enum_cls[value.decode('UTF-8')]
 | 
				
			||||||
 | 
					            if isinstance(value, int):
 | 
				
			||||||
 | 
					                return self.enum_cls(value)
 | 
				
			||||||
 | 
					        except (KeyError, ValueError):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        raise ValueError('Invalid value for %s: %r' % (self.enum_cls.__name__, value))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_db_string(self, value, quote=True):
 | 
				
			||||||
 | 
					        return escape(value.name, quote)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_sql(self, with_default=True):
 | 
				
			||||||
 | 
					        values = ['%s = %d' % (escape(item.name), item.value) for item in self.enum_cls]
 | 
				
			||||||
 | 
					        sql = '%s(%s)' % (self.db_type, ' ,'.join(values))
 | 
				
			||||||
 | 
					        if with_default:
 | 
				
			||||||
 | 
					            default = self.to_db_string(self.default)
 | 
				
			||||||
 | 
					            sql = '%s DEFAULT %s' % (sql, default)
 | 
				
			||||||
 | 
					        return sql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def create_ad_hoc_field(cls, db_type):
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        Give an SQL column description such as "Enum8('apple' = 1, 'banana' = 2, 'orange' = 3)"
 | 
				
			||||||
 | 
					        this method returns a matching enum field.
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        import re
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            Enum # exists in Python 3.4+
 | 
				
			||||||
 | 
					        except NameError:
 | 
				
			||||||
 | 
					            from enum import Enum # use the enum34 library instead
 | 
				
			||||||
 | 
					        members = {}
 | 
				
			||||||
 | 
					        for match in re.finditer("'(\w+)' = (\d+)", db_type):
 | 
				
			||||||
 | 
					            members[match.group(1)] = int(match.group(2))
 | 
				
			||||||
 | 
					        enum_cls = Enum('AdHocEnum', members)
 | 
				
			||||||
 | 
					        field_class = Enum8Field if db_type.startswith('Enum8') else Enum16Field
 | 
				
			||||||
 | 
					        return field_class(enum_cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Enum8Field(BaseEnumField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    db_type = 'Enum8'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Enum16Field(BaseEnumField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    db_type = 'Enum16'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ArrayField(Field):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class_default = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, inner_field, default=None):
 | 
				
			||||||
 | 
					        self.inner_field = inner_field
 | 
				
			||||||
 | 
					        super(ArrayField, self).__init__(default)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_python(self, value):
 | 
				
			||||||
 | 
					        if isinstance(value, text_type):
 | 
				
			||||||
 | 
					            value = parse_array(value)
 | 
				
			||||||
 | 
					        elif isinstance(value, binary_type):
 | 
				
			||||||
 | 
					            value = parse_array(value.decode('UTF-8'))
 | 
				
			||||||
 | 
					        elif not isinstance(value, (list, tuple)):
 | 
				
			||||||
 | 
					            raise ValueError('ArrayField expects list or tuple, not %s' % type(value))
 | 
				
			||||||
 | 
					        return [self.inner_field.to_python(v) for v in value]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate(self, value):
 | 
				
			||||||
 | 
					        for v in value:
 | 
				
			||||||
 | 
					            self.inner_field.validate(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_db_string(self, value, quote=True):
 | 
				
			||||||
 | 
					        array = [self.inner_field.to_db_string(v, quote=True) for v in value]
 | 
				
			||||||
 | 
					        return '[' + ', '.join(array) + ']'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_sql(self, with_default=True):
 | 
				
			||||||
 | 
					        from .utils import escape
 | 
				
			||||||
 | 
					        return 'Array(%s)' % self.inner_field.get_sql(with_default=False)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -68,12 +68,11 @@ class AlterTable(Operation):
 | 
				
			||||||
            if name not in table_fields:
 | 
					            if name not in table_fields:
 | 
				
			||||||
                logger.info('        Add column %s', name)
 | 
					                logger.info('        Add column %s', name)
 | 
				
			||||||
                assert prev_name, 'Cannot add a column to the beginning of the table'
 | 
					                assert prev_name, 'Cannot add a column to the beginning of the table'
 | 
				
			||||||
                default = field.get_db_prep_value(field.default)
 | 
					                cmd = 'ADD COLUMN %s %s AFTER %s' % (name, field.get_sql(), prev_name)
 | 
				
			||||||
                cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name)
 | 
					 | 
				
			||||||
                self._alter_table(database, cmd)
 | 
					                self._alter_table(database, cmd)
 | 
				
			||||||
            prev_name = name
 | 
					            prev_name = name
 | 
				
			||||||
        # Identify fields whose type was changed
 | 
					        # Identify fields whose type was changed
 | 
				
			||||||
        model_fields = [(name, field.db_type) for name, field in self.model_class._fields]
 | 
					        model_fields = [(name, field.get_sql(with_default=False)) for name, field in self.model_class._fields]
 | 
				
			||||||
        for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
 | 
					        for model_field, table_field in zip(model_fields, self._get_table_fields(database)):
 | 
				
			||||||
            assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
 | 
					            assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement'
 | 
				
			||||||
            if model_field[1] != table_field[1]:
 | 
					            if model_field[1] != table_field[1]:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,9 @@ from .fields import Field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from six import with_metaclass
 | 
					from six import with_metaclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from logging import getLogger
 | 
				
			||||||
 | 
					logger = getLogger('clickhouse_orm')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelBase(type):
 | 
					class ModelBase(type):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
| 
						 | 
					@ -28,7 +31,6 @@ class ModelBase(type):
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def create_ad_hoc_model(cls, fields):
 | 
					    def create_ad_hoc_model(cls, fields):
 | 
				
			||||||
        # fields is a list of tuples (name, db_type)
 | 
					        # fields is a list of tuples (name, db_type)
 | 
				
			||||||
        import infi.clickhouse_orm.fields as orm_fields
 | 
					 | 
				
			||||||
        # Check if model exists in cache
 | 
					        # Check if model exists in cache
 | 
				
			||||||
        fields = list(fields)
 | 
					        fields = list(fields)
 | 
				
			||||||
        cache_key = str(fields)
 | 
					        cache_key = str(fields)
 | 
				
			||||||
| 
						 | 
					@ -37,15 +39,28 @@ class ModelBase(type):
 | 
				
			||||||
        # Create an ad hoc model class
 | 
					        # Create an ad hoc model class
 | 
				
			||||||
        attrs = {}
 | 
					        attrs = {}
 | 
				
			||||||
        for name, db_type in fields:
 | 
					        for name, db_type in fields:
 | 
				
			||||||
            field_class = db_type + 'Field'
 | 
					            attrs[name] = cls.create_ad_hoc_field(db_type)
 | 
				
			||||||
            if not hasattr(orm_fields, field_class):
 | 
					 | 
				
			||||||
                raise NotImplementedError('No field class for %s' % db_type)
 | 
					 | 
				
			||||||
            attrs[name] = getattr(orm_fields, field_class)()
 | 
					 | 
				
			||||||
        model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
 | 
					        model_class = cls.__new__(cls, 'AdHocModel', (Model,), attrs)
 | 
				
			||||||
        # Add the model class to the cache
 | 
					        # Add the model class to the cache
 | 
				
			||||||
        cls.ad_hoc_model_cache[cache_key] = model_class
 | 
					        cls.ad_hoc_model_cache[cache_key] = model_class
 | 
				
			||||||
        return model_class
 | 
					        return model_class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def create_ad_hoc_field(cls, db_type):
 | 
				
			||||||
 | 
					        import infi.clickhouse_orm.fields as orm_fields
 | 
				
			||||||
 | 
					        # Enums
 | 
				
			||||||
 | 
					        if db_type.startswith('Enum'):
 | 
				
			||||||
 | 
					            return orm_fields.BaseEnumField.create_ad_hoc_field(db_type)
 | 
				
			||||||
 | 
					        # Arrays
 | 
				
			||||||
 | 
					        if db_type.startswith('Array'):
 | 
				
			||||||
 | 
					            inner_field = cls.create_ad_hoc_field(db_type[6 : -1])
 | 
				
			||||||
 | 
					            return orm_fields.ArrayField(inner_field)
 | 
				
			||||||
 | 
					        # Simple fields
 | 
				
			||||||
 | 
					        name = db_type + 'Field'
 | 
				
			||||||
 | 
					        if not hasattr(orm_fields, name):
 | 
				
			||||||
 | 
					            raise NotImplementedError('No field class for %s' % db_type)
 | 
				
			||||||
 | 
					        return getattr(orm_fields, name)()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Model(with_metaclass(ModelBase)):
 | 
					class Model(with_metaclass(ModelBase)):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
| 
						 | 
					@ -107,8 +122,7 @@ class Model(with_metaclass(ModelBase)):
 | 
				
			||||||
        parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
 | 
					        parts = ['CREATE TABLE IF NOT EXISTS `%s`.`%s` (' % (db_name, cls.table_name())]
 | 
				
			||||||
        cols = []
 | 
					        cols = []
 | 
				
			||||||
        for name, field in cls._fields:
 | 
					        for name, field in cls._fields:
 | 
				
			||||||
            default = field.get_db_prep_value(field.default)
 | 
					            cols.append('    %s %s' % (name, field.get_sql()))
 | 
				
			||||||
            cols.append('    %s %s DEFAULT %s' % (name, field.db_type, escape(default)))
 | 
					 | 
				
			||||||
        parts.append(',\n'.join(cols))
 | 
					        parts.append(',\n'.join(cols))
 | 
				
			||||||
        parts.append(')')
 | 
					        parts.append(')')
 | 
				
			||||||
        parts.append('ENGINE = ' + cls.engine.create_table_sql())
 | 
					        parts.append('ENGINE = ' + cls.engine.create_table_sql())
 | 
				
			||||||
| 
						 | 
					@ -140,8 +154,5 @@ class Model(with_metaclass(ModelBase)):
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        Returns the instance's column values as a tab-separated line. A newline is not included.
 | 
					        Returns the instance's column values as a tab-separated line. A newline is not included.
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        parts = []
 | 
					        data = self.__dict__
 | 
				
			||||||
        for name, field in self._fields:
 | 
					        return '\t'.join(field.to_db_string(data[name], quote=False) for name, field in self._fields)
 | 
				
			||||||
            value = field.get_db_prep_value(field.to_python(getattr(self, name)))
 | 
					 | 
				
			||||||
            parts.append(escape(value, quote=False))
 | 
					 | 
				
			||||||
        return '\t'.join(parts)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,6 @@
 | 
				
			||||||
from six import string_types, binary_type, text_type, PY3
 | 
					from six import string_types, binary_type, text_type, PY3
 | 
				
			||||||
import codecs
 | 
					import codecs
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SPECIAL_CHARS = {
 | 
					SPECIAL_CHARS = {
 | 
				
			||||||
| 
						 | 
					@ -13,11 +14,20 @@ SPECIAL_CHARS = {
 | 
				
			||||||
    "'"  : "\\'"
 | 
					    "'"  : "\\'"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def escape(value, quote=True):
 | 
					def escape(value, quote=True):
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    If the value is a string, escapes any special characters and optionally
 | 
				
			||||||
 | 
					    surrounds it with single quotes. If the value is not a string (e.g. a number), 
 | 
				
			||||||
 | 
					    converts it to one.
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
    if isinstance(value, string_types):
 | 
					    if isinstance(value, string_types):
 | 
				
			||||||
        chars = (SPECIAL_CHARS.get(c, c) for c in value)
 | 
					        if SPECIAL_CHARS_REGEX.search(value):
 | 
				
			||||||
        value = "'" + "".join(chars) + "'" if quote else "".join(chars)
 | 
					            value = "".join(SPECIAL_CHARS.get(c, c) for c in value)
 | 
				
			||||||
 | 
					        if quote:
 | 
				
			||||||
 | 
					            value = "'" + value + "'"
 | 
				
			||||||
    return text_type(value)
 | 
					    return text_type(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,6 +43,40 @@ def parse_tsv(line):
 | 
				
			||||||
    return [unescape(value) for value in line.split('\t')]
 | 
					    return [unescape(value) for value in line.split('\t')]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_array(array_string):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Parse an array string as returned by clickhouse. For example:
 | 
				
			||||||
 | 
					        "['hello', 'world']" ==> ["hello", "world"]
 | 
				
			||||||
 | 
					        "[1,2,3]"            ==> [1, 2, 3]
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Sanity check
 | 
				
			||||||
 | 
					    if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']':
 | 
				
			||||||
 | 
					        raise ValueError('Invalid array string: "%s"' % array_string)
 | 
				
			||||||
 | 
					    # Drop opening brace
 | 
				
			||||||
 | 
					    array_string = array_string[1:] 
 | 
				
			||||||
 | 
					    # Go over the string, lopping off each value at the beginning until nothing is left
 | 
				
			||||||
 | 
					    values = []
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if array_string == ']':
 | 
				
			||||||
 | 
					            # End of array
 | 
				
			||||||
 | 
					            return values
 | 
				
			||||||
 | 
					        elif array_string[0] in ', ':
 | 
				
			||||||
 | 
					            # In between values
 | 
				
			||||||
 | 
					            array_string = array_string[1:] 
 | 
				
			||||||
 | 
					        elif array_string[0] == "'":
 | 
				
			||||||
 | 
					            # Start of quoted value, find its end
 | 
				
			||||||
 | 
					            match = re.search(r"[^\\]'", array_string)
 | 
				
			||||||
 | 
					            if match is None:
 | 
				
			||||||
 | 
					                raise ValueError('Missing closing quote: "%s"' % array_string)
 | 
				
			||||||
 | 
					            values.append(array_string[1 : match.start() + 1])
 | 
				
			||||||
 | 
					            array_string = array_string[match.end():]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Start of non-quoted value, find its end
 | 
				
			||||||
 | 
					            match = re.search(r",|\]", array_string)
 | 
				
			||||||
 | 
					            values.append(array_string[0 : match.start()])
 | 
				
			||||||
 | 
					            array_string = array_string[match.end() - 1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def import_submodules(package_name):
 | 
					def import_submodules(package_name):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Import all submodules of a module.
 | 
					    Import all submodules of a module.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										6
									
								
								tests/sample_migrations/0006.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								tests/sample_migrations/0006.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,6 @@
 | 
				
			||||||
 | 
					from infi.clickhouse_orm import migrations
 | 
				
			||||||
 | 
					from ..test_migrations import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					operations = [
 | 
				
			||||||
 | 
					    migrations.CreateTable(EnumModel1)
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										6
									
								
								tests/sample_migrations/0007.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								tests/sample_migrations/0007.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,6 @@
 | 
				
			||||||
 | 
					from infi.clickhouse_orm import migrations
 | 
				
			||||||
 | 
					from ..test_migrations import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					operations = [
 | 
				
			||||||
 | 
					    migrations.AlterTable(EnumModel2)
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										73
									
								
								tests/test_array_fields.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								tests/test_array_fields.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,73 @@
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					from datetime import date
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.database import Database
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.models import Model
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.fields import *
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.engines import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ArrayFieldsTest(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.database = Database('test-db')
 | 
				
			||||||
 | 
					        self.database.create_table(ModelWithArrays)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        self.database.drop_database()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_insert_and_select(self):
 | 
				
			||||||
 | 
					        instance = ModelWithArrays(
 | 
				
			||||||
 | 
					            date_field='2016-08-30', 
 | 
				
			||||||
 | 
					            arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'], 
 | 
				
			||||||
 | 
					            arr_date=['2010-01-01']
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.database.insert([instance])
 | 
				
			||||||
 | 
					        query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
 | 
				
			||||||
 | 
					        for model_cls in (ModelWithArrays, None):
 | 
				
			||||||
 | 
					            results = list(self.database.select(query, model_cls))
 | 
				
			||||||
 | 
					            self.assertEquals(len(results), 1)
 | 
				
			||||||
 | 
					            self.assertEquals(results[0].arr_str, instance.arr_str)
 | 
				
			||||||
 | 
					            self.assertEquals(results[0].arr_int, instance.arr_int)
 | 
				
			||||||
 | 
					            self.assertEquals(results[0].arr_date, instance.arr_date)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_conversion(self):
 | 
				
			||||||
 | 
					        instance = ModelWithArrays(
 | 
				
			||||||
 | 
					            arr_int=('1', '2', '3'),
 | 
				
			||||||
 | 
					            arr_date=['2010-01-01']
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEquals(instance.arr_str, [])
 | 
				
			||||||
 | 
					        self.assertEquals(instance.arr_int, [1, 2, 3])
 | 
				
			||||||
 | 
					        self.assertEquals(instance.arr_date, [date(2010, 1, 1)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_assignment_error(self):
 | 
				
			||||||
 | 
					        instance = ModelWithArrays()
 | 
				
			||||||
 | 
					        for value in (7, 'x', [date.today()], ['aaa'], [None]):
 | 
				
			||||||
 | 
					            with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					                instance.arr_int = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_parse_array(self):
 | 
				
			||||||
 | 
					        from infi.clickhouse_orm.utils import parse_array, unescape
 | 
				
			||||||
 | 
					        self.assertEquals(parse_array("[]"), [])
 | 
				
			||||||
 | 
					        self.assertEquals(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
 | 
				
			||||||
 | 
					        self.assertEquals(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
 | 
				
			||||||
 | 
					        self.assertEquals(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
 | 
				
			||||||
 | 
					        for s in ("", 
 | 
				
			||||||
 | 
					                  "[", 
 | 
				
			||||||
 | 
					                  "]", 
 | 
				
			||||||
 | 
					                  "[1, 2", 
 | 
				
			||||||
 | 
					                  "3, 4]", 
 | 
				
			||||||
 | 
					                  "['aaa', 'aaa]"):
 | 
				
			||||||
 | 
					            with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					                parse_array(s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelWithArrays(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date_field = DateField()
 | 
				
			||||||
 | 
					    arr_str = ArrayField(StringField())
 | 
				
			||||||
 | 
					    arr_int = ArrayField(Int32Field())
 | 
				
			||||||
 | 
					    arr_date = ArrayField(DateField())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = MergeTree('date_field', ('date_field',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,6 +93,23 @@ class DatabaseTestCase(unittest.TestCase):
 | 
				
			||||||
            # Verify that all instances were returned
 | 
					            # Verify that all instances were returned
 | 
				
			||||||
            self.assertEquals(len(instances), len(data))
 | 
					            self.assertEquals(len(instances), len(data))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_pagination_last_page(self):
 | 
				
			||||||
 | 
					        self._insert_and_check(self._sample_data(), len(data))
 | 
				
			||||||
 | 
					        # Try different page sizes
 | 
				
			||||||
 | 
					        for page_size in (1, 2, 7, 10, 30, 100, 150):
 | 
				
			||||||
 | 
					            # Ask for the last page in two different ways and verify equality
 | 
				
			||||||
 | 
					            page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size)
 | 
				
			||||||
 | 
					            page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size)
 | 
				
			||||||
 | 
					            self.assertEquals(page_a[1:], page_b[1:])
 | 
				
			||||||
 | 
					            self.assertEquals([obj.to_tsv() for obj in page_a.objects], 
 | 
				
			||||||
 | 
					                              [obj.to_tsv() for obj in page_b.objects])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_pagination_invalid_page(self):
 | 
				
			||||||
 | 
					        self._insert_and_check(self._sample_data(), len(data))
 | 
				
			||||||
 | 
					        for page_num in (0, -2, -100):
 | 
				
			||||||
 | 
					            with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					                self.database.paginate(Person, 'first_name, last_name', page_num, 100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_special_chars(self):
 | 
					    def test_special_chars(self):
 | 
				
			||||||
        s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
 | 
					        s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
 | 
				
			||||||
        p = Person(first_name=s)
 | 
					        p = Person(first_name=s)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										87
									
								
								tests/test_enum_fields.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								tests/test_enum_fields.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,87 @@
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.database import Database
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.models import Model
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.fields import *
 | 
				
			||||||
 | 
					from infi.clickhouse_orm.engines import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    Enum # exists in Python 3.4+
 | 
				
			||||||
 | 
					except NameError:
 | 
				
			||||||
 | 
					    from enum import Enum # use the enum34 library instead
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EnumFieldsTest(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.database = Database('test-db')
 | 
				
			||||||
 | 
					        self.database.create_table(ModelWithEnum)
 | 
				
			||||||
 | 
					        self.database.create_table(ModelWithEnumArray)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        self.database.drop_database()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_insert_and_select(self):
 | 
				
			||||||
 | 
					        self.database.insert([
 | 
				
			||||||
 | 
					            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
 | 
				
			||||||
 | 
					            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					        query = 'SELECT * from $table ORDER BY date_field'
 | 
				
			||||||
 | 
					        results = list(self.database.select(query, ModelWithEnum))
 | 
				
			||||||
 | 
					        self.assertEquals(len(results), 2)
 | 
				
			||||||
 | 
					        self.assertEquals(results[0].enum_field, Fruit.apple)
 | 
				
			||||||
 | 
					        self.assertEquals(results[1].enum_field, Fruit.orange)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ad_hoc_model(self):
 | 
				
			||||||
 | 
					        self.database.insert([
 | 
				
			||||||
 | 
					            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
 | 
				
			||||||
 | 
					            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					        query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
 | 
				
			||||||
 | 
					        results = list(self.database.select(query))
 | 
				
			||||||
 | 
					        self.assertEquals(len(results), 2)
 | 
				
			||||||
 | 
					        self.assertEquals(results[0].enum_field.name, Fruit.apple.name)
 | 
				
			||||||
 | 
					        self.assertEquals(results[0].enum_field.value, Fruit.apple.value)
 | 
				
			||||||
 | 
					        self.assertEquals(results[1].enum_field.name, Fruit.orange.name)
 | 
				
			||||||
 | 
					        self.assertEquals(results[1].enum_field.value, Fruit.orange.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_conversion(self):
 | 
				
			||||||
 | 
					        self.assertEquals(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
 | 
				
			||||||
 | 
					        self.assertEquals(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
 | 
				
			||||||
 | 
					        self.assertEquals(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_assignment_error(self):
 | 
				
			||||||
 | 
					        for value in (0, 17, 'pear', '', None, 99.9):
 | 
				
			||||||
 | 
					            with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					                ModelWithEnum(enum_field=value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_default_value(self):
 | 
				
			||||||
 | 
					        instance = ModelWithEnum()
 | 
				
			||||||
 | 
					        self.assertEquals(instance.enum_field, Fruit.apple)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_enum_array(self):
 | 
				
			||||||
 | 
					        instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
 | 
				
			||||||
 | 
					        self.database.insert([instance])
 | 
				
			||||||
 | 
					        query = 'SELECT * from $table ORDER BY date_field'
 | 
				
			||||||
 | 
					        results = list(self.database.select(query, ModelWithEnumArray))
 | 
				
			||||||
 | 
					        self.assertEquals(len(results), 1)
 | 
				
			||||||
 | 
					        self.assertEquals(results[0].enum_array, instance.enum_array)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Fruit = Enum('Fruit', u'apple banana orange')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelWithEnum(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date_field = DateField()
 | 
				
			||||||
 | 
					    enum_field = Enum8Field(Fruit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = MergeTree('date_field', ('date_field',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelWithEnumArray(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date_field = DateField()
 | 
				
			||||||
 | 
					    enum_array = ArrayField(Enum16Field(Fruit))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = MergeTree('date_field', ('date_field',))
 | 
				
			||||||
| 
						 | 
					@ -10,6 +10,11 @@ from infi.clickhouse_orm.migrations import MigrationHistory
 | 
				
			||||||
import sys, os
 | 
					import sys, os
 | 
				
			||||||
sys.path.append(os.path.dirname(__file__))
 | 
					sys.path.append(os.path.dirname(__file__))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    Enum # exists in Python 3.4+
 | 
				
			||||||
 | 
					except NameError:
 | 
				
			||||||
 | 
					    from enum import Enum # use the enum34 library instead
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
 | 
					logging.basicConfig(level=logging.DEBUG, format='%(message)s')
 | 
				
			||||||
logging.getLogger("requests").setLevel(logging.WARNING)
 | 
					logging.getLogger("requests").setLevel(logging.WARNING)
 | 
				
			||||||
| 
						 | 
					@ -21,6 +26,9 @@ class MigrationsTestCase(unittest.TestCase):
 | 
				
			||||||
        self.database = Database('test-db')
 | 
					        self.database = Database('test-db')
 | 
				
			||||||
        self.database.drop_table(MigrationHistory)
 | 
					        self.database.drop_table(MigrationHistory)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        self.database.drop_database()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tableExists(self, model_class):
 | 
					    def tableExists(self, model_class):
 | 
				
			||||||
        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
 | 
					        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
 | 
				
			||||||
        return next(self.database.select(query)).result == 1
 | 
					        return next(self.database.select(query)).result == 1
 | 
				
			||||||
| 
						 | 
					@ -30,18 +38,28 @@ class MigrationsTestCase(unittest.TestCase):
 | 
				
			||||||
        return [(row.name, row.type) for row in self.database.select(query)]
 | 
					        return [(row.name, row.type) for row in self.database.select(query)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_migrations(self):
 | 
					    def test_migrations(self):
 | 
				
			||||||
 | 
					        # Creation and deletion of table
 | 
				
			||||||
        self.database.migrate('tests.sample_migrations', 1)
 | 
					        self.database.migrate('tests.sample_migrations', 1)
 | 
				
			||||||
        self.assertTrue(self.tableExists(Model1))
 | 
					        self.assertTrue(self.tableExists(Model1))
 | 
				
			||||||
        self.database.migrate('tests.sample_migrations', 2)
 | 
					        self.database.migrate('tests.sample_migrations', 2)
 | 
				
			||||||
        self.assertFalse(self.tableExists(Model1))
 | 
					        self.assertFalse(self.tableExists(Model1))
 | 
				
			||||||
        self.database.migrate('tests.sample_migrations', 3)
 | 
					        self.database.migrate('tests.sample_migrations', 3)
 | 
				
			||||||
        self.assertTrue(self.tableExists(Model1))
 | 
					        self.assertTrue(self.tableExists(Model1))
 | 
				
			||||||
 | 
					        # Adding, removing and altering simple fields
 | 
				
			||||||
        self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
 | 
					        self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
 | 
				
			||||||
        self.database.migrate('tests.sample_migrations', 4)
 | 
					        self.database.migrate('tests.sample_migrations', 4)
 | 
				
			||||||
        self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
 | 
					        self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
 | 
				
			||||||
        self.database.migrate('tests.sample_migrations', 5)
 | 
					        self.database.migrate('tests.sample_migrations', 5)
 | 
				
			||||||
        self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
 | 
					        self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
 | 
				
			||||||
 | 
					        # Altering enum fields
 | 
				
			||||||
 | 
					        self.database.migrate('tests.sample_migrations', 6)
 | 
				
			||||||
 | 
					        self.assertTrue(self.tableExists(EnumModel1))
 | 
				
			||||||
 | 
					        self.assertEquals(self.getTableFields(EnumModel1), 
 | 
				
			||||||
 | 
					                          [('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
 | 
				
			||||||
 | 
					        self.database.migrate('tests.sample_migrations', 7)
 | 
				
			||||||
 | 
					        self.assertTrue(self.tableExists(EnumModel1))
 | 
				
			||||||
 | 
					        self.assertEquals(self.getTableFields(EnumModel2), 
 | 
				
			||||||
 | 
					                          [('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Several different models with the same table name, to simulate a table that changes over time
 | 
					# Several different models with the same table name, to simulate a table that changes over time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,3 +104,26 @@ class Model3(Model):
 | 
				
			||||||
    def table_name(cls):
 | 
					    def table_name(cls):
 | 
				
			||||||
        return 'mig'
 | 
					        return 'mig'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EnumModel1(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date = DateField()
 | 
				
			||||||
 | 
					    f1 = Enum8Field(Enum('SomeEnum1', 'dog cat cow'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = MergeTree('date', ('date',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def table_name(cls):
 | 
				
			||||||
 | 
					        return 'enum_mig'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EnumModel2(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date = DateField()
 | 
				
			||||||
 | 
					    f1 = Enum16Field(Enum('SomeEnum2', 'dog cat horse pig')) # changed type and values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    engine = MergeTree('date', ('date',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def table_name(cls):
 | 
				
			||||||
 | 
					        return 'enum_mig'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user