diff --git a/dependency_injector/injections.py b/dependency_injector/injections.py index 747babcb..7fc800c2 100644 --- a/dependency_injector/injections.py +++ b/dependency_injector/injections.py @@ -6,6 +6,8 @@ from .utils import is_provider from .utils import ensure_is_injection from .utils import get_injectable_kwargs +from .errors import Error + class Injection(object): @@ -60,8 +62,20 @@ def inject(*args, **kwargs): injections += tuple(ensure_is_injection(injection) for injection in args) - def decorator(callback): + def decorator(callback, cls=None): """Dependency injection decorator.""" + if isinstance(callback, six.class_types): + cls = callback + try: + cls_init = six.get_unbound_function(getattr(cls, '__init__')) + except AttributeError: + raise Error( + 'Class {0} has no __init__() '.format(cls.__module__, + cls.__name__) + + 'method and could not be decorated with @inject decorator') + cls.__init__ = decorator(cls_init) + return cls + if hasattr(callback, 'injections'): callback.injections += injections return callback diff --git a/tests/test_injections.py b/tests/test_injections.py index ad17b533..a3bb0903 100644 --- a/tests/test_injections.py +++ b/tests/test_injections.py @@ -151,3 +151,47 @@ class InjectTests(unittest.TestCase): def test_decorate_with_not_injection(self): """Test `inject()` decorator with not an injection instance.""" self.assertRaises(di.Error, di.inject, object) + + def test_decorate_class_method(self): + """Test `inject()` decorator with class method.""" + class Test(object): + + """Test class.""" + + @di.inject(arg1=123) + @di.inject(arg2=456) + def some_method(self, arg1, arg2): + """Some test method.""" + return arg1, arg2 + + test_object = Test() + arg1, arg2 = test_object.some_method() + + self.assertEquals(arg1, 123) + self.assertEquals(arg2, 456) + + def test_decorate_class_with_init(self): + """Test `inject()` decorator that decorate class with __init__.""" + @di.inject(arg1=123) + @di.inject(arg2=456) + class Test(object): + + """Test class.""" + + def __init__(self, arg1, arg2): + """Init.""" + self.arg1 = arg1 + self.arg2 = arg2 + + test_object = Test() + + self.assertEquals(test_object.arg1, 123) + self.assertEquals(test_object.arg2, 456) + + def test_decorate_class_without_init(self): + """Test `inject()` decorator that decorate class without __init__.""" + with self.assertRaises(di.Error): + @di.inject(arg1=123) + class Test(object): + + """Test class."""