From 9b18304e9c75f7beac8b7b00a66c2234d9488cc7 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov <109198731+Jahongir-Qurbonov@users.noreply.github.com> Date: Fri, 11 Oct 2024 19:57:18 +0500 Subject: [PATCH] Improve updates handling (#4483) --- .../src/telethon/_impl/client/client/auth.py | 6 ++- .../telethon/_impl/client/client/client.py | 16 +++---- .../src/telethon/_impl/client/client/net.py | 47 +++++++------------ client/src/telethon/_impl/mtsender/sender.py | 8 +++- 4 files changed, 37 insertions(+), 40 deletions(-) diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 7220a4ec..0a33d35c 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: + assert client._sender assert dc_id is not None sender, client._session.dcs = await connect_sender( client._config, client._session.dcs, DataCenter(id=dc_id) ) - async with client._sender_lock: + + async with client._sender.lock: + old_sender = client._sender client._sender = sender + await old_sender.disconnect() async def bot_sign_in(self: Client, token: str) -> User: diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 3f2c5f23..1c199d25 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -220,14 +220,15 @@ class Client: base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) self._sender: Optional[Sender] = None - self._sender_lock = asyncio.Lock() - self._sender_lock_flag = False + if isinstance(session, Storage): - self._storage = session + storage = session elif session is None: - self._storage = MemorySession() + storage = MemorySession() else: - self._storage = SqliteSession(session) + storage = SqliteSession(session) + + self._storage = storage self._config = Config( api_id=api_id, @@ -2074,10 +2075,7 @@ class Client: return await upload(self, fd, size, name) async def __call__(self, request: Request[Return]) -> Return: - if not self._sender: - raise ConnectionError("not connected") - - return await invoke_request(self, self._sender, self._sender_lock, request) + return await invoke_request(self, request) async def __aenter__(self) -> Self: await connect(self) diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 8fcaf1bb..d7f6e31b 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -238,15 +238,16 @@ async def disconnect(self: Client) -> None: async def invoke_request( client: Client, - sender: Sender, - lock: asyncio.Lock, request: Request[Return], ) -> Return: + if not client._sender: + raise ConnectionError("not connected") + sleep_thresh = client._config.flood_sleep_threshold - rx = sender.enqueue(request) + rx = client._sender.enqueue(request) while True: while not rx.done(): - await step_sender(client, sender, lock) + await step_sender(client) try: response = rx.result() break @@ -254,43 +255,31 @@ async def invoke_request( if e.code == 420 and e.value is not None and e.value < sleep_thresh: await asyncio.sleep(e.value) sleep_thresh -= e.value - rx = sender.enqueue(request) + rx = client._sender.enqueue(request) continue else: raise adapt_rpc(e) from None return request.deserialize_response(response) -async def step(client: Client) -> None: - if client._sender: - await step_sender(client, client._sender, client._sender_lock) - - -async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None: - flag = client._sender_lock_flag - async with lock: - if client._sender_lock_flag != flag: - # different task already received an item from the network +async def step_sender(client: Client) -> None: + try: + assert client._sender + updates = await client._sender.step() + except ConnectionError: + if client.connected: + raise + else: + # disconnect was called, so the socket returning 0 bytes is expected return - # current task is responsible for receiving - # toggle the flag so any other task that comes after does not run again - client._sender_lock_flag = not client._sender_lock_flag - try: - updates = await sender.step() - except ConnectionError: - if client.connected: - raise - else: - # disconnect was called, so the socket returning 0 bytes is expected - return - - process_socket_updates(client, updates) + process_socket_updates(client, updates) async def run_until_disconnected(self: Client) -> None: while self.connected: - await step(self) + if self._sender: + await step_sender(self) def connected(client: Client) -> bool: diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 72e1652a..5bbb9e4c 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,7 +3,7 @@ import logging import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future +from asyncio import FIRST_COMPLETED, Event, Future, Lock from collections.abc import Iterator from dataclasses import dataclass from typing import Generic, Optional, Protocol, Type, TypeVar @@ -162,6 +162,7 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str + lock: Lock _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -191,6 +192,7 @@ class Sender: return cls( dc_id=dc_id, addr=addr, + lock=Lock(), _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -233,6 +235,10 @@ class Sender: return rx.result() async def step(self) -> list[Updates]: + async with self.lock: + return await self._step() + + async def _step(self) -> list[Updates]: self._try_fill_write() recv_req = asyncio.create_task(self._request_event.wait())