From d99e776e2b9e53552de8ee8e23d17979088f97e4 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 15 Oct 2024 17:36:16 +0500 Subject: [PATCH] Rewrite sender --- .../src/telethon/_impl/client/client/net.py | 2 +- client/src/telethon/_impl/mtsender/sender.py | 95 +++++++++++-------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 7db6c3ea..42b5a879 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -265,7 +265,7 @@ async def invoke_request( async def step_sender(client: Client) -> None: try: assert client._sender - updates = await client._sender.step() + updates = await client._sender.get_updates() except ConnectionError: if client.connected: raise diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index afd8cfe5..46c89d9c 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -4,9 +4,10 @@ import struct import time from abc import ABC from asyncio import FIRST_COMPLETED, Event, Future, Lock +from collections import deque from collections.abc import Iterator from dataclasses import dataclass -from typing import Generic, Optional, Protocol, Type, TypeVar +from typing import AsyncIterator, Deque, Generic, Optional, Protocol, Type, TypeVar from typing_extensions import Self @@ -169,11 +170,12 @@ class Sender: _transport: Transport _mtp: Mtp _mtp_buffer: bytearray + _updates: Deque[Updates] _requests: list[Request[object]] _request_event: Event _next_ping: float _read_buffer: bytearray - _write_drain_pending: bool + _step_iterator: AsyncIterator[None] | None = None @classmethod async def connect( @@ -199,11 +201,11 @@ class Sender: _transport=transport, _mtp=mtp, _mtp_buffer=bytearray(), + _updates=deque(), _requests=[], _request_event=Event(), _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), - _write_drain_pending=False, ) async def disconnect(self) -> None: @@ -230,54 +232,68 @@ class Sender: async def _step_until_receive(self, rx: Future[bytes]) -> bytes: while True: - await self.step() + await self._do_step() if rx.done(): return rx.result() - async def step(self) -> list[Updates]: - async with self._lock: - return await self._step() + async def get_updates(self) -> list[Updates]: + await self.do_step() + return ( + [self._updates.popleft() for _ in range(ulen)] + if (ulen := len(self._updates)) + else [] + ) - async def _step(self) -> list[Updates]: + async def do_step(self) -> None: + async with self._lock: + await self._do_step() + + async def _do_step(self) -> None: + await self.step.__anext__() + + @property + def step(self) -> AsyncIterator[None]: + if self._step_iterator is None: + self._step_iterator = self._step() + return self._step_iterator + + async def _step(self) -> AsyncIterator[None]: self._try_fill_write() recv_req = asyncio.create_task(self._request_event.wait()) - recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) - send_data = asyncio.create_task(self._do_send()) - done, pending = await asyncio.wait( - (recv_req, recv_data, send_data), - timeout=self._next_ping - asyncio.get_running_loop().time(), - return_when=FIRST_COMPLETED, - ) + recv_data = asyncio.create_task(self._step_recv()) + send_data = asyncio.create_task(self._step_send()) - if pending: - for task in pending: - task.cancel() - await asyncio.wait(pending) + while True: + done, _ = await asyncio.wait( + (recv_req, recv_data, send_data), + timeout=self._next_ping - asyncio.get_running_loop().time(), + return_when=FIRST_COMPLETED, + ) - result = [] - if recv_req in done: - self._request_event.clear() - if recv_data in done: - result = self._on_net_read(recv_data.result()) - if send_data in done: - self._on_net_write() - if not done: - self._on_ping_timeout() - return result + yield - async def _do_send(self) -> None: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False - else: - # Never return - await asyncio.get_running_loop().create_future() + if recv_req in done: + self._request_event.clear() + recv_req = asyncio.create_task(self._request_event.wait()) + if recv_data in done: + recv_data = asyncio.create_task(self._step_recv()) + if send_data in done: + send_data = asyncio.create_task(self._step_send()) + if not done: + self._on_ping_timeout() + + async def _step_recv(self) -> None: + recv_data = await self._reader.read(MAXIMUM_DATA) + updates = self._on_net_read(recv_data) + self._updates.extend(updates) + + async def _step_send(self) -> None: + self._try_fill_write() + await self._writer.drain() + self._on_net_write() def _try_fill_write(self) -> None: - if self._write_drain_pending: - return - for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: @@ -293,7 +309,6 @@ class Sender: request.state.container_msg_id = container_msg_id self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True def _on_net_read(self, read_buffer: bytes) -> list[Updates]: if not read_buffer: