diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index aaae9b0a..c5f5b4d8 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -171,6 +171,10 @@ async def connect(self: Client, reconnect: bool = False) -> None: if self._sender and not reconnect: return + if reconnect: + assert self._sender + await self._sender.disconnect() + if session := await self._storage.load(): self._session = session diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 722e74d5..5da2bd18 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,7 +3,7 @@ import logging import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future, Task +from asyncio import FIRST_COMPLETED, Event, Future, Lock, Task from collections.abc import Iterator from dataclasses import dataclass from typing import Generic, Optional, Protocol, Type, TypeVar @@ -172,8 +172,8 @@ class Sender: _requests: list[Request[object]] _request_event: Event _read_buffer: bytearray - _write_drain_pending: bool - _step_event: Event + _step_lock: Lock + _step_counter: int _recv_task: Optional[Task[bytes]] = None _send_task: Optional[Task[None]] = None @@ -191,7 +191,7 @@ class Sender: ip, port = addr.split(":") reader, writer = await connector(ip, int(port)) - sender = cls( + return cls( dc_id=dc_id, addr=addr, _logger=base_logger.getChild("mtsender"), @@ -204,28 +204,15 @@ class Sender: _requests=[], _request_event=Event(), _read_buffer=bytearray(), - _write_drain_pending=False, - _step_event=Event(), + _step_lock=Lock(), + _step_counter=0, ) - sender._recv_task = asyncio.create_task(sender._do_recv()) - sender._send_task = asyncio.create_task(sender._do_send()) - return sender - async def disconnect(self) -> None: - assert self._recv_task and self._send_task - - ( - recv_task, - send_task, - self._recv_task, - self._send_task, - ) = ( - self._recv_task, - self._send_task, - None, - None, - ) + assert self._recv_task + assert self._send_task + recv_task, send_task = self._recv_task, self._send_task + self._recv_task, self._send_task = None, None recv_task.cancel() send_task.cancel() @@ -253,10 +240,9 @@ class Sender: return oneshot async def _step_until_receive(self, rx: Future[bytes]) -> bytes: - while True: + while not rx.done(): await self.step() - if rx.done(): - return rx.result() + return rx.result() async def step_updates(self) -> list[Updates]: await self.step() @@ -264,11 +250,22 @@ class Sender: return updates async def step(self) -> None: - if self._recv_task is None or self._send_task is None: - return + ticket_number = self._step_counter + async with self._step_lock: + if self._step_counter == ticket_number: + # We're the one to drive IO. + await self._step() + self._step_counter += 1 - self._step_event.clear() - self._try_fill_write() + async def _step(self) -> None: + if self._step_counter == 0: + self._try_fill_write() + self._recv_task = asyncio.create_task(self._do_recv()) + self._send_task = asyncio.create_task(self._do_send()) + + if self._recv_task is None or self._send_task is None: + # Disconnected + return await asyncio.wait( (self._recv_task, self._send_task), return_when=FIRST_COMPLETED @@ -276,7 +273,7 @@ class Sender: if self._recv_task.done(): try: - buff = await self._recv_task + buff = self._recv_task.result() except TimeoutError: self._on_ping_timeout() else: @@ -284,31 +281,24 @@ class Sender: self._recv_task = asyncio.create_task(self._do_recv()) if self._send_task.done(): - await self._send_task + self._send_task.result() self._on_net_write() + + self._try_fill_write() self._send_task = asyncio.create_task(self._do_send()) - await self._step_event.wait() - async def _do_recv(self) -> bytes: - try: - async with asyncio.timeout(PING_DELAY): - return await self._reader.read(MAXIMUM_DATA) - finally: - self._step_event.set() + async with asyncio.timeout(PING_DELAY): + return await self._reader.read(MAXIMUM_DATA) async def _do_send(self) -> None: - try: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False - else: - await self._request_event.wait() - finally: - self._step_event.set() + await self._request_event.wait() + await self._writer.drain() + if not self._requests: + self._request_event.clear() def _try_fill_write(self) -> None: - if self._write_drain_pending: + if not self._requests: return for request in self._requests: @@ -326,7 +316,6 @@ class Sender: request.state.container_msg_id = container_msg_id self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True def _on_net_read(self, read_buffer: bytes) -> None: if not read_buffer: