Add reconnection policy support to Sender and Config classes; refactor error handling in Sender

This commit is contained in:
Jahongir Qurbonov 2025-06-03 10:27:43 +05:00
parent ac611dbbd4
commit 3bfa64a5d6
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
4 changed files with 24 additions and 23 deletions

View File

@ -10,6 +10,7 @@ from typing_extensions import Self
from ....version import __version__ as default_version from ....version import __version__ as default_version
from ...mtsender import Connector, Sender from ...mtsender import Connector, Sender
from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy
from ...session import ( from ...session import (
ChannelRef, ChannelRef,
ChatHashCache, ChatHashCache,
@ -215,6 +216,7 @@ class Client:
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
datacenter: Optional[DataCenter] = None, datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None, connector: Optional[Connector] = None,
reconnection_policy: Optional[ReconnectionPolicy] = None,
) -> None: ) -> None:
assert __package__ assert __package__
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
@ -246,7 +248,7 @@ class Client:
update_queue_limit=update_queue_limit, update_queue_limit=update_queue_limit,
base_logger=base_logger, base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
reconnection_policy=None, reconnection_policy=reconnection_policy or NoReconnect(),
) )
self._session = Session() self._session = Session()

View File

@ -46,6 +46,7 @@ class Config:
api_hash: str api_hash: str
base_logger: logging.Logger base_logger: logging.Logger
connector: Connector connector: Connector
reconnection_policy: ReconnectionPolicy
device_model: str = field(default_factory=default_device_model) device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version) system_version: str = field(default_factory=default_system_version)
app_version: str = __version__ app_version: str = __version__
@ -55,7 +56,6 @@ class Config:
datacenter: Optional[DataCenter] = None datacenter: Optional[DataCenter] = None
flood_sleep_threshold: int = 60 flood_sleep_threshold: int = 60
update_queue_limit: Optional[int] = None update_queue_limit: Optional[int] = None
reconnection_policy: Optional[ReconnectionPolicy] = None
KNOWN_DCS = [ KNOWN_DCS = [

View File

@ -1,11 +1,6 @@
import io from struct import error as struct_error
from ..mtproto.mtp.types import DeserializationError from ..mtproto.mtp.types import DeserializationError
from ..mtproto.transport.abcs import TransportError from ..mtproto.transport.abcs import TransportError
ReadError = io.BlockingIOError | TransportError | DeserializationError ReadError = struct_error | TransportError | DeserializationError
class IOError(io.BlockingIOError):
def __init__(self, *args: object) -> None:
super().__init__(*args)

View File

@ -29,7 +29,7 @@ from ..tl import Request as RemoteCall
from ..tl.abcs import Updates from ..tl.abcs import Updates
from ..tl.core import Serializable from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .reconnection import ReconnectionPolicy from .reconnection import ReconnectionPolicy
@ -165,7 +165,7 @@ class Sender:
dc_id: int dc_id: int
addr: str addr: str
_connector: Connector _connector: Connector
_reconnection_policy: Optional[ReconnectionPolicy] _reconnection_policy: ReconnectionPolicy
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
@ -190,7 +190,7 @@ class Sender:
addr: str, addr: str,
*, *,
connector: Connector, connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy], reconnection_policy: ReconnectionPolicy,
base_logger: logging.Logger, base_logger: logging.Logger,
) -> Self: ) -> Self:
ip, port = addr.split(":") ip, port = addr.split(":")
@ -249,7 +249,7 @@ class Sender:
try: try:
await self._step() await self._step()
except Exception as error: except Exception as error:
self._on_error(error) await self._on_error(error)
async def _step(self) -> None: async def _step(self) -> None:
if not self._writing: if not self._writing:
@ -308,10 +308,7 @@ class Sender:
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
await asyncio.sleep(1) await asyncio.sleep(1)
delay = False delay = self._reconnection_policy.should_retry(attempts)
if self._reconnection_policy is not None:
delay = self._reconnection_policy.should_retry(attempts)
if delay: if delay:
if delay is not True: if delay is not True:
@ -379,7 +376,7 @@ class Sender:
if isinstance(req.state, Serialized): if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id) req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_error(self, error: Exception): async def _on_error(self, error: Exception) -> None:
self._logger.info(f"handling error: {error}") self._logger.info(f"handling error: {error}")
self._transport.reset() self._transport.reset()
self._mtp.reset() self._mtp.reset()
@ -392,10 +389,17 @@ class Sender:
self._read_buffer.clear() self._read_buffer.clear()
self._mtp_buffer.clear() self._mtp_buffer.clear()
match error: if isinstance(error, struct.error) and self._reconnection_policy.should_retry(
# TODO 0
case DeserializationFailure(): ):
pass self._logger.info(f"read error occurred: {error}")
await self._try_connect()
for req in self._requests:
req.state = NotSerialized()
self._updates.append(UpdatesTooLong())
return
self._logger.warning( self._logger.warning(
f"marking all {len(self._requests)} request(s) as failed: {error}" f"marking all {len(self._requests)} request(s) as failed: {error}"
@ -561,7 +565,7 @@ async def connect(
auth_key: Optional[bytes], auth_key: Optional[bytes],
base_logger: logging.Logger, base_logger: logging.Logger,
connector: Connector, connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy] = None, reconnection_policy: ReconnectionPolicy,
) -> Sender: ) -> Sender:
if auth_key is None: if auth_key is None:
sender = await Sender.connect( sender = await Sender.connect(