Add .required() for configuration option

This commit is contained in:
Roman Mogylatov 2021-01-16 07:55:09 -05:00
parent 6f22549882
commit d039dc002e
7 changed files with 5341 additions and 5046 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -98,6 +98,7 @@ cdef class ConfigurationOption(Provider):
cdef tuple __name
cdef object __root_ref
cdef dict __children
cdef bint __required
cdef object __cache

View File

@ -146,6 +146,7 @@ class ConfigurationOption(Provider[Any]):
def as_int(self) -> TypedConfigurationOption[int]: ...
def as_float(self) -> TypedConfigurationOption[float]: ...
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
def required(self) -> ConfigurationOption: ...
def update(self, value: Any) -> None: ...
def from_ini(self, filepath: Union[Path, str]) -> None: ...
def from_yaml(self, filepath: Union[Path, str]) -> None: ...

View File

@ -1172,10 +1172,11 @@ cdef class ConfigurationOption(Provider):
UNDEFINED = object()
def __init__(self, name, root):
def __init__(self, name, root, required=False):
self.__name = name
self.__root_ref = weakref.ref(root)
self.__children = {}
self.__required = required
self.__cache = self.UNDEFINED
super().__init__()
@ -1193,7 +1194,7 @@ cdef class ConfigurationOption(Provider):
if copied_root is None:
copied_root = deepcopy(root, memo)
copied = self.__class__(copied_name, copied_root)
copied = self.__class__(copied_name, copied_root, self.__required)
copied.__children = deepcopy(self.__children, memo)
return copied
@ -1229,7 +1230,7 @@ cdef class ConfigurationOption(Provider):
return self.__cache
root = self.__root_ref()
value = root.get(self._get_self_name())
value = root.get(self._get_self_name(), self.__required)
self.__cache = value
return value
@ -1258,6 +1259,9 @@ cdef class ConfigurationOption(Provider):
def as_(self, callback, *args, **kwargs):
return TypedConfigurationOption(callback, self, *args, **kwargs)
def required(self):
return self.__class__(self.__name, self.__root_ref(), required=True)
def override(self, value):
if isinstance(value, Provider):
raise Error('Configuration option can only be overridden by a value')
@ -1452,12 +1456,15 @@ cdef class Configuration(Object):
def get_name(self):
return self.__name
def get(self, selector):
def get(self, selector, required=False):
"""Return configuration option.
:param selector: Selector string, e.g. "option1.option2"
:type selector: str
:param required: Required flag, raise error if required option is missing
:type required: bool
:return: Option value.
:rtype: Any
"""
@ -1472,7 +1479,7 @@ cdef class Configuration(Object):
value = value.get(key, self.UNDEFINED)
if value is self.UNDEFINED:
if self.__strict:
if self.__strict or required:
raise Error('Undefined configuration option "{0}.{1}"'.format(self.__name, selector))
return None

View File

@ -21,3 +21,7 @@ config3 = providers.Configuration()
int3: providers.Callable[int] = config3.option.as_int()
float3: providers.Callable[float] = config3.option.as_float()
int3_custom: providers.Callable[int] = config3.option.as_(int)
# Test 4: to check required() method
config4 = providers.Configuration()
option4: providers.ConfigurationOption = config4.option.required()

View File

@ -97,6 +97,31 @@ class ConfigTests(unittest.TestCase):
self.assertEqual(value, decimal.Decimal('123.123'))
def test_required(self):
provider = providers.Callable(
lambda value: value,
self.config.a.required(),
)
with self.assertRaisesRegex(errors.Error, 'Undefined configuration option "config.a"'):
provider()
def test_required_no_side_effect(self):
_ = providers.Callable(
lambda value: value,
self.config.a.required(),
)
self.assertIsNone(self.config.a())
def test_required_as_(self):
provider = providers.List(
self.config.int_test.required().as_int(),
self.config.float_test.required().as_float(),
self.config._as_test.required().as_(decimal.Decimal),
)
self.config.from_dict({'int_test': '1', 'float_test': '2.0', '_as_test': '3.0'})
self.assertEqual(provider(), [1, 2.0, decimal.Decimal('3.0')])
def test_providers_value_override(self):
a = self.config.a
ab = self.config.a.b