throttling middleware, close #9

This commit is contained in:
kiriharu 2021-01-05 21:26:24 +03:00
parent 41cec1f044
commit 4f7fb7f1df
9 changed files with 97 additions and 9 deletions

View File

@ -1,6 +1,6 @@
from aiogram import Bot, Dispatcher, executor from aiogram import Bot, Dispatcher, executor
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
import config import config
import handlers import handlers
@ -10,11 +10,12 @@ dp = Dispatcher(telegram_bot, storage=storage)
def on_startup(): def on_startup():
handlers.default.setup(dp) handlers.default.setup(dp)
dp.middleware.setup(WriteCommandMetric()) dp.middleware.setup(WriteCommandMetric())
dp.middleware.setup(LoggingMiddleware()) dp.middleware.setup(LoggingMiddleware())
dp.middleware.setup(ThrottlingMiddleware())
if __name__ == '__main__': if __name__ == '__main__':
on_startup() on_startup()
executor.start_polling(dp, skip_updates=True) executor.start_polling(dp, skip_updates=True)

View File

@ -3,6 +3,7 @@ from httpx import Response
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
from ..metrics import push_status_metric from ..metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
icmp_help_message = """ icmp_help_message = """
Производит проверку хоста по протоколу ICMP. Производит проверку хоста по протоколу ICMP.
@ -19,6 +20,7 @@ class ICMPCheckerHandler(CheckerBaseHandler):
def __init__(self): def __init__(self):
super(ICMPCheckerHandler, self).__init__() super(ICMPCheckerHandler, self).__init__()
@rate_limit
async def handler(self, message: Message): async def handler(self, message: Message):
try: try:
args = await self.process_args(message.text) args = await self.process_args(message.text)

View File

@ -4,6 +4,7 @@ from httpx import Response
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
from tgbot.handlers.metrics import push_status_metric from tgbot.handlers.metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
minecraft_help_message = """ minecraft_help_message = """
Получает статистику о Minecraft сервере Получает статистику о Minecraft сервере
@ -24,6 +25,7 @@ class MinecraftCheckerHandler(CheckerBaseHandler):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@rate_limit
async def handler(self, message: Message): async def handler(self, message: Message):
await self.target_port_handler(message) await self.target_port_handler(message)

View File

@ -1,5 +1,6 @@
from aiogram.types import Message from aiogram.types import Message
from tgbot.nodes import nodes from tgbot.nodes import nodes
from tgbot.middlewares.throttling import rate_limit
start_message = f""" start_message = f"""
Привет! Добро пожаловать в @hostinfobot!\n Привет! Добро пожаловать в @hostinfobot!\n
@ -20,6 +21,6 @@ start_message = f"""
""" """
@rate_limit
async def start_cmd(msg: Message): async def start_cmd(msg: Message):
await msg.answer(start_message, parse_mode='markdown') await msg.answer(start_message, parse_mode='markdown')

View File

@ -5,6 +5,7 @@ from httpx import Response
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
from tgbot.handlers.helpers import check_int from tgbot.handlers.helpers import check_int
from tgbot.handlers.metrics import push_status_metric from tgbot.handlers.metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
tcp_help_message = """ tcp_help_message = """
Производит проверку TCP порта, открыт ли он или нет Производит проверку TCP порта, открыт ли он или нет
@ -23,6 +24,7 @@ class TCPCheckerHandler(CheckerBaseHandler):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@rate_limit
async def handler(self, message: Message): async def handler(self, message: Message):
await self.target_port_handler(message) await self.target_port_handler(message)

View File

@ -3,6 +3,7 @@ from httpx import Response
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
from ..base import CheckerBaseHandler, process_args_for_host_port from ..base import CheckerBaseHandler, process_args_for_host_port
from ..metrics import push_status_metric from ..metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
web_help_message = """ web_help_message = """
Производит проверку хоста по протоколу HTTP. Производит проверку хоста по протоколу HTTP.
@ -22,6 +23,7 @@ class WebCheckerHandler(CheckerBaseHandler):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@rate_limit
async def handler(self, message: Message): async def handler(self, message: Message):
await self.target_port_handler(message) await self.target_port_handler(message)

View File

@ -1,7 +1,8 @@
from aiogram.types import Message
import whois import whois
from aiogram.types import Message
from tgbot.handlers.helpers import validate_local from tgbot.handlers.helpers import validate_local
from tgbot.middlewares.throttling import rate_limit
whois_help_message = """ whois_help_message = """
Вернёт информацию о домене. Вернёт информацию о домене.
@ -64,7 +65,7 @@ def create_whois_message(domain: str) -> str:
message += f"\n🔐 DNSSec: {dnssec}" message += f"\n🔐 DNSSec: {dnssec}"
return message return message
@rate_limit
async def whois_cmd(msg: Message): async def whois_cmd(msg: Message):
args = msg.text.split(" ") args = msg.text.split(" ")
if len(args) == 1: if len(args) == 1:

View File

@ -1,2 +1,3 @@
from tgbot.middlewares.write_command_metric import WriteCommandMetric from tgbot.middlewares.write_command_metric import WriteCommandMetric
from tgbot.middlewares.logging import LoggingMiddleware from tgbot.middlewares.logging import LoggingMiddleware
from tgbot.middlewares.throttling import ThrottlingMiddleware

View File

@ -0,0 +1,76 @@
from aiogram import Dispatcher, types
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler, current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils.exceptions import Throttled
import asyncio
def rate_limit(func):
setattr(func, 'throttling_rate_limit', 2)
setattr(func, 'throttling_key', 'message')
return func
class ThrottlingMiddleware(BaseMiddleware):
"""
Simple middleware
TODO: Rewrite
From https://docs.aiogram.dev/en/latest/examples/middleware_and_antiflood.html
"""
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
super().__init__()
async def on_process_message(self, message: types.Message, data: dict):
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
await self.message_throttled(message, t)
raise CancelHandler()
async def message_throttled(self, message: types.Message, throttled: Throttled):
"""
Notify user only on first exceed and notify about unlocking only on last exceed
:param message:
:param throttled:
"""
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
# Calculate how many time is left till the block ends
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await message.reply('❗️Слишком мого запросов. '
'Подождите еще несколько секунд перед отправкой следующего сообщения.'
'\nВ целях предотвращения флуда, бот перестанет отвечать на ваши сообщения '
'на некоторое время.')
# Sleep.
await asyncio.sleep(delta)
# Check lock status
thr = await dispatcher.check_key(key)
# If current message is not last with current key - do not send message
if thr.exceeded_count == throttled.exceeded_count:
await message.reply('Unlocked.')