diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 220a0081..757771e7 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -61,7 +61,7 @@ async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: client._config, client._session.dcs, DataCenter(id=dc_id) ) - async with client._sender._lock: + async with client._sender._step_lock: old_sender = client._sender client._sender = sender await old_sender.disconnect() diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 5cbfd32e..d0cfde4f 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -4,13 +4,15 @@ import struct import sys import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future, Lock +from asyncio import FIRST_COMPLETED, Event, Future, Lock, Task from collections.abc import AsyncGenerator, Iterator from dataclasses import dataclass from typing import ( Generic, + Literal, Optional, Protocol, + Set, Type, TypeVar, ) @@ -178,7 +180,6 @@ class Sender: dc_id: int addr: str _logger: logging.Logger - _lock: Lock _reader: AsyncReader _writer: AsyncWriter _transport: Transport @@ -189,6 +190,7 @@ class Sender: _request_event: Event _next_ping: float _read_buffer: bytearray + _step_lock: Lock _step_generator: AsyncGenerator[None, None] | None = None @classmethod @@ -209,7 +211,6 @@ class Sender: dc_id=dc_id, addr=addr, _logger=base_logger.getChild("mtsender"), - _lock=Lock(), _reader=reader, _writer=writer, _transport=transport, @@ -220,6 +221,7 @@ class Sender: _request_event=Event(), _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), + _step_lock=Lock(), ) async def disconnect(self) -> None: @@ -229,7 +231,6 @@ class Sender: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: rx = self._enqueue_body(bytes(request)) - self._request_event.set() return rx async def invoke(self, request: RemoteCall[Return]) -> bytes: @@ -243,11 +244,14 @@ class Sender: def _enqueue_body(self, body: bytes) -> Future[bytes]: oneshot = asyncio.get_running_loop().create_future() self._requests.append(Request(body=body, state=NotSerialized(), result=oneshot)) + self._request_event.set() return oneshot async def _step_until_receive(self, rx: Future[bytes]) -> bytes: while True: - await self._do_step() + step = asyncio.create_task(self.do_step()) + await asyncio.wait((step, rx), return_when=FIRST_COMPLETED) + if rx.done(): return rx.result() @@ -258,11 +262,8 @@ class Sender: return result async def do_step(self) -> None: - async with self._lock: - await self._do_step() - - async def _do_step(self) -> None: - await anext(self.step) + async with self._step_lock: + await anext(self.step) @property def step(self) -> AsyncGenerator[None, None]: @@ -275,42 +276,52 @@ class Sender: recv_data = asyncio.create_task(self._step_recv()) send_data = asyncio.create_task(self._step_send()) + pending: Set[Task[Literal[True] | None]] = set() + try: while True: - done, pending = await asyncio.wait( - (recv_req, recv_data, send_data), + _, pending = await asyncio.wait( + (recv_data, send_data), timeout=self._next_ping - asyncio.get_running_loop().time(), return_when=FIRST_COMPLETED, - ) + ) # pyright: ignore [reportAssignmentType] yield - if recv_req in done: - self._request_event.clear() + if recv_req.done(): recv_req = asyncio.create_task(self._request_event.wait()) - if recv_data in done: + if recv_data.done(): recv_data = asyncio.create_task(self._step_recv()) - if send_data in done: + if send_data.done(): send_data = asyncio.create_task(self._step_send()) - if not done: - self._on_ping_timeout() finally: - if pending: - for task in pending: - task.cancel() - await asyncio.wait(pending) + await self._try_cancel_tasks(pending) + + async def _try_cancel_tasks(self, pending: set[Task]) -> None: + if pending: + for task in pending: + task.cancel() + await asyncio.wait(pending) async def _step_recv(self) -> None: - recv_data = await self._reader.read(MAXIMUM_DATA) - result = self._on_net_read(recv_data) - self._updates.extend(result) + try: + async with asyncio.timeout(PING_DELAY): + recv_data = await self._reader.read(MAXIMUM_DATA) + result = self._on_net_read(recv_data) + self._updates.extend(result) + except TimeoutError: + self._on_ping_timeout() async def _step_send(self) -> None: - self._try_fill_write() + await self._request_event.wait() + await self._try_fill_write() await self._writer.drain() self._on_net_write() - def _try_fill_write(self) -> None: + if not self._requests: + self._request_event.clear() + + async def _try_fill_write(self) -> None: for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: