Fix garbage-collected and non cancel-safe tasks

This commit is contained in:
Jahongir Qurbonov 2024-10-17 12:20:25 +05:00
parent 106e7bd8bb
commit 0b8fbda667
2 changed files with 59 additions and 26 deletions

View File

@ -36,6 +36,7 @@ from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .utils import cancel_tasks, store_task
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
@ -181,6 +182,7 @@ class Sender:
_recv_lock: Lock
_send_lock: Lock
_step_event: Event
_write_drain_pending: bool
@classmethod
async def connect(
@ -212,9 +214,12 @@ class Sender:
_recv_lock=Lock(),
_send_lock=Lock(),
_step_event=Event(),
_write_drain_pending=False,
)
async def disconnect(self) -> None:
await cancel_tasks()
self._writer.close()
await self._writer.wait_closed()
@ -250,47 +255,58 @@ class Sender:
async def step(self) -> None:
self._step_event.clear()
await self._try_fill_write()
if not self._recv_lock.locked():
asyncio.create_task(self.step_recv())
recv_task = asyncio.create_task(self._do_recv())
recv_task.add_done_callback(self._recv_callback)
store_task(recv_task)
if not self._send_lock.locked():
asyncio.create_task(self.step_send())
send_task = asyncio.create_task(self._do_send())
send_task.add_done_callback(self._send_callback)
store_task(send_task)
await self._step_event.wait()
async def step_recv(self) -> None:
async def _do_recv(self) -> bytes:
async with self._recv_lock:
try:
await self._step_recv()
finally:
self._step_event.set()
async with asyncio.timeout(PING_DELAY):
return await self._reader.read(MAXIMUM_DATA)
except TimeoutError:
self._on_ping_timeout()
raise
async def step_send(self) -> None:
async def _do_send(self) -> None:
async with self._send_lock:
try:
await self._step_send()
finally:
self._step_event.set()
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
await self._request_event.wait()
async def _step_recv(self) -> None:
def _recv_callback(self, fut: Future[bytes]) -> None:
try:
async with asyncio.timeout(PING_DELAY):
recv_data = await self._reader.read(MAXIMUM_DATA)
result = self._on_net_read(recv_data)
self._updates.extend(result)
except TimeoutError:
self._on_ping_timeout()
if fut.done():
buffer = fut.result()
updates = self._on_net_read(buffer)
self._updates.extend(updates)
finally:
self._step_event.set()
async def _step_send(self) -> None:
await self._request_event.wait()
await self._try_fill_write()
await self._writer.drain()
self._on_net_write()
if not self._requests:
self._request_event.clear()
def _send_callback(self, fut: Future[None]) -> None:
try:
if fut.done():
self._on_net_write()
if not self._requests:
self._request_event.clear()
finally:
self._step_event.set()
async def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests:
if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None:
@ -306,6 +322,7 @@ class Sender:
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> list[Updates]:
if not read_buffer:

View File

@ -0,0 +1,16 @@
import asyncio
from asyncio import Task
from typing import Set
_background_tasks: Set[Task] = set()
def store_task(task: Task) -> None:
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def cancel_tasks() -> None:
for task in _background_tasks:
task.cancel()
await asyncio.wait(_background_tasks)