diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index e5fba4c8..c9b85e31 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -167,7 +167,7 @@ class Sender: _writer: AsyncWriter _reading: bool _writing: bool - _read_done: Event + _step_done: Event _transport: Transport _mtp: Mtp _mtp_buffer: bytearray @@ -198,7 +198,7 @@ class Sender: _writer=writer, _reading=False, _writing=False, - _read_done=Event(), + _step_done=Event(), _transport=transport, _mtp=mtp, _mtp_buffer=bytearray(), @@ -211,7 +211,7 @@ class Sender: async def disconnect(self) -> None: self._writer.close() await self._writer.wait_closed() - self._read_done.set() + self._step_done.set() def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: rx = self._enqueue_body(bytes(request)) @@ -239,23 +239,24 @@ class Sender: async def step(self) -> None: if not self._writing: self._writing = True - self._try_fill_write() + await self._do_write() self._writing = False if not self._reading: self._reading = True - await self._try_read() + await self._do_read() self._reading = False - else: - await self._read_done.wait() + + if not self._step_done.is_set(): + await self._step_done.wait() def pop_updates(self) -> list[Updates]: updates = self._updates[:] self._updates.clear() return updates - async def _try_read(self) -> None: - self._read_done.clear() + async def _do_read(self) -> None: + self._step_done.clear() try: async with asyncio.timeout(PING_DELAY): @@ -266,9 +267,24 @@ class Sender: self._on_net_read(recv_data) finally: self._try_timeout_ping() - self._read_done.set() + self._step_done.set() + + async def _do_write(self) -> None: + self._step_done.clear() + + try: + async with asyncio.timeout(PING_DELAY): + await self._try_fill_write() + except TimeoutError: + pass + finally: + self._try_timeout_ping() + self._step_done.set() + + async def _try_fill_write(self) -> None: + if not self._requests: + return - def _try_fill_write(self) -> None: for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: @@ -285,6 +301,8 @@ class Sender: if isinstance(request.state, Serialized): request.state = Sent(request.state.msg_id, container_msg_id) + await self._writer.drain() + def _try_timeout_ping(self) -> None: current_time = asyncio.get_running_loop().time()