mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-05 14:13:06 +03:00
Improve updates handling (#4483)
This commit is contained in:
parent
771954d010
commit
9b18304e9c
|
@ -55,12 +55,16 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
||||||
|
|
||||||
|
|
||||||
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||||
|
assert client._sender
|
||||||
assert dc_id is not None
|
assert dc_id is not None
|
||||||
sender, client._session.dcs = await connect_sender(
|
sender, client._session.dcs = await connect_sender(
|
||||||
client._config, client._session.dcs, DataCenter(id=dc_id)
|
client._config, client._session.dcs, DataCenter(id=dc_id)
|
||||||
)
|
)
|
||||||
async with client._sender_lock:
|
|
||||||
|
async with client._sender.lock:
|
||||||
|
old_sender = client._sender
|
||||||
client._sender = sender
|
client._sender = sender
|
||||||
|
await old_sender.disconnect()
|
||||||
|
|
||||||
|
|
||||||
async def bot_sign_in(self: Client, token: str) -> User:
|
async def bot_sign_in(self: Client, token: str) -> User:
|
||||||
|
|
|
@ -220,14 +220,15 @@ class Client:
|
||||||
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
||||||
|
|
||||||
self._sender: Optional[Sender] = None
|
self._sender: Optional[Sender] = None
|
||||||
self._sender_lock = asyncio.Lock()
|
|
||||||
self._sender_lock_flag = False
|
|
||||||
if isinstance(session, Storage):
|
if isinstance(session, Storage):
|
||||||
self._storage = session
|
storage = session
|
||||||
elif session is None:
|
elif session is None:
|
||||||
self._storage = MemorySession()
|
storage = MemorySession()
|
||||||
else:
|
else:
|
||||||
self._storage = SqliteSession(session)
|
storage = SqliteSession(session)
|
||||||
|
|
||||||
|
self._storage = storage
|
||||||
|
|
||||||
self._config = Config(
|
self._config = Config(
|
||||||
api_id=api_id,
|
api_id=api_id,
|
||||||
|
@ -2074,10 +2075,7 @@ class Client:
|
||||||
return await upload(self, fd, size, name)
|
return await upload(self, fd, size, name)
|
||||||
|
|
||||||
async def __call__(self, request: Request[Return]) -> Return:
|
async def __call__(self, request: Request[Return]) -> Return:
|
||||||
if not self._sender:
|
return await invoke_request(self, request)
|
||||||
raise ConnectionError("not connected")
|
|
||||||
|
|
||||||
return await invoke_request(self, self._sender, self._sender_lock, request)
|
|
||||||
|
|
||||||
async def __aenter__(self) -> Self:
|
async def __aenter__(self) -> Self:
|
||||||
await connect(self)
|
await connect(self)
|
||||||
|
|
|
@ -238,15 +238,16 @@ async def disconnect(self: Client) -> None:
|
||||||
|
|
||||||
async def invoke_request(
|
async def invoke_request(
|
||||||
client: Client,
|
client: Client,
|
||||||
sender: Sender,
|
|
||||||
lock: asyncio.Lock,
|
|
||||||
request: Request[Return],
|
request: Request[Return],
|
||||||
) -> Return:
|
) -> Return:
|
||||||
|
if not client._sender:
|
||||||
|
raise ConnectionError("not connected")
|
||||||
|
|
||||||
sleep_thresh = client._config.flood_sleep_threshold
|
sleep_thresh = client._config.flood_sleep_threshold
|
||||||
rx = sender.enqueue(request)
|
rx = client._sender.enqueue(request)
|
||||||
while True:
|
while True:
|
||||||
while not rx.done():
|
while not rx.done():
|
||||||
await step_sender(client, sender, lock)
|
await step_sender(client)
|
||||||
try:
|
try:
|
||||||
response = rx.result()
|
response = rx.result()
|
||||||
break
|
break
|
||||||
|
@ -254,30 +255,17 @@ async def invoke_request(
|
||||||
if e.code == 420 and e.value is not None and e.value < sleep_thresh:
|
if e.code == 420 and e.value is not None and e.value < sleep_thresh:
|
||||||
await asyncio.sleep(e.value)
|
await asyncio.sleep(e.value)
|
||||||
sleep_thresh -= e.value
|
sleep_thresh -= e.value
|
||||||
rx = sender.enqueue(request)
|
rx = client._sender.enqueue(request)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise adapt_rpc(e) from None
|
raise adapt_rpc(e) from None
|
||||||
return request.deserialize_response(response)
|
return request.deserialize_response(response)
|
||||||
|
|
||||||
|
|
||||||
async def step(client: Client) -> None:
|
async def step_sender(client: Client) -> None:
|
||||||
if client._sender:
|
|
||||||
await step_sender(client, client._sender, client._sender_lock)
|
|
||||||
|
|
||||||
|
|
||||||
async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None:
|
|
||||||
flag = client._sender_lock_flag
|
|
||||||
async with lock:
|
|
||||||
if client._sender_lock_flag != flag:
|
|
||||||
# different task already received an item from the network
|
|
||||||
return
|
|
||||||
# current task is responsible for receiving
|
|
||||||
# toggle the flag so any other task that comes after does not run again
|
|
||||||
client._sender_lock_flag = not client._sender_lock_flag
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updates = await sender.step()
|
assert client._sender
|
||||||
|
updates = await client._sender.step()
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
if client.connected:
|
if client.connected:
|
||||||
raise
|
raise
|
||||||
|
@ -290,7 +278,8 @@ async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> Non
|
||||||
|
|
||||||
async def run_until_disconnected(self: Client) -> None:
|
async def run_until_disconnected(self: Client) -> None:
|
||||||
while self.connected:
|
while self.connected:
|
||||||
await step(self)
|
if self._sender:
|
||||||
|
await step_sender(self)
|
||||||
|
|
||||||
|
|
||||||
def connected(client: Client) -> bool:
|
def connected(client: Client) -> bool:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from asyncio import FIRST_COMPLETED, Event, Future
|
from asyncio import FIRST_COMPLETED, Event, Future, Lock
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, Optional, Protocol, Type, TypeVar
|
from typing import Generic, Optional, Protocol, Type, TypeVar
|
||||||
|
@ -162,6 +162,7 @@ class Request(Generic[Return]):
|
||||||
class Sender:
|
class Sender:
|
||||||
dc_id: int
|
dc_id: int
|
||||||
addr: str
|
addr: str
|
||||||
|
lock: Lock
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
_reader: AsyncReader
|
_reader: AsyncReader
|
||||||
_writer: AsyncWriter
|
_writer: AsyncWriter
|
||||||
|
@ -191,6 +192,7 @@ class Sender:
|
||||||
return cls(
|
return cls(
|
||||||
dc_id=dc_id,
|
dc_id=dc_id,
|
||||||
addr=addr,
|
addr=addr,
|
||||||
|
lock=Lock(),
|
||||||
_logger=base_logger.getChild("mtsender"),
|
_logger=base_logger.getChild("mtsender"),
|
||||||
_reader=reader,
|
_reader=reader,
|
||||||
_writer=writer,
|
_writer=writer,
|
||||||
|
@ -233,6 +235,10 @@ class Sender:
|
||||||
return rx.result()
|
return rx.result()
|
||||||
|
|
||||||
async def step(self) -> list[Updates]:
|
async def step(self) -> list[Updates]:
|
||||||
|
async with self.lock:
|
||||||
|
return await self._step()
|
||||||
|
|
||||||
|
async def _step(self) -> list[Updates]:
|
||||||
self._try_fill_write()
|
self._try_fill_write()
|
||||||
|
|
||||||
recv_req = asyncio.create_task(self._request_event.wait())
|
recv_req = asyncio.create_task(self._request_event.wait())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user