diff --git a/src/django_clickhouse/migrations.py b/src/django_clickhouse/migrations.py index 626960e..9f60289 100644 --- a/src/django_clickhouse/migrations.py +++ b/src/django_clickhouse/migrations.py @@ -2,6 +2,7 @@ Migrating database """ import datetime +from typing import Optional from django.db import DEFAULT_DB_ALIAS as DJANGO_DEFAULT_DB_ALIAS from django.db.models.signals import post_migrate @@ -10,7 +11,7 @@ from infi.clickhouse_orm.migrations import * from infi.clickhouse_orm.utils import import_submodules from .configuration import config -from .database import connections +from .database import connections, Database from .utils import lazy_class_import, module_exists @@ -20,10 +21,11 @@ class Migration: """ operations = [] - def apply(self, db_alias): # type: (str) -> None + def apply(self, db_alias, database=None): # type: (str, Optional[Database]) -> None """ Applies migration to given database :param db_alias: Database alias to apply migration to + :param database: Sometimes I want to pass db object directly for testing purposes :return: None """ db_router = lazy_class_import(config.DATABASE_ROUTER)() @@ -33,32 +35,33 @@ class Migration: hints = getattr(op, 'hints', {}) if db_router.allow_migrate(db_alias, self.__module__, model=model_class, **hints): - op.apply(connections[db_alias]) + database = database or connections[db_alias] + op.apply(database) -def migrate_app(app_label, db_alias, up_to=9999): - # type: (str, str, int) -> None +def migrate_app(app_label, db_alias, up_to=9999, database=None): + # type: (str, str, int, Optional[Database]) -> None """ Migrates given django app :param app_label: App label to migrate :param db_alias: Database alias to migrate :param up_to: Migration number to migrate to + :param database: Sometimes I want to pass db object directly for testing purposes :return: None """ - db = connections[db_alias] migrations_package = "%s.%s" % (app_label, config.MIGRATIONS_PACKAGE) if module_exists(migrations_package): - applied_migrations = db._get_applied_migrations(migrations_package) + applied_migrations = database._get_applied_migrations(migrations_package) modules = import_submodules(migrations_package) unapplied_migrations = set(modules.keys()) - applied_migrations for name in sorted(unapplied_migrations): migration = modules[name].Migration() - migration.apply(db_alias) + migration.apply(db_alias, database=database) - db.insert([ + database.insert([ MigrationHistory(package_name=migrations_package, module_name=name, applied=datetime.date.today()) ])