nicer tests and host:port parsing

This commit is contained in:
Crystal Melting Dot 2021-02-20 18:24:39 +03:00
parent f81f1953d1
commit 82409c669b
7 changed files with 111 additions and 80 deletions

View File

@ -17,10 +17,9 @@ tortoise-orm = "^0.16.20"
aiomysql = "^0.0.21" aiomysql = "^0.0.21"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.2"
flake8 = "^3.8.4"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts]
test = 'tgbot.test:run_all'

View File

@ -102,14 +102,32 @@ class CheckerTargetPortHandler(CheckerBaseHandler):
) )
def process_args_for_host_port(text: str, default_port: int) -> list: def parse_host_port(text: str, default_port: int) -> Tuple[str, int]:
"""Parse host:port
"""
text = text.strip()
port = default_port port = default_port
host = text
if ":" in text:
host, port = text.rsplit(":", 1)
elif " " in text:
host, port = text.rsplit(" ", 1)
try:
port = int(port)
# !Important: Don't check range if port == default_port!
assert port == default_port or port in range(1, 65_536)
except (ValueError, AssertionError):
raise InvalidPort(port)
return (host, port)
def process_args_for_host_port(text: str, default_port: int) -> Tuple[str, int]:
"""Parse target from command
"""
args = text.split(' ', 1) args = text.split(' ', 1)
if len(args) != 2: if len(args) != 2:
raise NotEnoughArgs() raise NotEnoughArgs()
host = args[1] target = args[1]
if ":" in host: return parse_host_port(target, default_port)
host, port = host.rsplit(":", 1)
elif " " in host:
host, port = host.rsplit(" ", 1)
return [host, port]

View File

@ -1,9 +1,9 @@
from typing import Tuple
from aiogram.types import Message from aiogram.types import Message
from core.coretypes import ResponseStatus, ErrorPayload, PortResponse from core.coretypes import ResponseStatus, ErrorPayload, PortResponse
from httpx import Response from httpx import Response
from tgbot.handlers.base import CheckerTargetPortHandler, NotEnoughArgs, InvalidPort from tgbot.handlers.base import CheckerTargetPortHandler, NotEnoughArgs, InvalidPort, parse_host_port
from tgbot.handlers.helpers import check_int
from tgbot.handlers.metrics import push_status_metric from tgbot.handlers.metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit from tgbot.middlewares.throttling import rate_limit
@ -29,20 +29,15 @@ class TCPCheckerHandler(CheckerTargetPortHandler):
async def handler(self, message: Message): async def handler(self, message: Message):
await super(TCPCheckerHandler, self).handler(message) await super(TCPCheckerHandler, self).handler(message)
def process_args(self, text: str) -> list: def process_args(self, text: str) -> Tuple[str, int]:
args = text.split(' ', 1) args = text.split(' ', 1)
if len(args) != 2: if len(args) != 2:
raise NotEnoughArgs() raise NotEnoughArgs()
host = args[1] host = args[1]
if ":" in host: host, port = parse_host_port(host, -1)
host, port = host.rsplit(":", 1) if port == -1:
elif " " in host:
host, port = host.split(maxsplit=1)
else:
raise NotEnoughArgs() raise NotEnoughArgs()
if not check_int(port): return (host, port)
raise InvalidPort()
return [host, port]
async def prepare_message(self, res: Response): async def prepare_message(self, res: Response):
message, status = await self.message_std_vals(res) message, status = await self.message_std_vals(res)

View File

@ -1,2 +1 @@
def run_all(): from .test_port_parsers import *
from . import port_parsers

View File

@ -1,57 +0,0 @@
import asyncio
from ..handlers.default.tcp import TCPCheckerHandler
from ..handlers.base import process_args_for_host_port,\
NotEnoughArgs, InvalidPort
try:
args = "/cmd"
process_args_for_host_port(args, 443)
except NotEnoughArgs:
pass
args = "/cmd example.com"
host, port = process_args_for_host_port(args, 443)
assert port == 443
args = "/cmd example.com 42"
host, port = process_args_for_host_port(args, 443)
assert port == "42" # TODO: FIX THIS SHIT
args = "/cmd example.com:42"
host, port = process_args_for_host_port(args, 443)
assert port == "42"
try:
args = "/cmd example.com fucktests"
except InvalidPort:
pass
method = TCPCheckerHandler().process_args
async def test():
try:
args = "/cmd"
await method(args)
args = "/cmd example.com"
await method(args)
except NotEnoughArgs:
pass
args = "/cmd example.com 42"
host, port = await method(args)
assert port == "42"
args = "/cmd example.com:42"
host, port = await method(args)
assert port == "42"
try:
args = "/cmd example.com jdbnjsbndjsd"
await method(args)
except InvalidPort:
pass
asyncio.run(test())

View File

@ -0,0 +1,75 @@
from unittest import TestCase
import asyncio
from ..handlers.default.tcp import TCPCheckerHandler
from ..handlers.base import process_args_for_host_port,\
NotEnoughArgs, InvalidPort
class TestArgsProc(TestCase):
def test_exceptions(self):
"""Test exceptions being raised
on invalid commands
"""
cases = [
('/cmd', NotEnoughArgs),
('/cmd example.com testsarenice', InvalidPort)
]
for cmd, exc in cases:
with self.subTest(command=cmd):
self.assertRaises(
exc,
lambda: process_args_for_host_port(cmd, 443)
)
def test_host_port(self):
"""Test that host and port are parsed correctly
"""
cases = [
('/cmd example.com', 'example.com', 443),
('/cmd example.com 42', 'example.com', 42),
('/cmd example.com:42', 'example.com', 42)
]
for cmd, host, port in cases:
with self.subTest(cmd=cmd, host=host, port=port):
test_host, test_port = process_args_for_host_port(cmd, 443)
self.assertEqual(test_host, host)
self.assertEqual(test_port, port)
class TestTCPCheckerHandler(TestCase):
def setUp(self) -> None:
self.method = TCPCheckerHandler().process_args
return super().setUp()
def test_exceptions(self):
"""Test all appropriate excpetions are raised.
"""
cases = [
('/cmd', NotEnoughArgs),
('/cmd example.com', NotEnoughArgs),
('/cmd example.com jdbnjsbndjsd', InvalidPort)
]
for cmd, exc in cases:
with self.subTest(cmd=cmd):
self.assertRaises(
exc,
lambda: self.method(cmd)
)
def test_host_port(self):
"""Test that host and port are parsed correctly
"""
cases = [
('/cmd example.com 42', 'example.com', 42),
('/cmd example.com:65', 'example.com', 65)
]
for cmd, host, port in cases:
with self.subTest(cmd=cmd, host=host, port=port):
test_host, test_port = self.method(cmd)
self.assertEqual(test_host, host)
self.assertEqual(test_port, port)

View File

@ -8,6 +8,8 @@ authors = ["kiriharu <kiriharu@yandex.ru>"]
python = "^3.8.2" python = "^3.8.2"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.2"
flake8 = "^3.8.4"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]