From c138e17f4a23d93d87d4bc82c04e92430e7a40f8 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 22 Oct 2024 15:45:44 +0500 Subject: [PATCH] Refactor sender --- client/src/telethon/_impl/mtsender/sender.py | 88 +++++++------------- 1 file changed, 28 insertions(+), 60 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 57703082..b8d3ca8a 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,7 +3,7 @@ import logging import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future, Lock +from asyncio import Event, Future, Lock from collections.abc import Iterator from dataclasses import dataclass from typing import Generic, Optional, Protocol, Type, TypeVar @@ -166,15 +166,15 @@ class Sender: _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter + _reading: bool + _writing: bool _transport: Transport _mtp: Mtp _mtp_buffer: bytearray _updates: list[Updates] _requests: list[Request[object]] - _request_event: Event - _next_ping: float + _response_event: Event _read_buffer: bytearray - _write_drain_pending: bool _step_counter: int @classmethod @@ -196,6 +196,8 @@ class Sender: addr=addr, lock=Lock(), _logger=base_logger.getChild("mtsender"), + _reading=False, + _writing=False, _reader=reader, _writer=writer, _transport=transport, @@ -203,10 +205,8 @@ class Sender: _mtp_buffer=bytearray(), _updates=[], _requests=[], - _request_event=Event(), - _next_ping=asyncio.get_running_loop().time() + PING_DELAY, + _response_event=Event(), _read_buffer=bytearray(), - _write_drain_pending=False, _step_counter=0, ) @@ -216,7 +216,6 @@ class Sender: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: rx = self._enqueue_body(bytes(request)) - self._request_event.set() return rx async def invoke(self, request: RemoteCall[Return]) -> bytes: @@ -239,58 +238,34 @@ class Sender: return rx.result() async def step(self) -> None: - ticket_number = self._step_counter + if not self._writing: + self._writing = True + self._try_fill_write() + self._writing = False - async with self.lock: - if self._step_counter == ticket_number: - # We're the one to drive IO. - self._step_counter += 1 - await self._step() - # else: different task drove IO. + if not self._reading: + self._reading = True + self._response_event.clear() + await self._try_read() + self._reading = False + else: + await self._response_event.wait() def pop_updates(self) -> list[Updates]: updates = self._updates[:] self._updates.clear() return updates - async def _step(self) -> None: - self._try_fill_write() - - recv_req = asyncio.create_task(self._request_event.wait()) - recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) - send_data = asyncio.create_task(self._do_send()) - done, pending = await asyncio.wait( - (recv_req, recv_data, send_data), - timeout=self._next_ping - asyncio.get_running_loop().time(), - return_when=FIRST_COMPLETED, - ) - - if pending: - for task in pending: - task.cancel() - await asyncio.wait(pending) - - if recv_req in done: - self._request_event.clear() - if recv_data in done: - self._on_net_read(recv_data.result()) - if send_data in done: - self._on_net_write() - if not done: + async def _try_read(self) -> None: + try: + async with asyncio.timeout(PING_DELAY): + recv_data = await self._reader.read(MAXIMUM_DATA) + except TimeoutError: self._on_ping_timeout() - - async def _do_send(self) -> None: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False else: - # Never return - await asyncio.get_running_loop().create_future() + self._on_net_read(recv_data) def _try_fill_write(self) -> None: - if self._write_drain_pending: - return - for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: @@ -301,12 +276,11 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result - for request in self._requests: - if isinstance(request.state, Serialized): - request.state.container_msg_id = container_msg_id self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True + for request in self._requests: + if isinstance(request.state, Serialized): + request.state = Sent(request.state.msg_id, container_msg_id) def _on_net_read(self, read_buffer: bytes) -> None: if not read_buffer: @@ -323,11 +297,7 @@ class Sender: else: del self._read_buffer[:n] self._process_mtp_buffer() - - def _on_net_write(self) -> None: - for req in self._requests: - if isinstance(req.state, Serialized): - req.state = Sent(req.state.msg_id, req.state.container_msg_id) + self._response_event.set() def _on_ping_timeout(self) -> None: ping_id = generate_random_id() @@ -338,7 +308,6 @@ class Sender: ) ) ) - self._next_ping = asyncio.get_running_loop().time() + PING_DELAY def _process_mtp_buffer(self) -> None: results = self._mtp.deserialize(self._mtp_buffer) @@ -513,5 +482,4 @@ async def generate_auth_key(sender: Sender) -> Sender: first_salt = finished.first_salt sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) - sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY return sender