From 87ee685c8bac32a3ff1d9a2ca2d71a4ef5f5c8c8 Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Mon, 4 Jul 2016 17:01:51 +0300 Subject: [PATCH] migrations support --- src/infi/clickhouse_orm/database.py | 24 +++++- src/infi/clickhouse_orm/migrations.py | 110 ++++++++++++++++++++++++ src/infi/clickhouse_orm/utils.py | 12 +++ tests/__init__.py | 1 + tests/sample_migrations/0001_initial.py | 6 ++ tests/sample_migrations/0002.py | 6 ++ tests/sample_migrations/0003.py | 6 ++ tests/sample_migrations/0004.py | 6 ++ tests/sample_migrations/0005.py | 6 ++ tests/sample_migrations/__init__.py | 0 tests/test_migrations.py | 88 +++++++++++++++++++ 11 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 src/infi/clickhouse_orm/migrations.py create mode 100644 tests/sample_migrations/0001_initial.py create mode 100644 tests/sample_migrations/0002.py create mode 100644 tests/sample_migrations/0003.py create mode 100644 tests/sample_migrations/0004.py create mode 100644 tests/sample_migrations/0005.py create mode 100644 tests/sample_migrations/__init__.py create mode 100644 tests/test_migrations.py diff --git a/src/infi/clickhouse_orm/database.py b/src/infi/clickhouse_orm/database.py index fd85044..1761469 100644 --- a/src/infi/clickhouse_orm/database.py +++ b/src/infi/clickhouse_orm/database.py @@ -1,8 +1,10 @@ import requests from collections import namedtuple from models import ModelBase -from utils import escape, parse_tsv +from utils import escape, parse_tsv, import_submodules from math import ceil +import datetime +import logging Page = namedtuple('Page', 'objects number_of_objects pages_total number page_size') @@ -81,6 +83,26 @@ class Database(object): page_size=page_size ) + def migrate(self, migrations_package_name, up_to=9999): + from migrations import MigrationHistory + logger = logging.getLogger('migrations') + applied_migrations = self._get_applied_migrations(migrations_package_name) + modules = import_submodules(migrations_package_name) + unapplied_migrations = set(modules.keys()) - applied_migrations + for name in sorted(unapplied_migrations): + logger.info('Applying migration %s...', name) + for operation in modules[name].operations: + operation.apply(self) + self.insert([MigrationHistory(package_name=migrations_package_name, module_name=name, applied=datetime.date.today())]) + if int(name[:4]) >= up_to: + break + + def _get_applied_migrations(self, migrations_package_name): + from migrations import MigrationHistory + self.create_table(MigrationHistory) + query = "SELECT module_name from `%s`.`%s` WHERE package_name = '%s'" % (self.db_name, MigrationHistory.table_name(), migrations_package_name) + return set(obj.module_name for obj in self.select(query)) + def _send(self, data, settings=None): params = self._build_params(settings) r = requests.post(self.db_url, params=params, data=data, stream=True) diff --git a/src/infi/clickhouse_orm/migrations.py b/src/infi/clickhouse_orm/migrations.py new file mode 100644 index 0000000..770a36e --- /dev/null +++ b/src/infi/clickhouse_orm/migrations.py @@ -0,0 +1,110 @@ +from models import Model +from fields import DateField, StringField +from engines import MergeTree +from utils import escape + +from itertools import izip + +import logging +logger = logging.getLogger('migrations') + + +class Operation(object): + ''' + Base class for migration operations. + ''' + + def apply(self, database): + raise NotImplementedError() + + +class CreateTable(Operation): + ''' + A migration operation that creates a table for a given model class. + ''' + + def __init__(self, model_class): + self.model_class = model_class + + def apply(self, database): + logger.info(' Create table %s', self.model_class.table_name()) + database.create_table(self.model_class) + + +class AlterTable(Operation): + ''' + A migration operation that compares the table of a given model class to + the model's fields, and alters the table to match the model. The operation can: + - add new columns + - drop obsolete columns + - modify column types + Default values are not altered by this operation. + ''' + + def __init__(self, model_class): + self.model_class = model_class + + def _get_table_fields(self, database): + query = "DESC `%s`.`%s`" % (database.db_name, self.model_class.table_name()) + return [(row.name, row.type) for row in database.select(query)] + + def _alter_table(self, database, cmd): + cmd = "ALTER TABLE `%s`.`%s` %s" % (database.db_name, self.model_class.table_name(), cmd) + logger.debug(cmd) + database._send(cmd) + + def apply(self, database): + logger.info(' Alter table %s', self.model_class.table_name()) + table_fields = dict(self._get_table_fields(database)) + # Identify fields that were deleted from the model + deleted_fields = set(table_fields.keys()) - set(name for name, field in self.model_class._fields) + for name in deleted_fields: + logger.info(' Drop column %s', name) + self._alter_table(database, 'DROP COLUMN %s' % name) + del table_fields[name] + # Identify fields that were added to the model + prev_name = None + for name, field in self.model_class._fields: + if name not in table_fields: + logger.info(' Add column %s', name) + assert prev_name, 'Cannot add a column to the beginning of the table' + default = field.get_db_prep_value(field.default) + cmd = 'ADD COLUMN %s %s DEFAULT %s AFTER %s' % (name, field.db_type, escape(default), prev_name) + self._alter_table(database, cmd) + prev_name = name + # Identify fields whose type was changed + model_fields = [(name, field.db_type) for name, field in self.model_class._fields] + for model_field, table_field in izip(model_fields, self._get_table_fields(database)): + assert model_field[0] == table_field[0], 'Model fields and table columns in disagreement' + if model_field[1] != table_field[1]: + logger.info(' Change type of column %s from %s to %s', table_field[0], table_field[1], model_field[1]) + self._alter_table(database, 'MODIFY COLUMN %s %s' % model_field) + + +class DropTable(Operation): + ''' + A migration operation that drops the table of a given model class. + ''' + + def __init__(self, model_class): + self.model_class = model_class + + def apply(self, database): + logger.info(' Drop table %s', self.model_class.__name__) + database.drop_table(self.model_class) + + +class MigrationHistory(Model): + ''' + A model for storing which migrations were already applied to the containing database. + ''' + + package_name = StringField() + module_name = StringField() + applied = DateField() + + engine = MergeTree('applied', ('package_name', 'module_name')) + + @classmethod + def table_name(cls): + return 'infi_clickhouse_orm_migrations' diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index 62c61e2..0ac8ca4 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -26,3 +26,15 @@ def parse_tsv(line): if line[-1] == '\n': line = line[:-1] return [unescape(value) for value in line.split('\t')] + + +def import_submodules(package_name): + """ + Import all submodules of a module. + """ + import importlib, pkgutil + package = importlib.import_module(package_name) + return { + name: importlib.import_module(package_name + '.' + name) + for _, name, _ in pkgutil.iter_modules(package.__path__) + } diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..5284146 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +__import__("pkg_resources").declare_namespace(__name__) diff --git a/tests/sample_migrations/0001_initial.py b/tests/sample_migrations/0001_initial.py new file mode 100644 index 0000000..a289d86 --- /dev/null +++ b/tests/sample_migrations/0001_initial.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.CreateTable(Model1) +] \ No newline at end of file diff --git a/tests/sample_migrations/0002.py b/tests/sample_migrations/0002.py new file mode 100644 index 0000000..6e4e0d9 --- /dev/null +++ b/tests/sample_migrations/0002.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.DropTable(Model1) +] \ No newline at end of file diff --git a/tests/sample_migrations/0003.py b/tests/sample_migrations/0003.py new file mode 100644 index 0000000..a289d86 --- /dev/null +++ b/tests/sample_migrations/0003.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.CreateTable(Model1) +] \ No newline at end of file diff --git a/tests/sample_migrations/0004.py b/tests/sample_migrations/0004.py new file mode 100644 index 0000000..6d10205 --- /dev/null +++ b/tests/sample_migrations/0004.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.AlterTable(Model2) +] \ No newline at end of file diff --git a/tests/sample_migrations/0005.py b/tests/sample_migrations/0005.py new file mode 100644 index 0000000..f2633ef --- /dev/null +++ b/tests/sample_migrations/0005.py @@ -0,0 +1,6 @@ +from infi.clickhouse_orm import migrations +from ..test_migrations import * + +operations = [ + migrations.AlterTable(Model3) +] \ No newline at end of file diff --git a/tests/sample_migrations/__init__.py b/tests/sample_migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_migrations.py b/tests/test_migrations.py new file mode 100644 index 0000000..627e3f6 --- /dev/null +++ b/tests/test_migrations.py @@ -0,0 +1,88 @@ +import unittest + +from infi.clickhouse_orm.database import Database +from infi.clickhouse_orm.models import Model +from infi.clickhouse_orm.fields import * +from infi.clickhouse_orm.engines import * +from infi.clickhouse_orm.migrations import MigrationHistory + +# Add tests to path so that migrations will be importable +import sys, os +sys.path.append(os.path.dirname(__file__)) + +import logging +logging.basicConfig(level=logging.DEBUG, format='%(message)s') +logging.getLogger("requests").setLevel(logging.WARNING) + + +class MigrationsTestCase(unittest.TestCase): + + def setUp(self): + self.database = Database('test-db') + self.database.drop_table(MigrationHistory) + + def tableExists(self, model_class): + query = "EXISTS TABLE `%s`.`%s`" % (self.database.db_name, model_class.table_name()) + return next(self.database.select(query)).result == 1 + + def getTableFields(self, model_class): + query = "DESC `%s`.`%s`" % (self.database.db_name, model_class.table_name()) + return [(row.name, row.type) for row in self.database.select(query)] + + def test_migrations(self): + self.database.migrate('tests.sample_migrations', 1) + self.assertTrue(self.tableExists(Model1)) + self.database.migrate('tests.sample_migrations', 2) + self.assertFalse(self.tableExists(Model1)) + self.database.migrate('tests.sample_migrations', 3) + self.assertTrue(self.tableExists(Model1)) + self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')]) + self.database.migrate('tests.sample_migrations', 4) + self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')]) + self.database.migrate('tests.sample_migrations', 5) + self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')]) + + +# Several different models with the same table name, to simulate a table that changes over time + +class Model1(Model): + + date = DateField() + f1 = Int32Field() + f2 = StringField() + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'mig' + + +class Model2(Model): + + date = DateField() + f1 = Int32Field() + f3 = Float32Field() + f2 = StringField() + f4 = StringField() + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'mig' + + +class Model3(Model): + + date = DateField() + f1 = Int64Field() # changed from Int32 + f3 = Float64Field() # changed from Float32 + f4 = StringField() + + engine = MergeTree('date', ('date',)) + + @classmethod + def table_name(cls): + return 'mig' +