diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 49fa46b6..f49bd037 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -2,7 +2,7 @@ import logging import os import struct import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type, Union from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2 from ...tl.core import Reader @@ -76,6 +76,18 @@ HEADER_LEN = 8 + 8 # salt, client_id CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len +class Single: + """ + Sentinel value. + """ + + +class Pending: + """ + Sentinel value. + """ + + class Encrypted(Mtp): def __init__( self, @@ -122,7 +134,10 @@ class Encrypted(Mtp): self._client_id: int self._sequence: int self._last_msg_id: int - self._pending_ack: List[int] = [] + self._in_pending_ack: List[int] = [] + self._out_pending_ack: Dict[ + int, Union[int, Type[Single], Type[Pending]] # msg_id: container_id + ] = {} self._msg_count: int self._reset_session() @@ -142,7 +157,8 @@ class Encrypted(Mtp): self._client_id = struct.unpack(" int: @@ -170,6 +186,10 @@ class Encrypted(Mtp): self._buffer += struct.pack(" int: @@ -186,16 +206,23 @@ class Encrypted(Mtp): " None: if message_requires_ack(message): - self._pending_ack.append(message.msg_id) + self._in_pending_ack.append(message.msg_id) # https://core.telegram.org/mtproto/service_messages # https://core.telegram.org/mtproto/service_messages_about_messages constructor_id = struct.unpack_from(" None: rpc_result = GeneratedRpcResult.from_bytes(message.body) req_msg_id = rpc_result.req_msg_id result = rpc_result.result + del self._out_pending_ack[req_msg_id] + msg_id = MsgId(req_msg_id) inner_constructor = struct.unpack_from(" None: - MsgsAck.from_bytes(message.body) + if __debug__: + msgs_ack = MsgsAck.from_bytes(message.body) + for msg_id in msgs_ack.msg_ids: + assert msg_id in self._out_pending_ack def _handle_bad_notification(self, message: Message) -> None: bad_msg = AbcBadMsgNotification.from_bytes(message.body) assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) exc = BadMessage(code=bad_msg.error_code) - self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc)) + + bad_msg_id = bad_msg.bad_msg_id + if self._out_pending_ack[bad_msg_id] is None: + # Search bad_msg_id in containers instead. + # Make a new list since pending ack needs to be mutated after. + bad_msg_ids = [ + m for m, c in self._out_pending_ack.items() if bad_msg_id == c + ] + if not bad_msg_ids: + raise KeyError(f"bad_msg for unknown msg_id: {bad_msg_id}") + + for bad_msg_id in bad_msg_id: + self._rpc_results.append((MsgId(bad_msg_id), exc)) + del self._out_pending_ack[bad_msg_id] + else: + self._rpc_results.append((MsgId(bad_msg_id), exc)) + del self._out_pending_ack[bad_msg_id] + if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: # If we had no valid salt, this error is expected. exc.severity = logging.INFO @@ -284,9 +335,9 @@ class Encrypted(Mtp): def _handle_detailed_info(self, message: Message) -> None: msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body) if isinstance(msg_detailed, MsgDetailedInfo): - self._pending_ack.append(msg_detailed.answer_msg_id) + self._in_pending_ack.append(msg_detailed.answer_msg_id) elif isinstance(msg_detailed, MsgNewDetailedInfo): - self._pending_ack.append(msg_detailed.answer_msg_id) + self._in_pending_ack.append(msg_detailed.answer_msg_id) else: assert False @@ -295,6 +346,7 @@ class Encrypted(Mtp): def _handle_future_salts(self, message: Message) -> None: salts = FutureSalts.from_bytes(message.body) + del self._out_pending_ack[salts.req_msg_id] if salts.req_msg_id == self._salt_request_msg_id: # Response to internal request, do not propagate. @@ -367,9 +419,9 @@ class Encrypted(Mtp): ) def push(self, request: bytes) -> Optional[MsgId]: - if self._pending_ack: - self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False) - self._pending_ack = [] + if self._in_pending_ack: + self._serialize_msg(bytes(MsgsAck(msg_ids=self._in_pending_ack)), False) + self._in_pending_ack = [] if self._start_salt_time and len(self._salts) >= 2: start_secs, start_instant = self._start_salt_time