throttling middleware, close #9

This commit is contained in:
kiriharu 2021-01-05 21:26:24 +03:00
parent 41cec1f044
commit 4f7fb7f1df
9 changed files with 97 additions and 9 deletions

View File

@ -1,6 +1,6 @@
from aiogram import Bot, Dispatcher, executor
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
import config
import handlers
@ -13,6 +13,7 @@ def on_startup():
handlers.default.setup(dp)
dp.middleware.setup(WriteCommandMetric())
dp.middleware.setup(LoggingMiddleware())
dp.middleware.setup(ThrottlingMiddleware())
if __name__ == '__main__':

View File

@ -3,6 +3,7 @@ from httpx import Response
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
from ..metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
icmp_help_message = """
Производит проверку хоста по протоколу ICMP.
@ -19,6 +20,7 @@ class ICMPCheckerHandler(CheckerBaseHandler):
def __init__(self):
super(ICMPCheckerHandler, self).__init__()
@rate_limit
async def handler(self, message: Message):
try:
args = await self.process_args(message.text)

View File

@ -4,6 +4,7 @@ from httpx import Response
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
from tgbot.handlers.metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
minecraft_help_message = """
Получает статистику о Minecraft сервере
@ -24,6 +25,7 @@ class MinecraftCheckerHandler(CheckerBaseHandler):
def __init__(self):
super().__init__()
@rate_limit
async def handler(self, message: Message):
await self.target_port_handler(message)

View File

@ -1,5 +1,6 @@
from aiogram.types import Message
from tgbot.nodes import nodes
from tgbot.middlewares.throttling import rate_limit
start_message = f"""
Привет! Добро пожаловать в @hostinfobot!\n
@ -20,6 +21,6 @@ start_message = f"""
"""
@rate_limit
async def start_cmd(msg: Message):
await msg.answer(start_message, parse_mode='markdown')

View File

@ -5,6 +5,7 @@ from httpx import Response
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
from tgbot.handlers.helpers import check_int
from tgbot.handlers.metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
tcp_help_message = """
Производит проверку TCP порта, открыт ли он или нет
@ -23,6 +24,7 @@ class TCPCheckerHandler(CheckerBaseHandler):
def __init__(self):
super().__init__()
@rate_limit
async def handler(self, message: Message):
await self.target_port_handler(message)

View File

@ -3,6 +3,7 @@ from httpx import Response
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
from ..base import CheckerBaseHandler, process_args_for_host_port
from ..metrics import push_status_metric
from tgbot.middlewares.throttling import rate_limit
web_help_message = """
Производит проверку хоста по протоколу HTTP.
@ -22,6 +23,7 @@ class WebCheckerHandler(CheckerBaseHandler):
def __init__(self):
super().__init__()
@rate_limit
async def handler(self, message: Message):
await self.target_port_handler(message)

View File

@ -1,7 +1,8 @@
from aiogram.types import Message
import whois
from aiogram.types import Message
from tgbot.handlers.helpers import validate_local
from tgbot.middlewares.throttling import rate_limit
whois_help_message = """
Вернёт информацию о домене.
@ -64,7 +65,7 @@ def create_whois_message(domain: str) -> str:
message += f"\n🔐 DNSSec: {dnssec}"
return message
@rate_limit
async def whois_cmd(msg: Message):
args = msg.text.split(" ")
if len(args) == 1:

View File

@ -1,2 +1,3 @@
from tgbot.middlewares.write_command_metric import WriteCommandMetric
from tgbot.middlewares.logging import LoggingMiddleware
from tgbot.middlewares.throttling import ThrottlingMiddleware

View File

@ -0,0 +1,76 @@
from aiogram import Dispatcher, types
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler, current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils.exceptions import Throttled
import asyncio
def rate_limit(func):
setattr(func, 'throttling_rate_limit', 2)
setattr(func, 'throttling_key', 'message')
return func
class ThrottlingMiddleware(BaseMiddleware):
"""
Simple middleware
TODO: Rewrite
From https://docs.aiogram.dev/en/latest/examples/middleware_and_antiflood.html
"""
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
super().__init__()
async def on_process_message(self, message: types.Message, data: dict):
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
await self.message_throttled(message, t)
raise CancelHandler()
async def message_throttled(self, message: types.Message, throttled: Throttled):
"""
Notify user only on first exceed and notify about unlocking only on last exceed
:param message:
:param throttled:
"""
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
# Calculate how many time is left till the block ends
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await message.reply('❗️Слишком мого запросов. '
'Подождите еще несколько секунд перед отправкой следующего сообщения.'
'\nВ целях предотвращения флуда, бот перестанет отвечать на ваши сообщения '
'на некоторое время.')
# Sleep.
await asyncio.sleep(delta)
# Check lock status
thr = await dispatcher.check_key(key)
# If current message is not last with current key - do not send message
if thr.exceeded_count == throttled.exceeded_count:
await message.reply('Unlocked.')