diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index c5f5b4d8..075f21fb 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -167,14 +167,10 @@ async def connect_sender( return sender, session_dcs -async def connect(self: Client, reconnect: bool = False) -> None: - if self._sender and not reconnect: +async def connect(self: Client) -> None: + if self._sender: return - if reconnect: - assert self._sender - await self._sender.disconnect() - if session := await self._storage.load(): self._session = session @@ -185,7 +181,7 @@ async def connect(self: Client, reconnect: bool = False) -> None: self._config, self._session.dcs, datacenter ) - if reconnect or (self._message_box.is_empty() and self._session.user): + if self._message_box.is_empty() and self._session.user: try: await self(functions.updates.get_state()) except RpcError as e: @@ -201,8 +197,7 @@ async def connect(self: Client, reconnect: bool = False) -> None: id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username ) - if not self._dispatcher or self._dispatcher.done(): - self._dispatcher = asyncio.create_task(dispatcher(self)) + self._dispatcher = asyncio.create_task(dispatcher(self)) async def disconnect(self: Client) -> None: @@ -271,9 +266,6 @@ async def step_sender(client: Client) -> None: try: assert client._sender updates = await client._sender.step_updates() - except ConnectionResetError: - await connect(client, reconnect=True) - return except ConnectionError: if client.connected: raise diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 5da2bd18..0ef59f6a 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -163,6 +163,7 @@ class Sender: dc_id: int addr: str _logger: logging.Logger + _lock: Lock _reader: AsyncReader _writer: AsyncWriter _transport: Transport @@ -172,7 +173,6 @@ class Sender: _requests: list[Request[object]] _request_event: Event _read_buffer: bytearray - _step_lock: Lock _step_counter: int _recv_task: Optional[Task[bytes]] = None _send_task: Optional[Task[None]] = None @@ -195,6 +195,7 @@ class Sender: dc_id=dc_id, addr=addr, _logger=base_logger.getChild("mtsender"), + _lock=Lock(), _reader=reader, _writer=writer, _transport=transport, @@ -204,7 +205,6 @@ class Sender: _requests=[], _request_event=Event(), _read_buffer=bytearray(), - _step_lock=Lock(), _step_counter=0, ) @@ -212,7 +212,9 @@ class Sender: assert self._recv_task assert self._send_task recv_task, send_task = self._recv_task, self._send_task - self._recv_task, self._send_task = None, None + + async with self._lock: + self._recv_task, self._send_task = None, None recv_task.cancel() send_task.cancel() @@ -251,7 +253,7 @@ class Sender: async def step(self) -> None: ticket_number = self._step_counter - async with self._step_lock: + async with self._lock: if self._step_counter == ticket_number: # We're the one to drive IO. await self._step()