From 3250f9ec376b80c8c5a0941f3c70c04d30593290 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sat, 16 Mar 2024 15:00:38 +0100 Subject: [PATCH] Unify the way deserialization is returned --- client/src/telethon/_impl/mtproto/__init__.py | 14 +- .../telethon/_impl/mtproto/mtp/__init__.py | 4 +- .../telethon/_impl/mtproto/mtp/encrypted.py | 39 ++-- .../src/telethon/_impl/mtproto/mtp/plain.py | 10 +- .../src/telethon/_impl/mtproto/mtp/types.py | 46 ++-- client/src/telethon/_impl/mtsender/sender.py | 206 +++++++++++------- 6 files changed, 202 insertions(+), 117 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/__init__.py b/client/src/telethon/_impl/mtproto/__init__.py index 08e53086..d0ff7b43 100644 --- a/client/src/telethon/_impl/mtproto/__init__.py +++ b/client/src/telethon/_impl/mtproto/__init__.py @@ -2,7 +2,17 @@ from .authentication import CreatedKey, Step1, Step2, Step3, create_key from .authentication import step1 as auth_step1 from .authentication import step2 as auth_step2 from .authentication import step3 as auth_step3 -from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError +from .mtp import ( + BadMessage, + Deserialization, + Encrypted, + MsgId, + Mtp, + Plain, + RpcError, + RpcResult, + Update, +) from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport from .utils import DEFAULT_COMPRESSION_THRESHOLD @@ -22,6 +32,8 @@ __all__ = [ "Mtp", "Plain", "RpcError", + "RpcResult", + "Update", "Abridged", "BadStatus", "Full", diff --git a/client/src/telethon/_impl/mtproto/mtp/__init__.py b/client/src/telethon/_impl/mtproto/mtp/__init__.py index 64c6424f..4f572684 100644 --- a/client/src/telethon/_impl/mtproto/mtp/__init__.py +++ b/client/src/telethon/_impl/mtproto/mtp/__init__.py @@ -1,6 +1,6 @@ from .encrypted import Encrypted from .plain import Plain -from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError +from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update __all__ = [ "Encrypted", @@ -10,4 +10,6 @@ __all__ = [ "MsgId", "Mtp", "RpcError", + "RpcResult", + "Update", ] diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 6d6c79dd..07ec78c7 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -60,7 +60,7 @@ from ..utils import ( gzip_decompress, message_requires_ack, ) -from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult +from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update NUM_FUTURE_SALTS = 64 @@ -112,8 +112,7 @@ class Encrypted(Mtp): ] self._start_salt_time: Optional[Tuple[int, float]] = None self._compression_threshold = compression_threshold - self._rpc_results: List[Tuple[MsgId, RpcResult]] = [] - self._updates: List[bytes] = [] + self._deserialization: List[Deserialization] = [] self._buffer = bytearray() self._salt_request_msg_id: Optional[int] = None @@ -244,12 +243,9 @@ class Encrypted(Mtp): inner_constructor = struct.unpack_from(" None: constructor_id = struct.unpack_from("I", body)[0] if constructor_id in UPDATE_IDS: - self._updates.append(body) + self._deserialization.append(Update(body)) def _handle_ack(self, message: Message) -> None: MsgsAck.from_bytes(message.body) @@ -276,13 +272,13 @@ class Encrypted(Mtp): bad_msg = AbcBadMsgNotification.from_bytes(message.body) assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) - exc = BadMessage(code=bad_msg.error_code) + exc = BadMessage(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code) if bad_msg.bad_msg_id == self._salt_request_msg_id: # Response to internal request, do not propagate. self._salt_request_msg_id = None else: - self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc)) + self._deserialization.append(exc) if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: # If we had no valid salt, this error is expected. @@ -331,7 +327,9 @@ class Encrypted(Mtp): # Response to internal request, do not propagate. self._salt_request_msg_id = None else: - self._rpc_results.append((MsgId(salts.req_msg_id), message.body)) + self._deserialization.append( + RpcResult(MsgId(salts.req_msg_id), message.body) + ) self._start_salt_time = (salts.now, self._adjusted_now()) self._salts = salts.salts @@ -343,7 +341,7 @@ class Encrypted(Mtp): def _handle_pong(self, message: Message) -> None: pong = Pong.from_bytes(message.body) - self._rpc_results.append((MsgId(pong.msg_id), message.body)) + self._deserialization.append(RpcResult(MsgId(pong.msg_id), message.body)) def _handle_destroy_session(self, message: Message) -> None: DestroySessionRes.from_bytes(message.body) @@ -378,7 +376,7 @@ class Encrypted(Mtp): HttpWait.from_bytes(message.body) def _handle_update(self, message: Message) -> None: - self._updates.append(message.body) + self._deserialization.append(Update(message.body)) def _try_request_salts(self) -> None: if ( @@ -441,7 +439,7 @@ class Encrypted(Mtp): msg_id, buffer = result return msg_id, encrypt_data_v2(buffer, self._auth_key) - def deserialize(self, payload: bytes) -> Deserialization: + def deserialize(self, payload: bytes) -> List[Deserialization]: check_message_buffer(payload) plaintext = decrypt_data_v2(payload, self._auth_key) @@ -452,7 +450,6 @@ class Encrypted(Mtp): self._process_message(Message._read_from(Reader(memoryview(plaintext)[16:]))) - result = Deserialization(rpc_results=self._rpc_results, updates=self._updates) - self._rpc_results = [] - self._updates = [] + result = self._deserialization[:] + self._deserialization.clear() return result diff --git a/client/src/telethon/_impl/mtproto/mtp/plain.py b/client/src/telethon/_impl/mtproto/mtp/plain.py index 89dda0d3..acb16838 100644 --- a/client/src/telethon/_impl/mtproto/mtp/plain.py +++ b/client/src/telethon/_impl/mtproto/mtp/plain.py @@ -1,8 +1,8 @@ import struct -from typing import Optional, Tuple +from typing import List, Optional, Tuple from ..utils import check_message_buffer -from .types import Deserialization, MsgId, Mtp +from .types import Deserialization, MsgId, Mtp, RpcResult class Plain(Mtp): @@ -31,7 +31,7 @@ class Plain(Mtp): self._buffer.clear() return MsgId(0), result - def deserialize(self, payload: bytes) -> Deserialization: + def deserialize(self, payload: bytes) -> List[Deserialization]: check_message_buffer(payload) auth_key_id, msg_id, length = struct.unpack_from(" None: append_value = f" ({value})" if value else "" - super().__init__(f"rpc error {code}: {name}{append_value}") + super().__init__(f"rpc error {code}: {name}{append_value}", *args) + self.msg_id = msg_id self._code = code self._name = name self._value = value @@ -121,13 +145,15 @@ NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33} class BadMessage(ValueError): def __init__( self, - *, + *args: object, + msg_id: MsgId = MsgId(0), code: int, caused_by: Optional[int] = None, ) -> None: description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available" - super().__init__(f"bad msg={code}: {description}") + super().__init__(f"bad msg={code}: {description}", *args) + self.msg_id = msg_id self._code = code self._caused_by = caused_by self.severity = ( @@ -152,13 +178,7 @@ class BadMessage(ValueError): return self._code == other._code -RpcResult = bytes | RpcError | BadMessage - - -@dataclass -class Deserialization: - rpc_results: List[Tuple[MsgId, RpcResult]] - updates: List[bytes] +Deserialization = Update | RpcResult | RpcError | BadMessage # https://core.telegram.org/mtproto/description @@ -181,7 +201,7 @@ class Mtp(ABC): """ @abstractmethod - def deserialize(self, payload: bytes) -> Deserialization: + def deserialize(self, payload: bytes) -> List[Deserialization]: """ Deserialize incoming buffer payload. """ diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 57de56f0..d416c382 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -5,7 +5,17 @@ import time from abc import ABC from asyncio import FIRST_COMPLETED, Event, Future from dataclasses import dataclass -from typing import Generic, List, Optional, Protocol, Self, Tuple, Type, TypeVar +from typing import ( + Generic, + Iterator, + List, + Optional, + Protocol, + Self, + Tuple, + Type, + TypeVar, +) from ..crypto import AuthKey from ..mtproto import ( @@ -16,7 +26,9 @@ from ..mtproto import ( Mtp, Plain, RpcError, + RpcResult, Transport, + Update, authentication, ) from ..tl import Request as RemoteCall @@ -127,17 +139,19 @@ class NotSerialized(RequestState): class Serialized(RequestState): - __slots__ = ("msg_id",) + __slots__ = ("msg_id", "container_msg_id") def __init__(self, msg_id: MsgId): self.msg_id = msg_id + self.container_msg_id = msg_id class Sent(RequestState): - __slots__ = ("msg_id",) + __slots__ = ("msg_id", "container_msg_id") - def __init__(self, msg_id: MsgId): + def __init__(self, msg_id: MsgId, container_msg_id: MsgId): self.msg_id = msg_id + self.container_msg_id = container_msg_id Return = TypeVar("Return") @@ -273,7 +287,11 @@ class Sender: result = self._mtp.finalize() if result: - _, mtp_buffer = result + container_msg_id, mtp_buffer = result + for request in self._requests: + if isinstance(request.state, Serialized): + request.state.container_msg_id = container_msg_id + self._transport.pack(mtp_buffer, self._writer.write) self._write_drain_pending = True @@ -299,7 +317,7 @@ class Sender: def _on_net_write(self) -> None: for req in self._requests: if isinstance(req.state, Serialized): - req.state = Sent(req.state.msg_id) + req.state = Sent(req.state.msg_id, req.state.container_msg_id) def _on_ping_timeout(self) -> None: ping_id = generate_random_id() @@ -313,31 +331,44 @@ class Sender: self._next_ping = asyncio.get_running_loop().time() + PING_DELAY def _process_mtp_buffer(self, updates: List[Updates]) -> None: - result = self._mtp.deserialize(self._mtp_buffer) + results = self._mtp.deserialize(self._mtp_buffer) - for update in result.updates: - try: - u = Updates.from_bytes(update) - except ValueError: - cid = struct.unpack_from("I", update)[0] - alt_classes: Tuple[Type[Serializable], ...] = ( - AffectedFoundMessages, - AffectedHistory, - AffectedMessages, - ) - for cls in alt_classes: - if cid == cls.constructor_id(): - affected = cls.from_bytes(update) - # mypy struggles with the types here quite a bit - assert isinstance( - affected, - ( - AffectedFoundMessages, - AffectedHistory, - AffectedMessages, - ), - ) - u = UpdateShort( + for result in results: + if isinstance(result, Update): + self._process_update(updates, result.body) + elif isinstance(result, RpcResult): + self._process_result(result) + elif isinstance(result, RpcError): + self._process_error(result) + elif isinstance(result, BadMessage): + self._process_bad_message(result) + else: + raise RuntimeError("unexpected case") + + def _process_update(self, updates: List[Updates], update: bytes) -> None: + try: + updates.append(Updates.from_bytes(update)) + except ValueError: + cid = struct.unpack_from("I", update)[0] + alt_classes: Tuple[Type[Serializable], ...] = ( + AffectedFoundMessages, + AffectedHistory, + AffectedMessages, + ) + for cls in alt_classes: + if cid == cls.constructor_id(): + affected = cls.from_bytes(update) + # mypy struggles with the types here quite a bit + assert isinstance( + affected, + ( + AffectedFoundMessages, + AffectedHistory, + AffectedMessages, + ), + ) + updates.append( + UpdateShort( update=UpdateDeleteMessages( messages=[], pts=affected.pts, @@ -345,58 +376,83 @@ class Sender: ), date=0, ) - break - else: - self._logger.warning( - "failed to deserialize incoming update; make sure the session is not in use elsewhere: %s", - update.hex(), ) - continue - - updates.append(u) - - for msg_id, ret in result.rpc_results: - for i, req in enumerate(self._requests): - if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: - raise RuntimeError("got rpc result for unsent request") - elif isinstance(req.state, Sent) and req.state.msg_id == msg_id: - del self._requests[i] break else: self._logger.warning( - "telegram sent rpc_result for unknown msg_id=%d: %s", - msg_id, - ret.hex() if isinstance(ret, bytes) else repr(ret), + "failed to deserialize incoming update; make sure the session is not in use elsewhere: %s", + update.hex(), ) - continue + return - if isinstance(ret, bytes): - assert len(ret) >= 4 - req.result.set_result(ret) - elif isinstance(ret, RpcError): - ret._caused_by = struct.unpack_from(" None: + req = self._pop_request(result.msg_id) + + if req: + assert len(result.body) >= 4 + req.result.set_result(result.body) + else: + self._logger.warning( + "telegram sent rpc_result for unknown msg_id=%d: %s", + result.msg_id, + result.body.hex(), + ) + + def _process_error(self, result: RpcError) -> None: + req = self._pop_request(result.msg_id) + + if req: + result._caused_by = struct.unpack_from(" None: + for req in self._drain_requests(result.msg_id): + if result.retryable: + self._logger.log( + result.severity, + "telegram notified of bad msg_id=%d; will attempt to resend request: %s", + result.msg_id, + result, + ) + req.state = NotSerialized() + self._requests.append(req) else: - raise RuntimeError("unexpected case") + self._logger.log( + result.severity, + "telegram notified of bad msg_id=%d; impossible to retry: %s", + result.msg_id, + result, + ) + result._caused_by = struct.unpack_from(" Optional[Request[object]]: + for i, req in enumerate(self._requests): + if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: + raise RuntimeError("got response for unsent request") + elif isinstance(req.state, Sent) and req.state.msg_id == msg_id: + del self._requests[i] + return req + + return None + + def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]: + for i in reversed(range(len(self._requests))): + req = self._requests[i] + if isinstance(req.state, Serialized) and ( + req.state.msg_id == msg_id or req.state.container_msg_id == msg_id + ): + raise RuntimeError("got response for unsent request") + elif isinstance(req.state, Sent) and ( + req.state.msg_id == msg_id or req.state.container_msg_id == msg_id + ): + yield self._requests.pop(i) @property def auth_key(self) -> Optional[bytes]: