Improve updates handling (#4483)

This commit is contained in:
Jahongir Qurbonov 2024-10-11 19:57:18 +05:00 committed by GitHub
parent 771954d010
commit 9b18304e9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 40 deletions

View File

@ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert client._sender
assert dc_id is not None
sender, client._session.dcs = await connect_sender(
client._config, client._session.dcs, DataCenter(id=dc_id)
)
async with client._sender_lock:
async with client._sender.lock:
old_sender = client._sender
client._sender = sender
await old_sender.disconnect()
async def bot_sign_in(self: Client, token: str) -> User:

View File

@ -220,14 +220,15 @@ class Client:
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
self._sender_lock_flag = False
if isinstance(session, Storage):
self._storage = session
storage = session
elif session is None:
self._storage = MemorySession()
storage = MemorySession()
else:
self._storage = SqliteSession(session)
storage = SqliteSession(session)
self._storage = storage
self._config = Config(
api_id=api_id,
@ -2074,10 +2075,7 @@ class Client:
return await upload(self, fd, size, name)
async def __call__(self, request: Request[Return]) -> Return:
if not self._sender:
raise ConnectionError("not connected")
return await invoke_request(self, self._sender, self._sender_lock, request)
return await invoke_request(self, request)
async def __aenter__(self) -> Self:
await connect(self)

View File

@ -238,15 +238,16 @@ async def disconnect(self: Client) -> None:
async def invoke_request(
client: Client,
sender: Sender,
lock: asyncio.Lock,
request: Request[Return],
) -> Return:
if not client._sender:
raise ConnectionError("not connected")
sleep_thresh = client._config.flood_sleep_threshold
rx = sender.enqueue(request)
rx = client._sender.enqueue(request)
while True:
while not rx.done():
await step_sender(client, sender, lock)
await step_sender(client)
try:
response = rx.result()
break
@ -254,30 +255,17 @@ async def invoke_request(
if e.code == 420 and e.value is not None and e.value < sleep_thresh:
await asyncio.sleep(e.value)
sleep_thresh -= e.value
rx = sender.enqueue(request)
rx = client._sender.enqueue(request)
continue
else:
raise adapt_rpc(e) from None
return request.deserialize_response(response)
async def step(client: Client) -> None:
if client._sender:
await step_sender(client, client._sender, client._sender_lock)
async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None:
flag = client._sender_lock_flag
async with lock:
if client._sender_lock_flag != flag:
# different task already received an item from the network
return
# current task is responsible for receiving
# toggle the flag so any other task that comes after does not run again
client._sender_lock_flag = not client._sender_lock_flag
async def step_sender(client: Client) -> None:
try:
updates = await sender.step()
assert client._sender
updates = await client._sender.step()
except ConnectionError:
if client.connected:
raise
@ -290,7 +278,8 @@ async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> Non
async def run_until_disconnected(self: Client) -> None:
while self.connected:
await step(self)
if self._sender:
await step_sender(self)
def connected(client: Client) -> bool:

View File

@ -3,7 +3,7 @@ import logging
import struct
import time
from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future
from asyncio import FIRST_COMPLETED, Event, Future, Lock
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar
@ -162,6 +162,7 @@ class Request(Generic[Return]):
class Sender:
dc_id: int
addr: str
lock: Lock
_logger: logging.Logger
_reader: AsyncReader
_writer: AsyncWriter
@ -191,6 +192,7 @@ class Sender:
return cls(
dc_id=dc_id,
addr=addr,
lock=Lock(),
_logger=base_logger.getChild("mtsender"),
_reader=reader,
_writer=writer,
@ -233,6 +235,10 @@ class Sender:
return rx.result()
async def step(self) -> list[Updates]:
async with self.lock:
return await self._step()
async def _step(self) -> list[Updates]:
self._try_fill_write()
recv_req = asyncio.create_task(self._request_event.wait())