mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-31 07:57:38 +03:00 
			
		
		
		
	Improve updates handling (#4483)
This commit is contained in:
		
							parent
							
								
									771954d010
								
							
						
					
					
						commit
						9b18304e9c
					
				|  | @ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: | |||
| 
 | ||||
| 
 | ||||
| async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: | ||||
|     assert client._sender | ||||
|     assert dc_id is not None | ||||
|     sender, client._session.dcs = await connect_sender( | ||||
|         client._config, client._session.dcs, DataCenter(id=dc_id) | ||||
|     ) | ||||
|     async with client._sender_lock: | ||||
| 
 | ||||
|     async with client._sender.lock: | ||||
|         old_sender = client._sender | ||||
|         client._sender = sender | ||||
|         await old_sender.disconnect() | ||||
| 
 | ||||
| 
 | ||||
| async def bot_sign_in(self: Client, token: str) -> User: | ||||
|  |  | |||
|  | @ -220,14 +220,15 @@ class Client: | |||
|         base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) | ||||
| 
 | ||||
|         self._sender: Optional[Sender] = None | ||||
|         self._sender_lock = asyncio.Lock() | ||||
|         self._sender_lock_flag = False | ||||
| 
 | ||||
|         if isinstance(session, Storage): | ||||
|             self._storage = session | ||||
|             storage = session | ||||
|         elif session is None: | ||||
|             self._storage = MemorySession() | ||||
|             storage = MemorySession() | ||||
|         else: | ||||
|             self._storage = SqliteSession(session) | ||||
|             storage = SqliteSession(session) | ||||
| 
 | ||||
|         self._storage = storage | ||||
| 
 | ||||
|         self._config = Config( | ||||
|             api_id=api_id, | ||||
|  | @ -2074,10 +2075,7 @@ class Client: | |||
|         return await upload(self, fd, size, name) | ||||
| 
 | ||||
|     async def __call__(self, request: Request[Return]) -> Return: | ||||
|         if not self._sender: | ||||
|             raise ConnectionError("not connected") | ||||
| 
 | ||||
|         return await invoke_request(self, self._sender, self._sender_lock, request) | ||||
|         return await invoke_request(self, request) | ||||
| 
 | ||||
|     async def __aenter__(self) -> Self: | ||||
|         await connect(self) | ||||
|  |  | |||
|  | @ -238,15 +238,16 @@ async def disconnect(self: Client) -> None: | |||
| 
 | ||||
| async def invoke_request( | ||||
|     client: Client, | ||||
|     sender: Sender, | ||||
|     lock: asyncio.Lock, | ||||
|     request: Request[Return], | ||||
| ) -> Return: | ||||
|     if not client._sender: | ||||
|         raise ConnectionError("not connected") | ||||
| 
 | ||||
|     sleep_thresh = client._config.flood_sleep_threshold | ||||
|     rx = sender.enqueue(request) | ||||
|     rx = client._sender.enqueue(request) | ||||
|     while True: | ||||
|         while not rx.done(): | ||||
|             await step_sender(client, sender, lock) | ||||
|             await step_sender(client) | ||||
|         try: | ||||
|             response = rx.result() | ||||
|             break | ||||
|  | @ -254,43 +255,31 @@ async def invoke_request( | |||
|             if e.code == 420 and e.value is not None and e.value < sleep_thresh: | ||||
|                 await asyncio.sleep(e.value) | ||||
|                 sleep_thresh -= e.value | ||||
|                 rx = sender.enqueue(request) | ||||
|                 rx = client._sender.enqueue(request) | ||||
|                 continue | ||||
|             else: | ||||
|                 raise adapt_rpc(e) from None | ||||
|     return request.deserialize_response(response) | ||||
| 
 | ||||
| 
 | ||||
| async def step(client: Client) -> None: | ||||
|     if client._sender: | ||||
|         await step_sender(client, client._sender, client._sender_lock) | ||||
| 
 | ||||
| 
 | ||||
| async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None: | ||||
|     flag = client._sender_lock_flag | ||||
|     async with lock: | ||||
|         if client._sender_lock_flag != flag: | ||||
|             # different task already received an item from the network | ||||
| async def step_sender(client: Client) -> None: | ||||
|     try: | ||||
|         assert client._sender | ||||
|         updates = await client._sender.step() | ||||
|     except ConnectionError: | ||||
|         if client.connected: | ||||
|             raise | ||||
|         else: | ||||
|             # disconnect was called, so the socket returning 0 bytes is expected | ||||
|             return | ||||
|         # current task is responsible for receiving | ||||
|         # toggle the flag so any other task that comes after does not run again | ||||
|         client._sender_lock_flag = not client._sender_lock_flag | ||||
| 
 | ||||
|         try: | ||||
|             updates = await sender.step() | ||||
|         except ConnectionError: | ||||
|             if client.connected: | ||||
|                 raise | ||||
|             else: | ||||
|                 # disconnect was called, so the socket returning 0 bytes is expected | ||||
|                 return | ||||
| 
 | ||||
|         process_socket_updates(client, updates) | ||||
|     process_socket_updates(client, updates) | ||||
| 
 | ||||
| 
 | ||||
| async def run_until_disconnected(self: Client) -> None: | ||||
|     while self.connected: | ||||
|         await step(self) | ||||
|         if self._sender: | ||||
|             await step_sender(self) | ||||
| 
 | ||||
| 
 | ||||
| def connected(client: Client) -> bool: | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ import logging | |||
| import struct | ||||
| import time | ||||
| from abc import ABC | ||||
| from asyncio import FIRST_COMPLETED, Event, Future | ||||
| from asyncio import FIRST_COMPLETED, Event, Future, Lock | ||||
| from collections.abc import Iterator | ||||
| from dataclasses import dataclass | ||||
| from typing import Generic, Optional, Protocol, Type, TypeVar | ||||
|  | @ -162,6 +162,7 @@ class Request(Generic[Return]): | |||
| class Sender: | ||||
|     dc_id: int | ||||
|     addr: str | ||||
|     lock: Lock | ||||
|     _logger: logging.Logger | ||||
|     _reader: AsyncReader | ||||
|     _writer: AsyncWriter | ||||
|  | @ -191,6 +192,7 @@ class Sender: | |||
|         return cls( | ||||
|             dc_id=dc_id, | ||||
|             addr=addr, | ||||
|             lock=Lock(), | ||||
|             _logger=base_logger.getChild("mtsender"), | ||||
|             _reader=reader, | ||||
|             _writer=writer, | ||||
|  | @ -233,6 +235,10 @@ class Sender: | |||
|                 return rx.result() | ||||
| 
 | ||||
|     async def step(self) -> list[Updates]: | ||||
|         async with self.lock: | ||||
|             return await self._step() | ||||
| 
 | ||||
|     async def _step(self) -> list[Updates]: | ||||
|         self._try_fill_write() | ||||
| 
 | ||||
|         recv_req = asyncio.create_task(self._request_event.wait()) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user