From f15af530fa4a3b18a2983f3dde36f0c3175ec27a Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov <109198731+Jahongir-Qurbonov@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:00:00 +0500 Subject: [PATCH] Refactor sender to avoid cancelling tasks (#4497) --- .../src/telethon/_impl/client/client/auth.py | 7 +- client/src/telethon/_impl/mtsender/sender.py | 129 ++++++++---------- 2 files changed, 61 insertions(+), 75 deletions(-) diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 0a33d35c..eb5bc2ec 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -61,10 +61,9 @@ async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: client._config, client._session.dcs, DataCenter(id=dc_id) ) - async with client._sender.lock: - old_sender = client._sender - client._sender = sender - await old_sender.disconnect() + old_sender = client._sender + client._sender = sender + await old_sender.disconnect() async def bot_sign_in(self: Client, token: str) -> User: diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 57703082..c5da137c 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,7 +3,7 @@ import logging import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future, Lock +from asyncio import Event, Future from collections.abc import Iterator from dataclasses import dataclass from typing import Generic, Optional, Protocol, Type, TypeVar @@ -162,20 +162,19 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str - lock: Lock _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter + _reading: bool + _writing: bool + _step_done: Event _transport: Transport _mtp: Mtp _mtp_buffer: bytearray _updates: list[Updates] _requests: list[Request[object]] - _request_event: Event _next_ping: float _read_buffer: bytearray - _write_drain_pending: bool - _step_counter: int @classmethod async def connect( @@ -194,29 +193,28 @@ class Sender: return cls( dc_id=dc_id, addr=addr, - lock=Lock(), _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, + _reading=False, + _writing=False, + _step_done=Event(), _transport=transport, _mtp=mtp, _mtp_buffer=bytearray(), _updates=[], _requests=[], - _request_event=Event(), _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), - _write_drain_pending=False, - _step_counter=0, ) async def disconnect(self) -> None: self._writer.close() await self._writer.wait_closed() + self._step_done.set() def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: rx = self._enqueue_body(bytes(request)) - self._request_event.set() return rx async def invoke(self, request: RemoteCall[Return]) -> bytes: @@ -239,56 +237,47 @@ class Sender: return rx.result() async def step(self) -> None: - ticket_number = self._step_counter + if not self._writing: + self._writing = True + await self._do_write() + self._writing = False - async with self.lock: - if self._step_counter == ticket_number: - # We're the one to drive IO. - self._step_counter += 1 - await self._step() - # else: different task drove IO. + if not self._reading: + self._reading = True + await self._do_read() + self._reading = False + + if not self._step_done.is_set(): + await self._step_done.wait() def pop_updates(self) -> list[Updates]: updates = self._updates[:] self._updates.clear() return updates - async def _step(self) -> None: - self._try_fill_write() + async def _do_read(self) -> None: + self._step_done.clear() - recv_req = asyncio.create_task(self._request_event.wait()) - recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) - send_data = asyncio.create_task(self._do_send()) - done, pending = await asyncio.wait( - (recv_req, recv_data, send_data), - timeout=self._next_ping - asyncio.get_running_loop().time(), - return_when=FIRST_COMPLETED, - ) - - if pending: - for task in pending: - task.cancel() - await asyncio.wait(pending) - - if recv_req in done: - self._request_event.clear() - if recv_data in done: - self._on_net_read(recv_data.result()) - if send_data in done: - self._on_net_write() - if not done: - self._on_ping_timeout() - - async def _do_send(self) -> None: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False + timeout = self._next_ping - asyncio.get_running_loop().time() + try: + async with asyncio.timeout(timeout): + recv_data = await self._reader.read(MAXIMUM_DATA) + except TimeoutError: + pass else: - # Never return - await asyncio.get_running_loop().create_future() + self._on_net_read(recv_data) + finally: + self._try_timeout_ping() + self._step_done.set() - def _try_fill_write(self) -> None: - if self._write_drain_pending: + async def _do_write(self) -> None: + self._step_done.clear() + await self._try_fill_write() + self._try_timeout_ping() + self._step_done.set() + + async def _try_fill_write(self) -> None: + if not self._requests: return for request in self._requests: @@ -301,12 +290,27 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result - for request in self._requests: - if isinstance(request.state, Serialized): - request.state.container_msg_id = container_msg_id self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True + await self._writer.drain() + + for request in self._requests: + if isinstance(request.state, Serialized): + request.state = Sent(request.state.msg_id, container_msg_id) + + def _try_timeout_ping(self) -> None: + current_time = asyncio.get_running_loop().time() + + if current_time >= self._next_ping: + ping_id = generate_random_id() + self._enqueue_body( + bytes( + ping_delay_disconnect( + ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT + ) + ) + ) + self._next_ping = current_time + PING_DELAY def _on_net_read(self, read_buffer: bytes) -> None: if not read_buffer: @@ -324,22 +328,6 @@ class Sender: del self._read_buffer[:n] self._process_mtp_buffer() - def _on_net_write(self) -> None: - for req in self._requests: - if isinstance(req.state, Serialized): - req.state = Sent(req.state.msg_id, req.state.container_msg_id) - - def _on_ping_timeout(self) -> None: - ping_id = generate_random_id() - self._enqueue_body( - bytes( - ping_delay_disconnect( - ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT - ) - ) - ) - self._next_ping = asyncio.get_running_loop().time() + PING_DELAY - def _process_mtp_buffer(self) -> None: results = self._mtp.deserialize(self._mtp_buffer) @@ -513,5 +501,4 @@ async def generate_auth_key(sender: Sender) -> Sender: first_salt = finished.first_salt sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) - sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY return sender