diff --git a/apps/tgbot/pyproject.toml b/apps/tgbot/pyproject.toml index e2893b7..fd54c57 100644 --- a/apps/tgbot/pyproject.toml +++ b/apps/tgbot/pyproject.toml @@ -17,10 +17,9 @@ tortoise-orm = "^0.16.20" aiomysql = "^0.0.21" [tool.poetry.dev-dependencies] +pytest = "^6.2.2" +flake8 = "^3.8.4" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" - -[tool.poetry.scripts] -test = 'tgbot.test:run_all' diff --git a/apps/tgbot/tgbot/handlers/base.py b/apps/tgbot/tgbot/handlers/base.py index 63ee141..c8b9538 100644 --- a/apps/tgbot/tgbot/handlers/base.py +++ b/apps/tgbot/tgbot/handlers/base.py @@ -102,14 +102,32 @@ class CheckerTargetPortHandler(CheckerBaseHandler): ) -def process_args_for_host_port(text: str, default_port: int) -> list: +def parse_host_port(text: str, default_port: int) -> Tuple[str, int]: + """Parse host:port + """ + text = text.strip() port = default_port + host = text + if ":" in text: + host, port = text.rsplit(":", 1) + elif " " in text: + host, port = text.rsplit(" ", 1) + + try: + port = int(port) + # !Important: Don't check range if port == default_port! + assert port == default_port or port in range(1, 65_536) + except (ValueError, AssertionError): + raise InvalidPort(port) + + return (host, port) + + +def process_args_for_host_port(text: str, default_port: int) -> Tuple[str, int]: + """Parse target from command + """ args = text.split(' ', 1) if len(args) != 2: raise NotEnoughArgs() - host = args[1] - if ":" in host: - host, port = host.rsplit(":", 1) - elif " " in host: - host, port = host.rsplit(" ", 1) - return [host, port] + target = args[1] + return parse_host_port(target, default_port) diff --git a/apps/tgbot/tgbot/handlers/default/tcp.py b/apps/tgbot/tgbot/handlers/default/tcp.py index f0025a8..29549dc 100644 --- a/apps/tgbot/tgbot/handlers/default/tcp.py +++ b/apps/tgbot/tgbot/handlers/default/tcp.py @@ -1,9 +1,9 @@ +from typing import Tuple from aiogram.types import Message from core.coretypes import ResponseStatus, ErrorPayload, PortResponse from httpx import Response -from tgbot.handlers.base import CheckerTargetPortHandler, NotEnoughArgs, InvalidPort -from tgbot.handlers.helpers import check_int +from tgbot.handlers.base import CheckerTargetPortHandler, NotEnoughArgs, InvalidPort, parse_host_port from tgbot.handlers.metrics import push_status_metric from tgbot.middlewares.throttling import rate_limit @@ -29,20 +29,15 @@ class TCPCheckerHandler(CheckerTargetPortHandler): async def handler(self, message: Message): await super(TCPCheckerHandler, self).handler(message) - def process_args(self, text: str) -> list: + def process_args(self, text: str) -> Tuple[str, int]: args = text.split(' ', 1) if len(args) != 2: raise NotEnoughArgs() host = args[1] - if ":" in host: - host, port = host.rsplit(":", 1) - elif " " in host: - host, port = host.split(maxsplit=1) - else: + host, port = parse_host_port(host, -1) + if port == -1: raise NotEnoughArgs() - if not check_int(port): - raise InvalidPort() - return [host, port] + return (host, port) async def prepare_message(self, res: Response): message, status = await self.message_std_vals(res) diff --git a/apps/tgbot/tgbot/test/__init__.py b/apps/tgbot/tgbot/test/__init__.py index 1ba86c5..aec7c3e 100644 --- a/apps/tgbot/tgbot/test/__init__.py +++ b/apps/tgbot/tgbot/test/__init__.py @@ -1,2 +1 @@ -def run_all(): - from . import port_parsers +from .test_port_parsers import * diff --git a/apps/tgbot/tgbot/test/port_parsers.py b/apps/tgbot/tgbot/test/port_parsers.py deleted file mode 100644 index 7f770f4..0000000 --- a/apps/tgbot/tgbot/test/port_parsers.py +++ /dev/null @@ -1,57 +0,0 @@ -import asyncio - -from ..handlers.default.tcp import TCPCheckerHandler -from ..handlers.base import process_args_for_host_port,\ - NotEnoughArgs, InvalidPort - - -try: - args = "/cmd" - process_args_for_host_port(args, 443) -except NotEnoughArgs: - pass -args = "/cmd example.com" -host, port = process_args_for_host_port(args, 443) -assert port == 443 - -args = "/cmd example.com 42" -host, port = process_args_for_host_port(args, 443) -assert port == "42" # TODO: FIX THIS SHIT - -args = "/cmd example.com:42" -host, port = process_args_for_host_port(args, 443) -assert port == "42" - -try: - args = "/cmd example.com fucktests" -except InvalidPort: - pass - -method = TCPCheckerHandler().process_args - - -async def test(): - try: - args = "/cmd" - await method(args) - args = "/cmd example.com" - await method(args) - except NotEnoughArgs: - pass - - args = "/cmd example.com 42" - host, port = await method(args) - assert port == "42" - - args = "/cmd example.com:42" - host, port = await method(args) - assert port == "42" - - try: - args = "/cmd example.com jdbnjsbndjsd" - await method(args) - except InvalidPort: - pass - - -asyncio.run(test()) diff --git a/apps/tgbot/tgbot/test/test_port_parsers.py b/apps/tgbot/tgbot/test/test_port_parsers.py new file mode 100644 index 0000000..5e3dd8c --- /dev/null +++ b/apps/tgbot/tgbot/test/test_port_parsers.py @@ -0,0 +1,75 @@ +from unittest import TestCase +import asyncio + +from ..handlers.default.tcp import TCPCheckerHandler +from ..handlers.base import process_args_for_host_port,\ + NotEnoughArgs, InvalidPort + + +class TestArgsProc(TestCase): + + def test_exceptions(self): + """Test exceptions being raised + on invalid commands + """ + cases = [ + ('/cmd', NotEnoughArgs), + ('/cmd example.com testsarenice', InvalidPort) + ] + for cmd, exc in cases: + with self.subTest(command=cmd): + self.assertRaises( + exc, + lambda: process_args_for_host_port(cmd, 443) + ) + + def test_host_port(self): + """Test that host and port are parsed correctly + """ + cases = [ + ('/cmd example.com', 'example.com', 443), + ('/cmd example.com 42', 'example.com', 42), + ('/cmd example.com:42', 'example.com', 42) + ] + + for cmd, host, port in cases: + with self.subTest(cmd=cmd, host=host, port=port): + test_host, test_port = process_args_for_host_port(cmd, 443) + self.assertEqual(test_host, host) + self.assertEqual(test_port, port) + + +class TestTCPCheckerHandler(TestCase): + def setUp(self) -> None: + self.method = TCPCheckerHandler().process_args + return super().setUp() + + def test_exceptions(self): + """Test all appropriate excpetions are raised. + """ + cases = [ + ('/cmd', NotEnoughArgs), + ('/cmd example.com', NotEnoughArgs), + ('/cmd example.com jdbnjsbndjsd', InvalidPort) + ] + + for cmd, exc in cases: + with self.subTest(cmd=cmd): + self.assertRaises( + exc, + lambda: self.method(cmd) + ) + + def test_host_port(self): + """Test that host and port are parsed correctly + """ + cases = [ + ('/cmd example.com 42', 'example.com', 42), + ('/cmd example.com:65', 'example.com', 65) + ] + + for cmd, host, port in cases: + with self.subTest(cmd=cmd, host=host, port=port): + test_host, test_port = self.method(cmd) + self.assertEqual(test_host, host) + self.assertEqual(test_port, port) diff --git a/pyproject.toml b/pyproject.toml index 2f65f47..bb7da24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ authors = ["kiriharu "] python = "^3.8.2" [tool.poetry.dev-dependencies] +pytest = "^6.2.2" +flake8 = "^3.8.4" [build-system] requires = ["poetry-core>=1.0.0"]