mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-05-22 13:36:15 +03:00
Add .required() for configuration option
This commit is contained in:
parent
6f22549882
commit
d039dc002e
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -98,6 +98,7 @@ cdef class ConfigurationOption(Provider):
|
|||
cdef tuple __name
|
||||
cdef object __root_ref
|
||||
cdef dict __children
|
||||
cdef bint __required
|
||||
cdef object __cache
|
||||
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ class ConfigurationOption(Provider[Any]):
|
|||
def as_int(self) -> TypedConfigurationOption[int]: ...
|
||||
def as_float(self) -> TypedConfigurationOption[float]: ...
|
||||
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
|
||||
def required(self) -> ConfigurationOption: ...
|
||||
def update(self, value: Any) -> None: ...
|
||||
def from_ini(self, filepath: Union[Path, str]) -> None: ...
|
||||
def from_yaml(self, filepath: Union[Path, str]) -> None: ...
|
||||
|
|
|
@ -1172,10 +1172,11 @@ cdef class ConfigurationOption(Provider):
|
|||
|
||||
UNDEFINED = object()
|
||||
|
||||
def __init__(self, name, root):
|
||||
def __init__(self, name, root, required=False):
|
||||
self.__name = name
|
||||
self.__root_ref = weakref.ref(root)
|
||||
self.__children = {}
|
||||
self.__required = required
|
||||
self.__cache = self.UNDEFINED
|
||||
super().__init__()
|
||||
|
||||
|
@ -1193,7 +1194,7 @@ cdef class ConfigurationOption(Provider):
|
|||
if copied_root is None:
|
||||
copied_root = deepcopy(root, memo)
|
||||
|
||||
copied = self.__class__(copied_name, copied_root)
|
||||
copied = self.__class__(copied_name, copied_root, self.__required)
|
||||
copied.__children = deepcopy(self.__children, memo)
|
||||
|
||||
return copied
|
||||
|
@ -1229,7 +1230,7 @@ cdef class ConfigurationOption(Provider):
|
|||
return self.__cache
|
||||
|
||||
root = self.__root_ref()
|
||||
value = root.get(self._get_self_name())
|
||||
value = root.get(self._get_self_name(), self.__required)
|
||||
self.__cache = value
|
||||
return value
|
||||
|
||||
|
@ -1258,6 +1259,9 @@ cdef class ConfigurationOption(Provider):
|
|||
def as_(self, callback, *args, **kwargs):
|
||||
return TypedConfigurationOption(callback, self, *args, **kwargs)
|
||||
|
||||
def required(self):
|
||||
return self.__class__(self.__name, self.__root_ref(), required=True)
|
||||
|
||||
def override(self, value):
|
||||
if isinstance(value, Provider):
|
||||
raise Error('Configuration option can only be overridden by a value')
|
||||
|
@ -1452,12 +1456,15 @@ cdef class Configuration(Object):
|
|||
def get_name(self):
|
||||
return self.__name
|
||||
|
||||
def get(self, selector):
|
||||
def get(self, selector, required=False):
|
||||
"""Return configuration option.
|
||||
|
||||
:param selector: Selector string, e.g. "option1.option2"
|
||||
:type selector: str
|
||||
|
||||
:param required: Required flag, raise error if required option is missing
|
||||
:type required: bool
|
||||
|
||||
:return: Option value.
|
||||
:rtype: Any
|
||||
"""
|
||||
|
@ -1472,7 +1479,7 @@ cdef class Configuration(Object):
|
|||
value = value.get(key, self.UNDEFINED)
|
||||
|
||||
if value is self.UNDEFINED:
|
||||
if self.__strict:
|
||||
if self.__strict or required:
|
||||
raise Error('Undefined configuration option "{0}.{1}"'.format(self.__name, selector))
|
||||
return None
|
||||
|
||||
|
|
|
@ -21,3 +21,7 @@ config3 = providers.Configuration()
|
|||
int3: providers.Callable[int] = config3.option.as_int()
|
||||
float3: providers.Callable[float] = config3.option.as_float()
|
||||
int3_custom: providers.Callable[int] = config3.option.as_(int)
|
||||
|
||||
# Test 4: to check required() method
|
||||
config4 = providers.Configuration()
|
||||
option4: providers.ConfigurationOption = config4.option.required()
|
||||
|
|
|
@ -97,6 +97,31 @@ class ConfigTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(value, decimal.Decimal('123.123'))
|
||||
|
||||
def test_required(self):
|
||||
provider = providers.Callable(
|
||||
lambda value: value,
|
||||
self.config.a.required(),
|
||||
)
|
||||
with self.assertRaisesRegex(errors.Error, 'Undefined configuration option "config.a"'):
|
||||
provider()
|
||||
|
||||
def test_required_no_side_effect(self):
|
||||
_ = providers.Callable(
|
||||
lambda value: value,
|
||||
self.config.a.required(),
|
||||
)
|
||||
self.assertIsNone(self.config.a())
|
||||
|
||||
def test_required_as_(self):
|
||||
provider = providers.List(
|
||||
self.config.int_test.required().as_int(),
|
||||
self.config.float_test.required().as_float(),
|
||||
self.config._as_test.required().as_(decimal.Decimal),
|
||||
)
|
||||
self.config.from_dict({'int_test': '1', 'float_test': '2.0', '_as_test': '3.0'})
|
||||
|
||||
self.assertEqual(provider(), [1, 2.0, decimal.Decimal('3.0')])
|
||||
|
||||
def test_providers_value_override(self):
|
||||
a = self.config.a
|
||||
ab = self.config.a.b
|
||||
|
|
Loading…
Reference in New Issue
Block a user