diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 42b5a879..aaae9b0a 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -167,8 +167,8 @@ async def connect_sender( return sender, session_dcs -async def connect(self: Client) -> None: - if self._sender: +async def connect(self: Client, reconnect: bool = False) -> None: + if self._sender and not reconnect: return if session := await self._storage.load(): @@ -181,7 +181,7 @@ async def connect(self: Client) -> None: self._config, self._session.dcs, datacenter ) - if self._message_box.is_empty() and self._session.user: + if reconnect or (self._message_box.is_empty() and self._session.user): try: await self(functions.updates.get_state()) except RpcError as e: @@ -197,7 +197,8 @@ async def connect(self: Client) -> None: id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username ) - self._dispatcher = asyncio.create_task(dispatcher(self)) + if not self._dispatcher or self._dispatcher.done(): + self._dispatcher = asyncio.create_task(dispatcher(self)) async def disconnect(self: Client) -> None: @@ -265,7 +266,10 @@ async def invoke_request( async def step_sender(client: Client) -> None: try: assert client._sender - updates = await client._sender.get_updates() + updates = await client._sender.step_updates() + except ConnectionResetError: + await connect(client, reconnect=True) + return except ConnectionError: if client.connected: raise diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 0bad00ac..9690da9c 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,16 +3,10 @@ import logging import struct import time from abc import ABC -from asyncio import Event, Future, Lock +from asyncio import FIRST_COMPLETED, Event, Future, Lock, Task from collections.abc import Iterator from dataclasses import dataclass -from typing import ( - Generic, - Optional, - Protocol, - Type, - TypeVar, -) +from typing import Generic, Optional, Protocol, Type, TypeVar from typing_extensions import Self @@ -36,7 +30,6 @@ from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages -from .utils import cancel_tasks, store_task MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) @@ -169,6 +162,7 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str + lock: Lock _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -179,10 +173,9 @@ class Sender: _requests: list[Request[object]] _request_event: Event _read_buffer: bytearray - _recv_lock: Lock - _send_lock: Lock - _step_event: Event _write_drain_pending: bool + _recv_task: Optional[Task[bytes]] = None + _send_task: Optional[Task[None]] = None @classmethod async def connect( @@ -198,9 +191,10 @@ class Sender: ip, port = addr.split(":") reader, writer = await connector(ip, int(port)) - return cls( + sender = cls( dc_id=dc_id, addr=addr, + lock=Lock(), _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -211,14 +205,31 @@ class Sender: _requests=[], _request_event=Event(), _read_buffer=bytearray(), - _recv_lock=Lock(), - _send_lock=Lock(), - _step_event=Event(), _write_drain_pending=False, ) + sender._recv_task = asyncio.create_task(sender._do_recv()) + sender._send_task = asyncio.create_task(sender._do_send()) + return sender + async def disconnect(self) -> None: - await cancel_tasks() + assert self._recv_task and self._send_task + + ( + recv_task, + send_task, + self._recv_task, + self._send_task, + ) = ( + self._recv_task, + self._send_task, + None, + None, + ) + + recv_task.cancel() + send_task.cancel() + await asyncio.wait((recv_task, send_task)) self._writer.close() await self._writer.wait_closed() @@ -247,63 +258,47 @@ class Sender: if rx.done(): return rx.result() - async def get_updates(self) -> list[Updates]: + async def step_updates(self) -> list[Updates]: await self.step() - result = self._updates.copy() - self._updates.clear() - return result + updates, self._updates = self._updates, [] + return updates async def step(self) -> None: - self._step_event.clear() - await self._try_fill_write() + if self._recv_task is None or self._send_task is None: + return - if not self._recv_lock.locked(): - recv_task = asyncio.create_task(self._do_recv()) - recv_task.add_done_callback(self._recv_callback) - store_task(recv_task) - if not self._send_lock.locked(): - send_task = asyncio.create_task(self._do_send()) - send_task.add_done_callback(self._send_callback) - store_task(send_task) + self._try_fill_write() - await self._step_event.wait() + await asyncio.wait( + (self._recv_task, self._send_task), return_when=FIRST_COMPLETED + ) - async def _do_recv(self) -> bytes: - async with self._recv_lock: + if self._recv_task.done(): try: - async with asyncio.timeout(PING_DELAY): - return await self._reader.read(MAXIMUM_DATA) + buff = await self._recv_task except TimeoutError: self._on_ping_timeout() - raise + else: + self._on_net_read(buff) + + self._recv_task = asyncio.create_task(self._do_recv()) + if self._send_task.done(): + await self._send_task + self._on_net_write() + self._send_task = asyncio.create_task(self._do_send()) + + async def _do_recv(self) -> bytes: + async with asyncio.timeout(PING_DELAY): + return await self._reader.read(MAXIMUM_DATA) async def _do_send(self) -> None: - async with self._send_lock: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False - else: - await self._request_event.wait() + if self._write_drain_pending: + await self._writer.drain() + self._write_drain_pending = False + else: + await self._request_event.wait() - def _recv_callback(self, fut: Future[bytes]) -> None: - try: - if fut.done(): - buffer = fut.result() - updates = self._on_net_read(buffer) - self._updates.extend(updates) - finally: - self._step_event.set() - - def _send_callback(self, fut: Future[None]) -> None: - try: - if fut.done(): - self._on_net_write() - if not self._requests: - self._request_event.clear() - finally: - self._step_event.set() - - async def _try_fill_write(self) -> None: + def _try_fill_write(self) -> None: if self._write_drain_pending: return @@ -324,13 +319,12 @@ class Sender: self._transport.pack(mtp_buffer, self._writer.write) self._write_drain_pending = True - def _on_net_read(self, read_buffer: bytes) -> list[Updates]: + def _on_net_read(self, read_buffer: bytes) -> None: if not read_buffer: raise ConnectionResetError("read 0 bytes") self._read_buffer += read_buffer - updates: list[Updates] = [] while self._read_buffer: self._mtp_buffer.clear() try: @@ -339,9 +333,7 @@ class Sender: break else: del self._read_buffer[:n] - self._process_mtp_buffer(updates) - - return updates + self._process_mtp_buffer() def _on_net_write(self) -> None: for req in self._requests: @@ -358,12 +350,12 @@ class Sender: ) ) - def _process_mtp_buffer(self, updates: list[Updates]) -> None: + def _process_mtp_buffer(self) -> None: results = self._mtp.deserialize(self._mtp_buffer) for result in results: if isinstance(result, Update): - self._process_update(updates, result.body) + self._process_update(result.body) elif isinstance(result, RpcResult): self._process_result(result) elif isinstance(result, RpcError): @@ -371,11 +363,9 @@ class Sender: else: self._process_bad_message(result) - def _process_update( - self, updates: list[Updates], update: bytes | bytearray | memoryview - ) -> None: + def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: - updates.append(Updates.from_bytes(update)) + self._updates.append(Updates.from_bytes(update)) except ValueError: cid = struct.unpack_from("I", update)[0] alt_classes: tuple[Type[Serializable], ...] = ( @@ -395,7 +385,7 @@ class Sender: AffectedMessages, ), ) - updates.append( + self._updates.append( UpdateShort( update=UpdateDeleteMessages( messages=[], diff --git a/client/src/telethon/_impl/mtsender/utils.py b/client/src/telethon/_impl/mtsender/utils.py deleted file mode 100644 index 5c69e661..00000000 --- a/client/src/telethon/_impl/mtsender/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -import asyncio -from asyncio import Task -from typing import Set - -_background_tasks: Set[Task] = set() - - -def store_task(task: Task) -> None: - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - - -async def cancel_tasks() -> None: - for task in _background_tasks: - task.cancel() - await asyncio.wait(_background_tasks)