Add reconnection policy support to Sender and Config classes; refactor error handling in Sender

This commit is contained in:
Jahongir Qurbonov 2025-06-03 10:27:43 +05:00
parent ac611dbbd4
commit 3bfa64a5d6
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
4 changed files with 24 additions and 23 deletions

View File

@ -10,6 +10,7 @@ from typing_extensions import Self
from ....version import __version__ as default_version
from ...mtsender import Connector, Sender
from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy
from ...session import (
ChannelRef,
ChatHashCache,
@ -215,6 +216,7 @@ class Client:
lang_code: Optional[str] = None,
datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None,
reconnection_policy: Optional[ReconnectionPolicy] = None,
) -> None:
assert __package__
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
@ -246,7 +248,7 @@ class Client:
update_queue_limit=update_queue_limit,
base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
reconnection_policy=None,
reconnection_policy=reconnection_policy or NoReconnect(),
)
self._session = Session()

View File

@ -46,6 +46,7 @@ class Config:
api_hash: str
base_logger: logging.Logger
connector: Connector
reconnection_policy: ReconnectionPolicy
device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version)
app_version: str = __version__
@ -55,7 +56,6 @@ class Config:
datacenter: Optional[DataCenter] = None
flood_sleep_threshold: int = 60
update_queue_limit: Optional[int] = None
reconnection_policy: Optional[ReconnectionPolicy] = None
KNOWN_DCS = [

View File

@ -1,11 +1,6 @@
import io
from struct import error as struct_error
from ..mtproto.mtp.types import DeserializationError
from ..mtproto.transport.abcs import TransportError
ReadError = io.BlockingIOError | TransportError | DeserializationError
class IOError(io.BlockingIOError):
def __init__(self, *args: object) -> None:
super().__init__(*args)
ReadError = struct_error | TransportError | DeserializationError

View File

@ -29,7 +29,7 @@ from ..tl import Request as RemoteCall
from ..tl.abcs import Updates
from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .reconnection import ReconnectionPolicy
@ -165,7 +165,7 @@ class Sender:
dc_id: int
addr: str
_connector: Connector
_reconnection_policy: Optional[ReconnectionPolicy]
_reconnection_policy: ReconnectionPolicy
_logger: logging.Logger
_reader: AsyncReader
_writer: AsyncWriter
@ -190,7 +190,7 @@ class Sender:
addr: str,
*,
connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy],
reconnection_policy: ReconnectionPolicy,
base_logger: logging.Logger,
) -> Self:
ip, port = addr.split(":")
@ -249,7 +249,7 @@ class Sender:
try:
await self._step()
except Exception as error:
self._on_error(error)
await self._on_error(error)
async def _step(self) -> None:
if not self._writing:
@ -308,10 +308,7 @@ class Sender:
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
await asyncio.sleep(1)
delay = False
if self._reconnection_policy is not None:
delay = self._reconnection_policy.should_retry(attempts)
delay = self._reconnection_policy.should_retry(attempts)
if delay:
if delay is not True:
@ -379,7 +376,7 @@ class Sender:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_error(self, error: Exception):
async def _on_error(self, error: Exception) -> None:
self._logger.info(f"handling error: {error}")
self._transport.reset()
self._mtp.reset()
@ -392,10 +389,17 @@ class Sender:
self._read_buffer.clear()
self._mtp_buffer.clear()
match error:
# TODO
case DeserializationFailure():
pass
if isinstance(error, struct.error) and self._reconnection_policy.should_retry(
0
):
self._logger.info(f"read error occurred: {error}")
await self._try_connect()
for req in self._requests:
req.state = NotSerialized()
self._updates.append(UpdatesTooLong())
return
self._logger.warning(
f"marking all {len(self._requests)} request(s) as failed: {error}"
@ -561,7 +565,7 @@ async def connect(
auth_key: Optional[bytes],
base_logger: logging.Logger,
connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy] = None,
reconnection_policy: ReconnectionPolicy,
) -> Sender:
if auth_key is None:
sender = await Sender.connect(