diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 8fec6a95..91564209 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -10,6 +10,7 @@ from typing_extensions import Self from ....version import __version__ as default_version from ...mtsender import Connector, Sender +from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy from ...session import ( ChannelRef, ChatHashCache, @@ -215,6 +216,7 @@ class Client: lang_code: Optional[str] = None, datacenter: Optional[DataCenter] = None, connector: Optional[Connector] = None, + reconnection_policy: Optional[ReconnectionPolicy] = None, ) -> None: assert __package__ base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) @@ -246,7 +248,7 @@ class Client: update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), - reconnection_policy=None, + reconnection_policy=reconnection_policy or NoReconnect(), ) self._session = Session() diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 2770ed9a..e790457d 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -46,6 +46,7 @@ class Config: api_hash: str base_logger: logging.Logger connector: Connector + reconnection_policy: ReconnectionPolicy device_model: str = field(default_factory=default_device_model) system_version: str = field(default_factory=default_system_version) app_version: str = __version__ @@ -55,7 +56,6 @@ class Config: datacenter: Optional[DataCenter] = None flood_sleep_threshold: int = 60 update_queue_limit: Optional[int] = None - reconnection_policy: Optional[ReconnectionPolicy] = None KNOWN_DCS = [ diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py index 227d6fdd..26c94f06 100644 --- a/client/src/telethon/_impl/mtsender/errors.py +++ b/client/src/telethon/_impl/mtsender/errors.py @@ -1,11 +1,6 @@ -import io +from struct import error as struct_error from ..mtproto.mtp.types import DeserializationError from ..mtproto.transport.abcs import TransportError -ReadError = io.BlockingIOError | TransportError | DeserializationError - - -class IOError(io.BlockingIOError): - def __init__(self, *args: object) -> None: - super().__init__(*args) +ReadError = struct_error | TransportError | DeserializationError diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index f13f52b8..f9b8c1f5 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -29,7 +29,7 @@ from ..tl import Request as RemoteCall from ..tl.abcs import Updates from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect -from ..tl.types import UpdateDeleteMessages, UpdateShort +from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from .reconnection import ReconnectionPolicy @@ -165,7 +165,7 @@ class Sender: dc_id: int addr: str _connector: Connector - _reconnection_policy: Optional[ReconnectionPolicy] + _reconnection_policy: ReconnectionPolicy _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -190,7 +190,7 @@ class Sender: addr: str, *, connector: Connector, - reconnection_policy: Optional[ReconnectionPolicy], + reconnection_policy: ReconnectionPolicy, base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") @@ -249,7 +249,7 @@ class Sender: try: await self._step() except Exception as error: - self._on_error(error) + await self._on_error(error) async def _step(self) -> None: if not self._writing: @@ -308,10 +308,7 @@ class Sender: self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") await asyncio.sleep(1) - delay = False - - if self._reconnection_policy is not None: - delay = self._reconnection_policy.should_retry(attempts) + delay = self._reconnection_policy.should_retry(attempts) if delay: if delay is not True: @@ -379,7 +376,7 @@ class Sender: if isinstance(req.state, Serialized): req.state = Sent(req.state.msg_id, req.state.container_msg_id) - def _on_error(self, error: Exception): + async def _on_error(self, error: Exception) -> None: self._logger.info(f"handling error: {error}") self._transport.reset() self._mtp.reset() @@ -392,10 +389,17 @@ class Sender: self._read_buffer.clear() self._mtp_buffer.clear() - match error: - # TODO - case DeserializationFailure(): - pass + if isinstance(error, struct.error) and self._reconnection_policy.should_retry( + 0 + ): + self._logger.info(f"read error occurred: {error}") + await self._try_connect() + + for req in self._requests: + req.state = NotSerialized() + + self._updates.append(UpdatesTooLong()) + return self._logger.warning( f"marking all {len(self._requests)} request(s) as failed: {error}" @@ -561,7 +565,7 @@ async def connect( auth_key: Optional[bytes], base_logger: logging.Logger, connector: Connector, - reconnection_policy: Optional[ReconnectionPolicy] = None, + reconnection_policy: ReconnectionPolicy, ) -> Sender: if auth_key is None: sender = await Sender.connect(