Fix handling of bad_msg referring to containers

This commit is contained in:
Lonami Exo 2023-10-18 19:27:43 +02:00
parent 186dd38ff4
commit 5604f530c0

View File

@ -2,7 +2,7 @@ import logging
import os import os
import struct import struct
import time import time
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Type, Union
from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2 from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2
from ...tl.core import Reader from ...tl.core import Reader
@ -76,6 +76,18 @@ HEADER_LEN = 8 + 8 # salt, client_id
CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len
class Single:
"""
Sentinel value.
"""
class Pending:
"""
Sentinel value.
"""
class Encrypted(Mtp): class Encrypted(Mtp):
def __init__( def __init__(
self, self,
@ -122,7 +134,10 @@ class Encrypted(Mtp):
self._client_id: int self._client_id: int
self._sequence: int self._sequence: int
self._last_msg_id: int self._last_msg_id: int
self._pending_ack: List[int] = [] self._in_pending_ack: List[int] = []
self._out_pending_ack: Dict[
int, Union[int, Type[Single], Type[Pending]] # msg_id: container_id
] = {}
self._msg_count: int self._msg_count: int
self._reset_session() self._reset_session()
@ -142,7 +157,8 @@ class Encrypted(Mtp):
self._client_id = struct.unpack("<q", os.urandom(8))[0] self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
self._pending_ack.clear() self._in_pending_ack.clear()
self._out_pending_ack.clear()
self._msg_count = 0 self._msg_count = 0
def _get_new_msg_id(self) -> int: def _get_new_msg_id(self) -> int:
@ -170,6 +186,10 @@ class Encrypted(Mtp):
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body)) self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
self._buffer += body self._buffer += body
self._msg_count += 1 self._msg_count += 1
if content_related:
self._out_pending_ack[msg_id] = Pending
return MsgId(msg_id) return MsgId(msg_id)
def _get_current_salt(self) -> int: def _get_current_salt(self) -> int:
@ -186,16 +206,23 @@ class Encrypted(Mtp):
"<qq", self._get_current_salt(), self._client_id "<qq", self._get_current_salt(), self._client_id
) )
if self._msg_count != 1: if self._msg_count == 1:
container_msg_id = Single
else:
container_msg_id = self._get_new_msg_id()
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack( self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
"<qiiIi", "<qiiIi",
self._get_new_msg_id(), container_msg_id,
self._get_seq_no(False), self._get_seq_no(False),
len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8, len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8,
MsgContainer.constructor_id(), MsgContainer.constructor_id(),
self._msg_count, self._msg_count,
) )
for m, c in self._out_pending_ack.items():
if c is Pending:
self._out_pending_ack[m] = container_msg_id
self._msg_count = 0 self._msg_count = 0
result = bytes(self._buffer) result = bytes(self._buffer)
self._buffer.clear() self._buffer.clear()
@ -203,18 +230,22 @@ class Encrypted(Mtp):
def _process_message(self, message: Message) -> None: def _process_message(self, message: Message) -> None:
if message_requires_ack(message): if message_requires_ack(message):
self._pending_ack.append(message.msg_id) self._in_pending_ack.append(message.msg_id)
# https://core.telegram.org/mtproto/service_messages # https://core.telegram.org/mtproto/service_messages
# https://core.telegram.org/mtproto/service_messages_about_messages # https://core.telegram.org/mtproto/service_messages_about_messages
constructor_id = struct.unpack_from("<I", message.body)[0] constructor_id = struct.unpack_from("<I", message.body)[0]
self._handlers.get(constructor_id, self._handle_update)(message) self._handlers.get(constructor_id, self._handle_update)(message)
assert len(self._out_pending_ack) < 1000
def _handle_rpc_result(self, message: Message) -> None: def _handle_rpc_result(self, message: Message) -> None:
rpc_result = GeneratedRpcResult.from_bytes(message.body) rpc_result = GeneratedRpcResult.from_bytes(message.body)
req_msg_id = rpc_result.req_msg_id req_msg_id = rpc_result.req_msg_id
result = rpc_result.result result = rpc_result.result
del self._out_pending_ack[req_msg_id]
msg_id = MsgId(req_msg_id) msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0] inner_constructor = struct.unpack_from("<I", result)[0]
@ -245,14 +276,34 @@ class Encrypted(Mtp):
self._updates.append(body) self._updates.append(body)
def _handle_ack(self, message: Message) -> None: def _handle_ack(self, message: Message) -> None:
MsgsAck.from_bytes(message.body) if __debug__:
msgs_ack = MsgsAck.from_bytes(message.body)
for msg_id in msgs_ack.msg_ids:
assert msg_id in self._out_pending_ack
def _handle_bad_notification(self, message: Message) -> None: def _handle_bad_notification(self, message: Message) -> None:
bad_msg = AbcBadMsgNotification.from_bytes(message.body) bad_msg = AbcBadMsgNotification.from_bytes(message.body)
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
exc = BadMessage(code=bad_msg.error_code) exc = BadMessage(code=bad_msg.error_code)
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
bad_msg_id = bad_msg.bad_msg_id
if self._out_pending_ack[bad_msg_id] is None:
# Search bad_msg_id in containers instead.
# Make a new list since pending ack needs to be mutated after.
bad_msg_ids = [
m for m, c in self._out_pending_ack.items() if bad_msg_id == c
]
if not bad_msg_ids:
raise KeyError(f"bad_msg for unknown msg_id: {bad_msg_id}")
for bad_msg_id in bad_msg_id:
self._rpc_results.append((MsgId(bad_msg_id), exc))
del self._out_pending_ack[bad_msg_id]
else:
self._rpc_results.append((MsgId(bad_msg_id), exc))
del self._out_pending_ack[bad_msg_id]
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
# If we had no valid salt, this error is expected. # If we had no valid salt, this error is expected.
exc.severity = logging.INFO exc.severity = logging.INFO
@ -284,9 +335,9 @@ class Encrypted(Mtp):
def _handle_detailed_info(self, message: Message) -> None: def _handle_detailed_info(self, message: Message) -> None:
msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body) msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body)
if isinstance(msg_detailed, MsgDetailedInfo): if isinstance(msg_detailed, MsgDetailedInfo):
self._pending_ack.append(msg_detailed.answer_msg_id) self._in_pending_ack.append(msg_detailed.answer_msg_id)
elif isinstance(msg_detailed, MsgNewDetailedInfo): elif isinstance(msg_detailed, MsgNewDetailedInfo):
self._pending_ack.append(msg_detailed.answer_msg_id) self._in_pending_ack.append(msg_detailed.answer_msg_id)
else: else:
assert False assert False
@ -295,6 +346,7 @@ class Encrypted(Mtp):
def _handle_future_salts(self, message: Message) -> None: def _handle_future_salts(self, message: Message) -> None:
salts = FutureSalts.from_bytes(message.body) salts = FutureSalts.from_bytes(message.body)
del self._out_pending_ack[salts.req_msg_id]
if salts.req_msg_id == self._salt_request_msg_id: if salts.req_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate. # Response to internal request, do not propagate.
@ -367,9 +419,9 @@ class Encrypted(Mtp):
) )
def push(self, request: bytes) -> Optional[MsgId]: def push(self, request: bytes) -> Optional[MsgId]:
if self._pending_ack: if self._in_pending_ack:
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False) self._serialize_msg(bytes(MsgsAck(msg_ids=self._in_pending_ack)), False)
self._pending_ack = [] self._in_pending_ack = []
if self._start_salt_time and len(self._salts) >= 2: if self._start_salt_time and len(self._salts) >= 2:
start_secs, start_instant = self._start_salt_time start_secs, start_instant = self._start_salt_time