mirror of
https://github.com/catspace-dev/unicheckbot.git
synced 2024-11-21 21:46:32 +03:00
throttling middleware, close #9
This commit is contained in:
parent
41cec1f044
commit
4f7fb7f1df
|
@ -1,6 +1,6 @@
|
|||
from aiogram import Bot, Dispatcher, executor
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware
|
||||
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
|
||||
import config
|
||||
import handlers
|
||||
|
||||
|
@ -13,6 +13,7 @@ def on_startup():
|
|||
handlers.default.setup(dp)
|
||||
dp.middleware.setup(WriteCommandMetric())
|
||||
dp.middleware.setup(LoggingMiddleware())
|
||||
dp.middleware.setup(ThrottlingMiddleware())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -3,6 +3,7 @@ from httpx import Response
|
|||
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
|
||||
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
|
||||
from ..metrics import push_status_metric
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
icmp_help_message = """
|
||||
❓ Производит проверку хоста по протоколу ICMP.
|
||||
|
@ -19,6 +20,7 @@ class ICMPCheckerHandler(CheckerBaseHandler):
|
|||
def __init__(self):
|
||||
super(ICMPCheckerHandler, self).__init__()
|
||||
|
||||
@rate_limit
|
||||
async def handler(self, message: Message):
|
||||
try:
|
||||
args = await self.process_args(message.text)
|
||||
|
|
|
@ -4,6 +4,7 @@ from httpx import Response
|
|||
|
||||
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
|
||||
from tgbot.handlers.metrics import push_status_metric
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
minecraft_help_message = """
|
||||
❓ Получает статистику о Minecraft сервере
|
||||
|
@ -24,6 +25,7 @@ class MinecraftCheckerHandler(CheckerBaseHandler):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@rate_limit
|
||||
async def handler(self, message: Message):
|
||||
await self.target_port_handler(message)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from aiogram.types import Message
|
||||
from tgbot.nodes import nodes
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
start_message = f"""
|
||||
Привет! Добро пожаловать в @hostinfobot!\n
|
||||
|
@ -20,6 +21,6 @@ start_message = f"""
|
|||
|
||||
"""
|
||||
|
||||
|
||||
@rate_limit
|
||||
async def start_cmd(msg: Message):
|
||||
await msg.answer(start_message, parse_mode='markdown')
|
|
@ -5,6 +5,7 @@ from httpx import Response
|
|||
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
|
||||
from tgbot.handlers.helpers import check_int
|
||||
from tgbot.handlers.metrics import push_status_metric
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
tcp_help_message = """
|
||||
❓ Производит проверку TCP порта, открыт ли он или нет
|
||||
|
@ -23,6 +24,7 @@ class TCPCheckerHandler(CheckerBaseHandler):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@rate_limit
|
||||
async def handler(self, message: Message):
|
||||
await self.target_port_handler(message)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from httpx import Response
|
|||
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
|
||||
from ..base import CheckerBaseHandler, process_args_for_host_port
|
||||
from ..metrics import push_status_metric
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
web_help_message = """
|
||||
❓ Производит проверку хоста по протоколу HTTP.
|
||||
|
@ -22,6 +23,7 @@ class WebCheckerHandler(CheckerBaseHandler):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@rate_limit
|
||||
async def handler(self, message: Message):
|
||||
await self.target_port_handler(message)
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from aiogram.types import Message
|
||||
import whois
|
||||
from aiogram.types import Message
|
||||
|
||||
from tgbot.handlers.helpers import validate_local
|
||||
from tgbot.middlewares.throttling import rate_limit
|
||||
|
||||
whois_help_message = """
|
||||
❓ Вернёт информацию о домене.
|
||||
|
@ -64,7 +65,7 @@ def create_whois_message(domain: str) -> str:
|
|||
message += f"\n🔐 DNSSec: {dnssec}"
|
||||
return message
|
||||
|
||||
|
||||
@rate_limit
|
||||
async def whois_cmd(msg: Message):
|
||||
args = msg.text.split(" ")
|
||||
if len(args) == 1:
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from tgbot.middlewares.write_command_metric import WriteCommandMetric
|
||||
from tgbot.middlewares.logging import LoggingMiddleware
|
||||
from tgbot.middlewares.throttling import ThrottlingMiddleware
|
||||
|
|
76
apps/tgbot/tgbot/middlewares/throttling.py
Normal file
76
apps/tgbot/tgbot/middlewares/throttling.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
from aiogram import Dispatcher, types
|
||||
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
|
||||
from aiogram.dispatcher.handler import CancelHandler, current_handler
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
from aiogram.utils.exceptions import Throttled
|
||||
import asyncio
|
||||
|
||||
|
||||
def rate_limit(func):
|
||||
setattr(func, 'throttling_rate_limit', 2)
|
||||
setattr(func, 'throttling_key', 'message')
|
||||
return func
|
||||
|
||||
|
||||
class ThrottlingMiddleware(BaseMiddleware):
|
||||
"""
|
||||
Simple middleware
|
||||
TODO: Rewrite
|
||||
From https://docs.aiogram.dev/en/latest/examples/middleware_and_antiflood.html
|
||||
"""
|
||||
|
||||
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
|
||||
self.rate_limit = limit
|
||||
self.prefix = key_prefix
|
||||
super().__init__()
|
||||
|
||||
async def on_process_message(self, message: types.Message, data: dict):
|
||||
handler = current_handler.get()
|
||||
|
||||
dispatcher = Dispatcher.get_current()
|
||||
if handler:
|
||||
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
limit = self.rate_limit
|
||||
key = f"{self.prefix}_message"
|
||||
|
||||
try:
|
||||
await dispatcher.throttle(key, rate=limit)
|
||||
except Throttled as t:
|
||||
await self.message_throttled(message, t)
|
||||
raise CancelHandler()
|
||||
|
||||
async def message_throttled(self, message: types.Message, throttled: Throttled):
|
||||
"""
|
||||
Notify user only on first exceed and notify about unlocking only on last exceed
|
||||
|
||||
:param message:
|
||||
:param throttled:
|
||||
"""
|
||||
handler = current_handler.get()
|
||||
dispatcher = Dispatcher.get_current()
|
||||
if handler:
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
key = f"{self.prefix}_message"
|
||||
|
||||
# Calculate how many time is left till the block ends
|
||||
delta = throttled.rate - throttled.delta
|
||||
|
||||
# Prevent flooding
|
||||
if throttled.exceeded_count <= 2:
|
||||
await message.reply('❗️Слишком мого запросов. '
|
||||
'Подождите еще несколько секунд перед отправкой следующего сообщения.'
|
||||
'\nВ целях предотвращения флуда, бот перестанет отвечать на ваши сообщения '
|
||||
'на некоторое время.')
|
||||
|
||||
# Sleep.
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
# Check lock status
|
||||
thr = await dispatcher.check_key(key)
|
||||
|
||||
# If current message is not last with current key - do not send message
|
||||
if thr.exceeded_count == throttled.exceeded_count:
|
||||
await message.reply('Unlocked.')
|
Loading…
Reference in New Issue
Block a user