mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-14 21:46:38 +03:00
Improve updates handling (#4483)
This commit is contained in:
parent
771954d010
commit
9b18304e9c
|
@ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
|||
|
||||
|
||||
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||
assert client._sender
|
||||
assert dc_id is not None
|
||||
sender, client._session.dcs = await connect_sender(
|
||||
client._config, client._session.dcs, DataCenter(id=dc_id)
|
||||
)
|
||||
async with client._sender_lock:
|
||||
|
||||
async with client._sender.lock:
|
||||
old_sender = client._sender
|
||||
client._sender = sender
|
||||
await old_sender.disconnect()
|
||||
|
||||
|
||||
async def bot_sign_in(self: Client, token: str) -> User:
|
||||
|
|
|
@ -220,14 +220,15 @@ class Client:
|
|||
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
||||
|
||||
self._sender: Optional[Sender] = None
|
||||
self._sender_lock = asyncio.Lock()
|
||||
self._sender_lock_flag = False
|
||||
|
||||
if isinstance(session, Storage):
|
||||
self._storage = session
|
||||
storage = session
|
||||
elif session is None:
|
||||
self._storage = MemorySession()
|
||||
storage = MemorySession()
|
||||
else:
|
||||
self._storage = SqliteSession(session)
|
||||
storage = SqliteSession(session)
|
||||
|
||||
self._storage = storage
|
||||
|
||||
self._config = Config(
|
||||
api_id=api_id,
|
||||
|
@ -2074,10 +2075,7 @@ class Client:
|
|||
return await upload(self, fd, size, name)
|
||||
|
||||
async def __call__(self, request: Request[Return]) -> Return:
|
||||
if not self._sender:
|
||||
raise ConnectionError("not connected")
|
||||
|
||||
return await invoke_request(self, self._sender, self._sender_lock, request)
|
||||
return await invoke_request(self, request)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await connect(self)
|
||||
|
|
|
@ -238,15 +238,16 @@ async def disconnect(self: Client) -> None:
|
|||
|
||||
async def invoke_request(
|
||||
client: Client,
|
||||
sender: Sender,
|
||||
lock: asyncio.Lock,
|
||||
request: Request[Return],
|
||||
) -> Return:
|
||||
if not client._sender:
|
||||
raise ConnectionError("not connected")
|
||||
|
||||
sleep_thresh = client._config.flood_sleep_threshold
|
||||
rx = sender.enqueue(request)
|
||||
rx = client._sender.enqueue(request)
|
||||
while True:
|
||||
while not rx.done():
|
||||
await step_sender(client, sender, lock)
|
||||
await step_sender(client)
|
||||
try:
|
||||
response = rx.result()
|
||||
break
|
||||
|
@ -254,43 +255,31 @@ async def invoke_request(
|
|||
if e.code == 420 and e.value is not None and e.value < sleep_thresh:
|
||||
await asyncio.sleep(e.value)
|
||||
sleep_thresh -= e.value
|
||||
rx = sender.enqueue(request)
|
||||
rx = client._sender.enqueue(request)
|
||||
continue
|
||||
else:
|
||||
raise adapt_rpc(e) from None
|
||||
return request.deserialize_response(response)
|
||||
|
||||
|
||||
async def step(client: Client) -> None:
|
||||
if client._sender:
|
||||
await step_sender(client, client._sender, client._sender_lock)
|
||||
|
||||
|
||||
async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None:
|
||||
flag = client._sender_lock_flag
|
||||
async with lock:
|
||||
if client._sender_lock_flag != flag:
|
||||
# different task already received an item from the network
|
||||
async def step_sender(client: Client) -> None:
|
||||
try:
|
||||
assert client._sender
|
||||
updates = await client._sender.step()
|
||||
except ConnectionError:
|
||||
if client.connected:
|
||||
raise
|
||||
else:
|
||||
# disconnect was called, so the socket returning 0 bytes is expected
|
||||
return
|
||||
# current task is responsible for receiving
|
||||
# toggle the flag so any other task that comes after does not run again
|
||||
client._sender_lock_flag = not client._sender_lock_flag
|
||||
|
||||
try:
|
||||
updates = await sender.step()
|
||||
except ConnectionError:
|
||||
if client.connected:
|
||||
raise
|
||||
else:
|
||||
# disconnect was called, so the socket returning 0 bytes is expected
|
||||
return
|
||||
|
||||
process_socket_updates(client, updates)
|
||||
process_socket_updates(client, updates)
|
||||
|
||||
|
||||
async def run_until_disconnected(self: Client) -> None:
|
||||
while self.connected:
|
||||
await step(self)
|
||||
if self._sender:
|
||||
await step_sender(self)
|
||||
|
||||
|
||||
def connected(client: Client) -> bool:
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import struct
|
||||
import time
|
||||
from abc import ABC
|
||||
from asyncio import FIRST_COMPLETED, Event, Future
|
||||
from asyncio import FIRST_COMPLETED, Event, Future, Lock
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Optional, Protocol, Type, TypeVar
|
||||
|
@ -162,6 +162,7 @@ class Request(Generic[Return]):
|
|||
class Sender:
|
||||
dc_id: int
|
||||
addr: str
|
||||
lock: Lock
|
||||
_logger: logging.Logger
|
||||
_reader: AsyncReader
|
||||
_writer: AsyncWriter
|
||||
|
@ -191,6 +192,7 @@ class Sender:
|
|||
return cls(
|
||||
dc_id=dc_id,
|
||||
addr=addr,
|
||||
lock=Lock(),
|
||||
_logger=base_logger.getChild("mtsender"),
|
||||
_reader=reader,
|
||||
_writer=writer,
|
||||
|
@ -233,6 +235,10 @@ class Sender:
|
|||
return rx.result()
|
||||
|
||||
async def step(self) -> list[Updates]:
|
||||
async with self.lock:
|
||||
return await self._step()
|
||||
|
||||
async def _step(self) -> list[Updates]:
|
||||
self._try_fill_write()
|
||||
|
||||
recv_req = asyncio.create_task(self._request_event.wait())
|
||||
|
|
Loading…
Reference in New Issue
Block a user