Remove _out_pending_ack

Not really reliable.
Now that the sender has container IDs, it knows what to resend.
This commit is contained in:
Lonami Exo 2024-03-16 14:10:40 +01:00
parent c7d1a36969
commit b5db881415

View File

@ -2,7 +2,7 @@ import logging
import os import os
import struct import struct
import time import time
from typing import Dict, List, Optional, Tuple, Type, Union from typing import List, Optional, Tuple
from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2 from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2
from ...tl.core import Reader from ...tl.core import Reader
@ -143,9 +143,6 @@ class Encrypted(Mtp):
self._sequence: int self._sequence: int
self._last_msg_id: int self._last_msg_id: int
self._in_pending_ack: List[int] = [] self._in_pending_ack: List[int] = []
self._out_pending_ack: Dict[
int, Union[int, Type[Single], Type[Pending]] # msg_id: container_id
] = {}
self._msg_count: int self._msg_count: int
self._reset_session() self._reset_session()
@ -167,7 +164,6 @@ class Encrypted(Mtp):
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
self._in_pending_ack.clear() self._in_pending_ack.clear()
self._out_pending_ack.clear()
self._msg_count = 0 self._msg_count = 0
def _get_new_msg_id(self) -> int: def _get_new_msg_id(self) -> int:
@ -196,9 +192,6 @@ class Encrypted(Mtp):
self._buffer += body self._buffer += body
self._msg_count += 1 self._msg_count += 1
if content_related:
self._out_pending_ack[msg_id] = Pending
return MsgId(msg_id) return MsgId(msg_id)
def _get_current_salt(self) -> int: def _get_current_salt(self) -> int:
@ -228,10 +221,6 @@ class Encrypted(Mtp):
self._msg_count, self._msg_count,
) )
for m, c in self._out_pending_ack.items():
if c is Pending:
self._out_pending_ack[m] = container_msg_id
self._msg_count = 0 self._msg_count = 0
result = bytes(self._buffer) result = bytes(self._buffer)
self._buffer.clear() self._buffer.clear()
@ -246,15 +235,11 @@ class Encrypted(Mtp):
constructor_id = struct.unpack_from("<I", message.body)[0] constructor_id = struct.unpack_from("<I", message.body)[0]
self._handlers.get(constructor_id, self._handle_update)(message) self._handlers.get(constructor_id, self._handle_update)(message)
assert len(self._out_pending_ack) < 1000
def _handle_rpc_result(self, message: Message) -> None: def _handle_rpc_result(self, message: Message) -> None:
rpc_result = GeneratedRpcResult.from_bytes(message.body) rpc_result = GeneratedRpcResult.from_bytes(message.body)
req_msg_id = rpc_result.req_msg_id req_msg_id = rpc_result.req_msg_id
result = rpc_result.result result = rpc_result.result
del self._out_pending_ack[req_msg_id]
msg_id = MsgId(req_msg_id) msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0] inner_constructor = struct.unpack_from("<I", result)[0]
@ -285,10 +270,7 @@ class Encrypted(Mtp):
self._updates.append(body) self._updates.append(body)
def _handle_ack(self, message: Message) -> None: def _handle_ack(self, message: Message) -> None:
if __debug__: MsgsAck.from_bytes(message.body)
msgs_ack = MsgsAck.from_bytes(message.body)
for msg_id in msgs_ack.msg_ids:
assert msg_id in self._out_pending_ack
def _handle_bad_notification(self, message: Message) -> None: def _handle_bad_notification(self, message: Message) -> None:
bad_msg = AbcBadMsgNotification.from_bytes(message.body) bad_msg = AbcBadMsgNotification.from_bytes(message.body)
@ -296,24 +278,11 @@ class Encrypted(Mtp):
exc = BadMessage(code=bad_msg.error_code) exc = BadMessage(code=bad_msg.error_code)
bad_msg_id = bad_msg.bad_msg_id if bad_msg.bad_msg_id == self._salt_request_msg_id:
if bad_msg_id in self._out_pending_ack: # Response to internal request, do not propagate.
bad_msg_ids = [bad_msg.bad_msg_id] self._salt_request_msg_id = None
else: else:
# Search bad_msg_id in containers instead. self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
bad_msg_ids = [
m for m, c in self._out_pending_ack.items() if bad_msg_id == c
]
if not bad_msg_ids:
raise KeyError(f"bad_msg for unknown msg_id: {bad_msg_id}")
for bad_msg_id in bad_msg_ids:
if bad_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate.
self._salt_request_msg_id = None
else:
self._rpc_results.append((MsgId(bad_msg_id), exc))
del self._out_pending_ack[bad_msg_id]
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
# If we had no valid salt, this error is expected. # If we had no valid salt, this error is expected.
@ -357,7 +326,6 @@ class Encrypted(Mtp):
def _handle_future_salts(self, message: Message) -> None: def _handle_future_salts(self, message: Message) -> None:
salts = FutureSalts.from_bytes(message.body) salts = FutureSalts.from_bytes(message.body)
del self._out_pending_ack[salts.req_msg_id]
if salts.req_msg_id == self._salt_request_msg_id: if salts.req_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate. # Response to internal request, do not propagate.