mirror of
https://github.com/catspace-dev/unicheckbot.git
synced 2024-11-24 23:03:45 +03:00
throttling middleware, close #9
This commit is contained in:
parent
41cec1f044
commit
4f7fb7f1df
|
@ -1,6 +1,6 @@
|
||||||
from aiogram import Bot, Dispatcher, executor
|
from aiogram import Bot, Dispatcher, executor
|
||||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware
|
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
|
||||||
import config
|
import config
|
||||||
import handlers
|
import handlers
|
||||||
|
|
||||||
|
@ -10,11 +10,12 @@ dp = Dispatcher(telegram_bot, storage=storage)
|
||||||
|
|
||||||
|
|
||||||
def on_startup():
|
def on_startup():
|
||||||
handlers.default.setup(dp)
|
handlers.default.setup(dp)
|
||||||
dp.middleware.setup(WriteCommandMetric())
|
dp.middleware.setup(WriteCommandMetric())
|
||||||
dp.middleware.setup(LoggingMiddleware())
|
dp.middleware.setup(LoggingMiddleware())
|
||||||
|
dp.middleware.setup(ThrottlingMiddleware())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
on_startup()
|
on_startup()
|
||||||
executor.start_polling(dp, skip_updates=True)
|
executor.start_polling(dp, skip_updates=True)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from httpx import Response
|
||||||
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
|
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
|
||||||
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
|
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
|
||||||
from ..metrics import push_status_metric
|
from ..metrics import push_status_metric
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
icmp_help_message = """
|
icmp_help_message = """
|
||||||
❓ Производит проверку хоста по протоколу ICMP.
|
❓ Производит проверку хоста по протоколу ICMP.
|
||||||
|
@ -19,6 +20,7 @@ class ICMPCheckerHandler(CheckerBaseHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ICMPCheckerHandler, self).__init__()
|
super(ICMPCheckerHandler, self).__init__()
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def handler(self, message: Message):
|
async def handler(self, message: Message):
|
||||||
try:
|
try:
|
||||||
args = await self.process_args(message.text)
|
args = await self.process_args(message.text)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from httpx import Response
|
||||||
|
|
||||||
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
|
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
|
||||||
from tgbot.handlers.metrics import push_status_metric
|
from tgbot.handlers.metrics import push_status_metric
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
minecraft_help_message = """
|
minecraft_help_message = """
|
||||||
❓ Получает статистику о Minecraft сервере
|
❓ Получает статистику о Minecraft сервере
|
||||||
|
@ -24,6 +25,7 @@ class MinecraftCheckerHandler(CheckerBaseHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def handler(self, message: Message):
|
async def handler(self, message: Message):
|
||||||
await self.target_port_handler(message)
|
await self.target_port_handler(message)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from aiogram.types import Message
|
from aiogram.types import Message
|
||||||
from tgbot.nodes import nodes
|
from tgbot.nodes import nodes
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
start_message = f"""
|
start_message = f"""
|
||||||
Привет! Добро пожаловать в @hostinfobot!\n
|
Привет! Добро пожаловать в @hostinfobot!\n
|
||||||
|
@ -20,6 +21,6 @@ start_message = f"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def start_cmd(msg: Message):
|
async def start_cmd(msg: Message):
|
||||||
await msg.answer(start_message, parse_mode='markdown')
|
await msg.answer(start_message, parse_mode='markdown')
|
|
@ -5,6 +5,7 @@ from httpx import Response
|
||||||
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
|
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
|
||||||
from tgbot.handlers.helpers import check_int
|
from tgbot.handlers.helpers import check_int
|
||||||
from tgbot.handlers.metrics import push_status_metric
|
from tgbot.handlers.metrics import push_status_metric
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
tcp_help_message = """
|
tcp_help_message = """
|
||||||
❓ Производит проверку TCP порта, открыт ли он или нет
|
❓ Производит проверку TCP порта, открыт ли он или нет
|
||||||
|
@ -23,6 +24,7 @@ class TCPCheckerHandler(CheckerBaseHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def handler(self, message: Message):
|
async def handler(self, message: Message):
|
||||||
await self.target_port_handler(message)
|
await self.target_port_handler(message)
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from httpx import Response
|
||||||
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
|
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
|
||||||
from ..base import CheckerBaseHandler, process_args_for_host_port
|
from ..base import CheckerBaseHandler, process_args_for_host_port
|
||||||
from ..metrics import push_status_metric
|
from ..metrics import push_status_metric
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
web_help_message = """
|
web_help_message = """
|
||||||
❓ Производит проверку хоста по протоколу HTTP.
|
❓ Производит проверку хоста по протоколу HTTP.
|
||||||
|
@ -22,6 +23,7 @@ class WebCheckerHandler(CheckerBaseHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def handler(self, message: Message):
|
async def handler(self, message: Message):
|
||||||
await self.target_port_handler(message)
|
await self.target_port_handler(message)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from aiogram.types import Message
|
|
||||||
import whois
|
import whois
|
||||||
|
from aiogram.types import Message
|
||||||
|
|
||||||
from tgbot.handlers.helpers import validate_local
|
from tgbot.handlers.helpers import validate_local
|
||||||
|
from tgbot.middlewares.throttling import rate_limit
|
||||||
|
|
||||||
whois_help_message = """
|
whois_help_message = """
|
||||||
❓ Вернёт информацию о домене.
|
❓ Вернёт информацию о домене.
|
||||||
|
@ -64,7 +65,7 @@ def create_whois_message(domain: str) -> str:
|
||||||
message += f"\n🔐 DNSSec: {dnssec}"
|
message += f"\n🔐 DNSSec: {dnssec}"
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@rate_limit
|
||||||
async def whois_cmd(msg: Message):
|
async def whois_cmd(msg: Message):
|
||||||
args = msg.text.split(" ")
|
args = msg.text.split(" ")
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
from tgbot.middlewares.write_command_metric import WriteCommandMetric
|
from tgbot.middlewares.write_command_metric import WriteCommandMetric
|
||||||
from tgbot.middlewares.logging import LoggingMiddleware
|
from tgbot.middlewares.logging import LoggingMiddleware
|
||||||
|
from tgbot.middlewares.throttling import ThrottlingMiddleware
|
||||||
|
|
76
apps/tgbot/tgbot/middlewares/throttling.py
Normal file
76
apps/tgbot/tgbot/middlewares/throttling.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
from aiogram import Dispatcher, types
|
||||||
|
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
|
||||||
|
from aiogram.dispatcher.handler import CancelHandler, current_handler
|
||||||
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
|
from aiogram.utils.exceptions import Throttled
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit(func):
|
||||||
|
setattr(func, 'throttling_rate_limit', 2)
|
||||||
|
setattr(func, 'throttling_key', 'message')
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class ThrottlingMiddleware(BaseMiddleware):
|
||||||
|
"""
|
||||||
|
Simple middleware
|
||||||
|
TODO: Rewrite
|
||||||
|
From https://docs.aiogram.dev/en/latest/examples/middleware_and_antiflood.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
|
||||||
|
self.rate_limit = limit
|
||||||
|
self.prefix = key_prefix
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def on_process_message(self, message: types.Message, data: dict):
|
||||||
|
handler = current_handler.get()
|
||||||
|
|
||||||
|
dispatcher = Dispatcher.get_current()
|
||||||
|
if handler:
|
||||||
|
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||||
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
|
else:
|
||||||
|
limit = self.rate_limit
|
||||||
|
key = f"{self.prefix}_message"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await dispatcher.throttle(key, rate=limit)
|
||||||
|
except Throttled as t:
|
||||||
|
await self.message_throttled(message, t)
|
||||||
|
raise CancelHandler()
|
||||||
|
|
||||||
|
async def message_throttled(self, message: types.Message, throttled: Throttled):
|
||||||
|
"""
|
||||||
|
Notify user only on first exceed and notify about unlocking only on last exceed
|
||||||
|
|
||||||
|
:param message:
|
||||||
|
:param throttled:
|
||||||
|
"""
|
||||||
|
handler = current_handler.get()
|
||||||
|
dispatcher = Dispatcher.get_current()
|
||||||
|
if handler:
|
||||||
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
|
else:
|
||||||
|
key = f"{self.prefix}_message"
|
||||||
|
|
||||||
|
# Calculate how many time is left till the block ends
|
||||||
|
delta = throttled.rate - throttled.delta
|
||||||
|
|
||||||
|
# Prevent flooding
|
||||||
|
if throttled.exceeded_count <= 2:
|
||||||
|
await message.reply('❗️Слишком мого запросов. '
|
||||||
|
'Подождите еще несколько секунд перед отправкой следующего сообщения.'
|
||||||
|
'\nВ целях предотвращения флуда, бот перестанет отвечать на ваши сообщения '
|
||||||
|
'на некоторое время.')
|
||||||
|
|
||||||
|
# Sleep.
|
||||||
|
await asyncio.sleep(delta)
|
||||||
|
|
||||||
|
# Check lock status
|
||||||
|
thr = await dispatcher.check_key(key)
|
||||||
|
|
||||||
|
# If current message is not last with current key - do not send message
|
||||||
|
if thr.exceeded_count == throttled.exceeded_count:
|
||||||
|
await message.reply('Unlocked.')
|
Loading…
Reference in New Issue
Block a user