mirror of
				https://github.com/catspace-dev/unicheckbot.git
				synced 2025-11-04 01:17:30 +03:00 
			
		
		
		
	throttling middleware, close #9
This commit is contained in:
		
							parent
							
								
									41cec1f044
								
							
						
					
					
						commit
						4f7fb7f1df
					
				| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
from aiogram import Bot, Dispatcher, executor
 | 
			
		||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
 | 
			
		||||
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware
 | 
			
		||||
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
 | 
			
		||||
import config
 | 
			
		||||
import handlers
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ def on_startup():
 | 
			
		|||
    handlers.default.setup(dp)
 | 
			
		||||
    dp.middleware.setup(WriteCommandMetric())
 | 
			
		||||
    dp.middleware.setup(LoggingMiddleware())
 | 
			
		||||
    dp.middleware.setup(ThrottlingMiddleware())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from httpx import Response
 | 
			
		|||
from core.coretypes import ErrorPayload, ICMPCheckerResponse, ResponseStatus
 | 
			
		||||
from ..base import CheckerBaseHandler, NotEnoughArgs, LocalhostForbidden
 | 
			
		||||
from ..metrics import push_status_metric
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
icmp_help_message = """
 | 
			
		||||
❓ Производит проверку хоста по протоколу ICMP.
 | 
			
		||||
| 
						 | 
				
			
			@ -19,6 +20,7 @@ class ICMPCheckerHandler(CheckerBaseHandler):
 | 
			
		|||
    def __init__(self):
 | 
			
		||||
        super(ICMPCheckerHandler, self).__init__()
 | 
			
		||||
 | 
			
		||||
    @rate_limit
 | 
			
		||||
    async def handler(self, message: Message):
 | 
			
		||||
        try:
 | 
			
		||||
            args = await self.process_args(message.text)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@ from httpx import Response
 | 
			
		|||
 | 
			
		||||
from tgbot.handlers.base import CheckerBaseHandler, process_args_for_host_port
 | 
			
		||||
from tgbot.handlers.metrics import push_status_metric
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
minecraft_help_message = """
 | 
			
		||||
❓ Получает статистику о Minecraft сервере
 | 
			
		||||
| 
						 | 
				
			
			@ -24,6 +25,7 @@ class MinecraftCheckerHandler(CheckerBaseHandler):
 | 
			
		|||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    @rate_limit
 | 
			
		||||
    async def handler(self, message: Message):
 | 
			
		||||
        await self.target_port_handler(message)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
from aiogram.types import Message
 | 
			
		||||
from tgbot.nodes import nodes
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
start_message = f"""
 | 
			
		||||
Привет! Добро пожаловать в @hostinfobot!\n
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +21,6 @@ start_message = f"""
 | 
			
		|||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rate_limit
 | 
			
		||||
async def start_cmd(msg: Message):
 | 
			
		||||
    await msg.answer(start_message, parse_mode='markdown')
 | 
			
		||||
| 
						 | 
				
			
			@ -5,6 +5,7 @@ from httpx import Response
 | 
			
		|||
from tgbot.handlers.base import CheckerBaseHandler, NotEnoughArgs, InvalidPort
 | 
			
		||||
from tgbot.handlers.helpers import check_int
 | 
			
		||||
from tgbot.handlers.metrics import push_status_metric
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
tcp_help_message = """
 | 
			
		||||
❓ Производит проверку TCP порта, открыт ли он или нет
 | 
			
		||||
| 
						 | 
				
			
			@ -23,6 +24,7 @@ class TCPCheckerHandler(CheckerBaseHandler):
 | 
			
		|||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    @rate_limit
 | 
			
		||||
    async def handler(self, message: Message):
 | 
			
		||||
        await self.target_port_handler(message)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from httpx import Response
 | 
			
		|||
from core.coretypes import ResponseStatus, HTTP_EMOJI, HttpCheckerResponse, ErrorPayload
 | 
			
		||||
from ..base import CheckerBaseHandler, process_args_for_host_port
 | 
			
		||||
from ..metrics import push_status_metric
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
web_help_message = """
 | 
			
		||||
❓ Производит проверку хоста по протоколу HTTP.
 | 
			
		||||
| 
						 | 
				
			
			@ -22,6 +23,7 @@ class WebCheckerHandler(CheckerBaseHandler):
 | 
			
		|||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    @rate_limit
 | 
			
		||||
    async def handler(self, message: Message):
 | 
			
		||||
        await self.target_port_handler(message)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,8 @@
 | 
			
		|||
from aiogram.types import Message
 | 
			
		||||
import whois
 | 
			
		||||
from aiogram.types import Message
 | 
			
		||||
 | 
			
		||||
from tgbot.handlers.helpers import validate_local
 | 
			
		||||
