Improve updates handling (#4483)

This commit is contained in:
Jahongir Qurbonov 2024-10-11 19:57:18 +05:00 committed by GitHub
parent 771954d010
commit 9b18304e9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 40 deletions

View File

@ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert client._sender
assert dc_id is not None assert dc_id is not None
sender, client._session.dcs = await connect_sender( sender, client._session.dcs = await connect_sender(
client._config, client._session.dcs, DataCenter(id=dc_id) client._config, client._session.dcs, DataCenter(id=dc_id)
) )
async with client._sender_lock:
async with client._sender.lock:
old_sender = client._sender
client._sender = sender client._sender = sender
await old_sender.disconnect()
async def bot_sign_in(self: Client, token: str) -> User: async def bot_sign_in(self: Client, token: str) -> User:

View File

@ -220,14 +220,15 @@ class Client:
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
self._sender: Optional[Sender] = None self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
self._sender_lock_flag = False
if isinstance(session, Storage): if isinstance(session, Storage):
self._storage = session storage = session
elif session is None: elif session is None:
self._storage = MemorySession() storage = MemorySession()
else: else:
self._storage = SqliteSession(session) storage = SqliteSession(session)
self._storage = storage
self._config = Config( self._config = Config(
api_id=api_id, api_id=api_id,
@ -2074,10 +2075,7 @@ class Client:
return await upload(self, fd, size, name) return await upload(self, fd, size, name)
async def __call__(self, request: Request[Return]) -> Return: async def __call__(self, request: Request[Return]) -> Return:
if not self._sender: return await invoke_request(self, request)
raise ConnectionError("not connected")
return await invoke_request(self, self._sender, self._sender_lock, request)
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
await connect(self) await connect(self)

View File

@ -238,15 +238,16 @@ async def disconnect(self: Client) -> None:
async def invoke_request( async def invoke_request(
client: Client, client: Client,
sender: Sender,
lock: asyncio.Lock,
request: Request[Return], request: Request[Return],
) -> Return: ) -> Return:
if not client._sender:
raise ConnectionError("not connected")
sleep_thresh = client._config.flood_sleep_threshold sleep_thresh = client._config.flood_sleep_threshold
rx = sender.enqueue(request) rx = client._sender.enqueue(request)
while True: while True:
while not rx.done(): while not rx.done():
await step_sender(client, sender, lock) await step_sender(client)
try: try:
response = rx.result() response = rx.result()
break break
@ -254,30 +255,17 @@ async def invoke_request(
if e.code == 420 and e.value is not None and e.value < sleep_thresh: if e.code == 420 and e.value is not None and e.value < sleep_thresh:
await asyncio.sleep(e.value) await asyncio.sleep(e.value)
sleep_thresh -= e.value sleep_thresh -= e.value
rx = sender.enqueue(request) rx = client._sender.enqueue(request)
continue continue
else: else:
raise adapt_rpc(e) from None raise adapt_rpc(e) from None
return request.deserialize_response(response) return request.deserialize_response(response)
async def step(client: Client) -> None: async def step_sender(client: Client) -> None:
if client._sender:
await step_sender(client, client._sender, client._sender_lock)
async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None:
flag = client._sender_lock_flag
async with lock:
if client._sender_lock_flag != flag:
# different task already received an item from the network
return
# current task is responsible for receiving
# toggle the flag so any other task that comes after does not run again
client._sender_lock_flag = not client._sender_lock_flag
try: try:
updates = await sender.step() assert client._sender
updates = await client._sender.step()
except ConnectionError: except ConnectionError:
if client.connected: if client.connected:
raise raise
@ -290,7 +278,8 @@ async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> Non
async def run_until_disconnected(self: Client) -> None: async def run_until_disconnected(self: Client) -> None:
while self.connected: while self.connected:
await step(self) if self._sender:
await step_sender(self)
def connected(client: Client) -> bool: def connected(client: Client) -> bool:

View File

@ -3,7 +3,7 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future from asyncio import FIRST_COMPLETED, Event, Future, Lock
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import Generic, Optional, Protocol, Type, TypeVar
@ -162,6 +162,7 @@ class Request(Generic[Return]):
class Sender: class Sender:
dc_id: int dc_id: int
addr: str addr: str
lock: Lock
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
@ -191,6 +192,7 @@ class Sender:
return cls( return cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
lock=Lock(),
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
@ -233,6 +235,10 @@ class Sender:
return rx.result() return rx.result()
async def step(self) -> list[Updates]: async def step(self) -> list[Updates]:
async with self.lock:
return await self._step()
async def _step(self) -> list[Updates]:
self._try_fill_write() self._try_fill_write()
recv_req = asyncio.create_task(self._request_event.wait()) recv_req = asyncio.create_task(self._request_event.wait())