2020-09-01 23:04:48 +03:00
|
|
|
"""`Factory` provider - passing injections to the underlying providers example."""
|
2020-08-06 23:33:06 +03:00
|
|
|
|
2020-09-03 23:46:03 +03:00
|
|
|
from dependency_injector import containers, providers
|
2020-08-06 23:33:06 +03:00
|
|
|
|
|
|
|
|
|
|
|
class Regularizer:
|
2020-09-01 04:26:21 +03:00
|
|
|
def __init__(self, alpha: float) -> None:
|
2020-08-06 23:33:06 +03:00
|
|
|
self.alpha = alpha
|
|
|
|
|
|
|
|
|
|
|
|
class Loss:
|
2020-09-01 04:26:21 +03:00
|
|
|
def __init__(self, regularizer: Regularizer) -> None:
|
2020-08-06 23:33:06 +03:00
|
|
|
self.regularizer = regularizer
|
|
|
|
|
|
|
|
|
|
|
|
class ClassificationTask:
|
2020-09-01 04:26:21 +03:00
|
|
|
def __init__(self, loss: Loss) -> None:
|
2020-08-06 23:33:06 +03:00
|
|
|
self.loss = loss
|
|
|
|
|
|
|
|
|
|
|
|
class Algorithm:
|
2020-09-01 04:26:21 +03:00
|
|
|
def __init__(self, task: ClassificationTask) -> None:
|
2020-08-06 23:33:06 +03:00
|
|
|
self.task = task
|
|
|
|
|
|
|
|
|
2020-09-03 23:46:03 +03:00
|
|
|
class Container(containers.DeclarativeContainer):
|
|
|
|
|
|
|
|
algorithm_factory = providers.Factory(
|
|
|
|
Algorithm,
|
|
|
|
task=providers.Factory(
|
|
|
|
ClassificationTask,
|
|
|
|
loss=providers.Factory(
|
|
|
|
Loss,
|
|
|
|
regularizer=providers.Factory(
|
|
|
|
Regularizer,
|
|
|
|
),
|
2020-08-06 23:33:06 +03:00
|
|
|
),
|
|
|
|
),
|
2020-09-03 23:46:03 +03:00
|
|
|
)
|
2020-08-06 23:33:06 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2020-09-03 23:46:03 +03:00
|
|
|
container = Container()
|
|
|
|
|
|
|
|
algorithm_1 = container.algorithm_factory(
|
|
|
|
task__loss__regularizer__alpha=0.5,
|
|
|
|
)
|
2020-08-06 23:33:06 +03:00
|
|
|
assert algorithm_1.task.loss.regularizer.alpha == 0.5
|
|
|
|
|
2020-09-03 23:46:03 +03:00
|
|
|
algorithm_2 = container.algorithm_factory(
|
|
|
|
task__loss__regularizer__alpha=0.7,
|
|
|
|
)
|
2020-08-06 23:33:06 +03:00
|
|
|
assert algorithm_2.task.loss.regularizer.alpha == 0.7
|