diff --git a/apps/tgbot/tgbot/handlers/errors.py b/apps/tgbot/tgbot/handlers/errors.py new file mode 100644 index 0000000..ceaa49c --- /dev/null +++ b/apps/tgbot/tgbot/handlers/errors.py @@ -0,0 +1,10 @@ +class NotEnoughArgs(Exception): + pass + + +class InvalidPort(Exception): + pass + + +class LocalhostForbidden(Exception): + pass diff --git a/apps/tgbot/tgbot/handlers/helpers.py b/apps/tgbot/tgbot/handlers/helpers.py index 1d38c3e..f6b8f68 100644 --- a/apps/tgbot/tgbot/handlers/helpers.py +++ b/apps/tgbot/tgbot/handlers/helpers.py @@ -8,23 +8,9 @@ from aiogram.bot import Bot from tgbot.handlers.metrics import push_api_request_status from tgbot.config import NOTIFICATION_BOT_TOKEN, NOTIFICATION_USERS from traceback import format_exc -from functools import wraps -from time import time import asyncio -def timing(f): - @wraps(f) - def wrap(*args, **kw): - ts = time() - result = f(*args, **kw) - te = time() - logger.info(f"func {f.__name__} took {te - ts} sec") - return result - - return wrap - - def check_int(value) -> bool: try: int(value) @@ -34,27 +20,6 @@ def check_int(value) -> bool: return True -def validate_local(target: str) -> bool: - """ - Validates ip or FQDN is localhost - - :return True if localhost find - """ - if target == "localhost": - return True - with suppress(ValueError): - ip_addr = ip_address(target) - if any( - [ip_addr.is_loopback, - ip_addr.is_private, - ip_addr.is_multicast, - ip_addr.is_link_local, - ip_addr.is_unspecified] - ): - return True - return False - - async def send_api_request(client: AsyncClient, endpoint: str, data: dict, node: APINode): try: data['token'] = node.token diff --git a/apps/tgbot/tgbot/handlers/validators.py b/apps/tgbot/tgbot/handlers/validators.py new file mode 100644 index 0000000..160317e --- /dev/null +++ b/apps/tgbot/tgbot/handlers/validators.py @@ -0,0 +1,29 @@ +from .errors import LocalhostForbidden +from ipaddress import ip_address +from contextlib import suppress + + +class BaseValidator: + + def __init__(self): + pass + + def validate(self, target: str, **kwargs): + pass + + +class LocalhostValidator(BaseValidator): + + def validate(self, target: str, **kwargs): + if target == "localhost": + raise LocalhostForbidden + with suppress(ValueError): + ip_addr = ip_address(target) + if any( + [ip_addr.is_loopback, + ip_addr.is_private, + ip_addr.is_multicast, + ip_addr.is_link_local, + ip_addr.is_unspecified] + ): + raise LocalhostForbidden