mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-07-16 11:02:19 +03:00
Fix handling of bad_msg referring to containers
This commit is contained in:
parent
186dd38ff4
commit
5604f530c0
|
@ -2,7 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2
|
from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2
|
||||||
from ...tl.core import Reader
|
from ...tl.core import Reader
|
||||||
|
@ -76,6 +76,18 @@ HEADER_LEN = 8 + 8 # salt, client_id
|
||||||
CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len
|
CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len
|
||||||
|
|
||||||
|
|
||||||
|
class Single:
|
||||||
|
"""
|
||||||
|
Sentinel value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Pending:
|
||||||
|
"""
|
||||||
|
Sentinel value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Encrypted(Mtp):
|
class Encrypted(Mtp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -122,7 +134,10 @@ class Encrypted(Mtp):
|
||||||
self._client_id: int
|
self._client_id: int
|
||||||
self._sequence: int
|
self._sequence: int
|
||||||
self._last_msg_id: int
|
self._last_msg_id: int
|
||||||
self._pending_ack: List[int] = []
|
self._in_pending_ack: List[int] = []
|
||||||
|
self._out_pending_ack: Dict[
|
||||||
|
int, Union[int, Type[Single], Type[Pending]] # msg_id: container_id
|
||||||
|
] = {}
|
||||||
self._msg_count: int
|
self._msg_count: int
|
||||||
self._reset_session()
|
self._reset_session()
|
||||||
|
|
||||||
|
@ -142,7 +157,8 @@ class Encrypted(Mtp):
|
||||||
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||||
self._sequence = 0
|
self._sequence = 0
|
||||||
self._last_msg_id = 0
|
self._last_msg_id = 0
|
||||||
self._pending_ack.clear()
|
self._in_pending_ack.clear()
|
||||||
|
self._out_pending_ack.clear()
|
||||||
self._msg_count = 0
|
self._msg_count = 0
|
||||||
|
|
||||||
def _get_new_msg_id(self) -> int:
|
def _get_new_msg_id(self) -> int:
|
||||||
|
@ -170,6 +186,10 @@ class Encrypted(Mtp):
|
||||||
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
|
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
|
||||||
self._buffer += body
|
self._buffer += body
|
||||||
self._msg_count += 1
|
self._msg_count += 1
|
||||||
|
|
||||||
|
if content_related:
|
||||||
|
self._out_pending_ack[msg_id] = Pending
|
||||||
|
|
||||||
return MsgId(msg_id)
|
return MsgId(msg_id)
|
||||||
|
|
||||||
def _get_current_salt(self) -> int:
|
def _get_current_salt(self) -> int:
|
||||||
|
@ -186,16 +206,23 @@ class Encrypted(Mtp):
|
||||||
"<qq", self._get_current_salt(), self._client_id
|
"<qq", self._get_current_salt(), self._client_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._msg_count != 1:
|
if self._msg_count == 1:
|
||||||
|
container_msg_id = Single
|
||||||
|
else:
|
||||||
|
container_msg_id = self._get_new_msg_id()
|
||||||
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
|
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
|
||||||
"<qiiIi",
|
"<qiiIi",
|
||||||
self._get_new_msg_id(),
|
container_msg_id,
|
||||||
self._get_seq_no(False),
|
self._get_seq_no(False),
|
||||||
len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8,
|
len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8,
|
||||||
MsgContainer.constructor_id(),
|
MsgContainer.constructor_id(),
|
||||||
self._msg_count,
|
self._msg_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for m, c in self._out_pending_ack.items():
|
||||||
|
if c is Pending:
|
||||||
|
self._out_pending_ack[m] = container_msg_id
|
||||||
|
|
||||||
self._msg_count = 0
|
self._msg_count = 0
|
||||||
result = bytes(self._buffer)
|
result = bytes(self._buffer)
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
|
@ -203,18 +230,22 @@ class Encrypted(Mtp):
|
||||||
|
|
||||||
def _process_message(self, message: Message) -> None:
|
def _process_message(self, message: Message) -> None:
|
||||||
if message_requires_ack(message):
|
if message_requires_ack(message):
|
||||||
self._pending_ack.append(message.msg_id)
|
self._in_pending_ack.append(message.msg_id)
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/service_messages
|
# https://core.telegram.org/mtproto/service_messages
|
||||||
# https://core.telegram.org/mtproto/service_messages_about_messages
|
# https://core.telegram.org/mtproto/service_messages_about_messages
|
||||||
constructor_id = struct.unpack_from("<I", message.body)[0]
|
constructor_id = struct.unpack_from("<I", message.body)[0]
|
||||||
self._handlers.get(constructor_id, self._handle_update)(message)
|
self._handlers.get(constructor_id, self._handle_update)(message)
|
||||||
|
|
||||||
|
assert len(self._out_pending_ack) < 1000
|
||||||
|
|
||||||
def _handle_rpc_result(self, message: Message) -> None:
|
def _handle_rpc_result(self, message: Message) -> None:
|
||||||
rpc_result = GeneratedRpcResult.from_bytes(message.body)
|
rpc_result = GeneratedRpcResult.from_bytes(message.body)
|
||||||
req_msg_id = rpc_result.req_msg_id
|
req_msg_id = rpc_result.req_msg_id
|
||||||
result = rpc_result.result
|
result = rpc_result.result
|
||||||
|
|
||||||
|
del self._out_pending_ack[req_msg_id]
|
||||||
|
|
||||||
msg_id = MsgId(req_msg_id)
|
msg_id = MsgId(req_msg_id)
|
||||||
inner_constructor = struct.unpack_from("<I", result)[0]
|
inner_constructor = struct.unpack_from("<I", result)[0]
|
||||||
|
|
||||||
|
@ -245,14 +276,34 @@ class Encrypted(Mtp):
|
||||||
self._updates.append(body)
|
self._updates.append(body)
|
||||||
|
|
||||||
def _handle_ack(self, message: Message) -> None:
|
def _handle_ack(self, message: Message) -> None:
|
||||||
MsgsAck.from_bytes(message.body)
|
if __debug__:
|
||||||
|
msgs_ack = MsgsAck.from_bytes(message.body)
|
||||||
|
for msg_id in msgs_ack.msg_ids:
|
||||||
|
assert msg_id in self._out_pending_ack
|
||||||
|
|
||||||
def _handle_bad_notification(self, message: Message) -> None:
|
def _handle_bad_notification(self, message: Message) -> None:
|
||||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||||
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
||||||
|
|
||||||
exc = BadMessage(code=bad_msg.error_code)
|
exc = BadMessage(code=bad_msg.error_code)
|
||||||
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
|
|
||||||
|
bad_msg_id = bad_msg.bad_msg_id
|
||||||
|
if self._out_pending_ack[bad_msg_id] is None:
|
||||||
|
# Search bad_msg_id in containers instead.
|
||||||
|
# Make a new list since pending ack needs to be mutated after.
|
||||||
|
bad_msg_ids = [
|
||||||
|
m for m, c in self._out_pending_ack.items() if bad_msg_id == c
|
||||||
|
]
|
||||||
|
if not bad_msg_ids:
|
||||||
|
raise KeyError(f"bad_msg for unknown msg_id: {bad_msg_id}")
|
||||||
|
|
||||||
|
for bad_msg_id in bad_msg_id:
|
||||||
|
self._rpc_results.append((MsgId(bad_msg_id), exc))
|
||||||
|
del self._out_pending_ack[bad_msg_id]
|
||||||
|
else:
|
||||||
|
self._rpc_results.append((MsgId(bad_msg_id), exc))
|
||||||
|
del self._out_pending_ack[bad_msg_id]
|
||||||
|
|
||||||
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
||||||
# If we had no valid salt, this error is expected.
|
# If we had no valid salt, this error is expected.
|
||||||
exc.severity = logging.INFO
|
exc.severity = logging.INFO
|
||||||
|
@ -284,9 +335,9 @@ class Encrypted(Mtp):
|
||||||
def _handle_detailed_info(self, message: Message) -> None:
|
def _handle_detailed_info(self, message: Message) -> None:
|
||||||
msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body)
|
msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body)
|
||||||
if isinstance(msg_detailed, MsgDetailedInfo):
|
if isinstance(msg_detailed, MsgDetailedInfo):
|
||||||
self._pending_ack.append(msg_detailed.answer_msg_id)
|
self._in_pending_ack.append(msg_detailed.answer_msg_id)
|
||||||
elif isinstance(msg_detailed, MsgNewDetailedInfo):
|
elif isinstance(msg_detailed, MsgNewDetailedInfo):
|
||||||
self._pending_ack.append(msg_detailed.answer_msg_id)
|
self._in_pending_ack.append(msg_detailed.answer_msg_id)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
@ -295,6 +346,7 @@ class Encrypted(Mtp):
|
||||||
|
|
||||||
def _handle_future_salts(self, message: Message) -> None:
|
def _handle_future_salts(self, message: Message) -> None:
|
||||||
salts = FutureSalts.from_bytes(message.body)
|
salts = FutureSalts.from_bytes(message.body)
|
||||||
|
del self._out_pending_ack[salts.req_msg_id]
|
||||||
|
|
||||||
if salts.req_msg_id == self._salt_request_msg_id:
|
if salts.req_msg_id == self._salt_request_msg_id:
|
||||||
# Response to internal request, do not propagate.
|
# Response to internal request, do not propagate.
|
||||||
|
@ -367,9 +419,9 @@ class Encrypted(Mtp):
|
||||||
)
|
)
|
||||||
|
|
||||||
def push(self, request: bytes) -> Optional[MsgId]:
|
def push(self, request: bytes) -> Optional[MsgId]:
|
||||||
if self._pending_ack:
|
if self._in_pending_ack:
|
||||||
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
|
self._serialize_msg(bytes(MsgsAck(msg_ids=self._in_pending_ack)), False)
|
||||||
self._pending_ack = []
|
self._in_pending_ack = []
|
||||||
|
|
||||||
if self._start_salt_time and len(self._salts) >= 2:
|
if self._start_salt_time and len(self._salts) >= 2:
|
||||||
start_secs, start_instant = self._start_salt_time
|
start_secs, start_instant = self._start_salt_time
|
||||||
|
|
Loading…
Reference in New Issue
Block a user