from tgbot.middlewares.throttling import rate_limit
 | 
			
		||||
 | 
			
		||||
whois_help_message = """
 | 
			
		||||
❓ Вернёт информацию о домене.
 | 
			
		||||
| 
						 | 
				
			
			@ -64,7 +65,7 @@ def create_whois_message(domain: str) -> str:
 | 
			
		|||
        message += f"\n🔐 DNSSec: {dnssec}"
 | 
			
		||||
    return message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rate_limit
 | 
			
		||||
async def whois_cmd(msg: Message):
 | 
			
		||||
    args = msg.text.split(" ")
 | 
			
		||||
    if len(args) == 1:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,2 +1,3 @@
 | 
			
		|||
from tgbot.middlewares.write_command_metric import WriteCommandMetric
 | 
			
		||||
from tgbot.middlewares.logging import LoggingMiddleware
 | 
			
		||||
from tgbot.middlewares.throttling import ThrottlingMiddleware
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										76
									
								
								apps/tgbot/tgbot/middlewares/throttling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								apps/tgbot/tgbot/middlewares/throttling.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,76 @@
 | 
			
		|||
from aiogram import Dispatcher, types
 | 
			
		||||
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
 | 
			
		||||
from aiogram.dispatcher.handler import CancelHandler, current_handler
 | 
			
		||||
from aiogram.dispatcher.middlewares import BaseMiddleware
 | 
			
		||||
from aiogram.utils.exceptions import Throttled
 | 
			
		||||
import asyncio
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rate_limit(func):
 | 
			
		||||
    setattr(func, 'throttling_rate_limit', 2)
 | 
			
		||||
    setattr(func, 'throttling_key', 'message')
 | 
			
		||||
    return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ThrottlingMiddleware(BaseMiddleware):
 | 
			
		||||
    """
 | 
			
		||||
    Simple middleware
 | 
			
		||||
    TODO: Rewrite
 | 
			
		||||
    From https://docs.aiogram.dev/en/latest/examples/middleware_and_antiflood.html
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
 | 
			
		||||
        self.rate_limit = limit
 | 
			
		||||
        self.prefix = key_prefix
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    async def on_process_message(self, message: types.Message, data: dict):
 | 
			
		||||
        handler = current_handler.get()
 | 
			
		||||
 | 
			
		||||
        dispatcher = Dispatcher.get_current()
 | 
			
		||||
        if handler:
 | 
			
		||||
            limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
 | 
			
		||||
            key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
 | 
			
		||||
        else:
 | 
			
		||||
            limit = self.rate_limit
 | 
			
		||||
            key = f"{self.prefix}_message"
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            await dispatcher.throttle(key, rate=limit)
 | 
			
		||||
        except Throttled as t:
 | 
			
		||||
            await self.message_throttled(message, t)
 | 
			
		||||
            raise CancelHandler()
 | 
			
		||||
 | 
			
		||||
    async def message_throttled(self, message: types.Message, throttled: Throttled):
 | 
			
		||||
        """
 | 
			
		||||
        Notify user only on first exceed and notify about unlocking only on last exceed
 | 
			
		||||
 | 
			
		||||
        :param message:
 | 
			
		||||
        :param throttled:
 | 
			
		||||
        """
 | 
			
		||||
        handler = current_handler.get()
 | 
			
		||||
        dispatcher = Dispatcher.get_current()
 | 
			
		||||
        if handler:
 | 
			
		||||
            key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
 | 
			
		||||
        else:
 | 
			
		||||
            key = f"{self.prefix}_message"
 | 
			
		||||
 | 
			
		||||
        # Calculate how many time is left till the block ends
 | 
			
		||||
        delta = throttled.rate - throttled.delta
 | 
			
		||||
 | 
			
		||||
        # Prevent flooding
 | 
			
		||||
        if throttled.exceeded_count <= 2:
 | 
			
		||||
            await message.reply('❗️Слишком мого запросов. '
 | 
			
		||||
                                'Подождите еще несколько секунд перед отправкой следующего сообщения.'
 | 
			
		||||
                                '\nВ целях предотвращения флуда, бот перестанет отвечать на ваши сообщения '
 | 
			
		||||
                                'на некоторое время.')
 | 
			
		||||
 | 
			
		||||
        # Sleep.
 | 
			
		||||
        await asyncio.sleep(delta)
 | 
			
		||||
 | 
			
		||||
        # Check lock status
 | 
			
		||||
        thr = await dispatcher.check_key(key)
 | 
			
		||||
 | 
			
		||||
        # If current message is not last with current key - do not send message
 | 
			
		||||
        if thr.exceeded_count == throttled.exceeded_count:
 | 
			
		||||
            await message.reply('Unlocked.')
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user