From 0b8fbda667f54f567833711fb67524a40f311afd Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Thu, 17 Oct 2024 12:20:25 +0500 Subject: [PATCH] Fix garbage-collected and non cancel-safe tasks --- client/src/telethon/_impl/mtsender/sender.py | 69 ++++++++++++-------- client/src/telethon/_impl/mtsender/utils.py | 16 +++++ 2 files changed, 59 insertions(+), 26 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/utils.py diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index fcf57179..0bad00ac 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -36,6 +36,7 @@ from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages +from .utils import cancel_tasks, store_task MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) @@ -181,6 +182,7 @@ class Sender: _recv_lock: Lock _send_lock: Lock _step_event: Event + _write_drain_pending: bool @classmethod async def connect( @@ -212,9 +214,12 @@ class Sender: _recv_lock=Lock(), _send_lock=Lock(), _step_event=Event(), + _write_drain_pending=False, ) async def disconnect(self) -> None: + await cancel_tasks() + self._writer.close() await self._writer.wait_closed() @@ -250,47 +255,58 @@ class Sender: async def step(self) -> None: self._step_event.clear() + await self._try_fill_write() if not self._recv_lock.locked(): - asyncio.create_task(self.step_recv()) + recv_task = asyncio.create_task(self._do_recv()) + recv_task.add_done_callback(self._recv_callback) + store_task(recv_task) if not self._send_lock.locked(): - asyncio.create_task(self.step_send()) + send_task = asyncio.create_task(self._do_send()) + send_task.add_done_callback(self._send_callback) + store_task(send_task) await self._step_event.wait() - async def step_recv(self) -> None: + async def _do_recv(self) -> bytes: async with self._recv_lock: try: - await self._step_recv() - finally: - self._step_event.set() + async with asyncio.timeout(PING_DELAY): + return await self._reader.read(MAXIMUM_DATA) + except TimeoutError: + self._on_ping_timeout() + raise - async def step_send(self) -> None: + async def _do_send(self) -> None: async with self._send_lock: - try: - await self._step_send() - finally: - self._step_event.set() + if self._write_drain_pending: + await self._writer.drain() + self._write_drain_pending = False + else: + await self._request_event.wait() - async def _step_recv(self) -> None: + def _recv_callback(self, fut: Future[bytes]) -> None: try: - async with asyncio.timeout(PING_DELAY): - recv_data = await self._reader.read(MAXIMUM_DATA) - result = self._on_net_read(recv_data) - self._updates.extend(result) - except TimeoutError: - self._on_ping_timeout() + if fut.done(): + buffer = fut.result() + updates = self._on_net_read(buffer) + self._updates.extend(updates) + finally: + self._step_event.set() - async def _step_send(self) -> None: - await self._request_event.wait() - await self._try_fill_write() - await self._writer.drain() - self._on_net_write() - - if not self._requests: - self._request_event.clear() + def _send_callback(self, fut: Future[None]) -> None: + try: + if fut.done(): + self._on_net_write() + if not self._requests: + self._request_event.clear() + finally: + self._step_event.set() async def _try_fill_write(self) -> None: + if self._write_drain_pending: + return + for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: @@ -306,6 +322,7 @@ class Sender: request.state.container_msg_id = container_msg_id self._transport.pack(mtp_buffer, self._writer.write) + self._write_drain_pending = True def _on_net_read(self, read_buffer: bytes) -> list[Updates]: if not read_buffer: diff --git a/client/src/telethon/_impl/mtsender/utils.py b/client/src/telethon/_impl/mtsender/utils.py new file mode 100644 index 00000000..5c69e661 --- /dev/null +++ b/client/src/telethon/_impl/mtsender/utils.py @@ -0,0 +1,16 @@ +import asyncio +from asyncio import Task +from typing import Set + +_background_tasks: Set[Task] = set() + + +def store_task(task: Task) -> None: + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + +async def cancel_tasks() -> None: + for task in _background_tasks: + task.cancel() + await asyncio.wait(_background_tasks)