Fix handling of salts and container buffer

This commit is contained in:
Lonami Exo 2023-10-14 01:21:33 +02:00
parent 6ed279e773
commit c91ce98a25
4 changed files with 121 additions and 88 deletions

View File

@ -1,3 +1,4 @@
import logging
import os
import struct
import time
@ -89,16 +90,12 @@ class Encrypted(Mtp):
self._salts: List[FutureSalt] = [
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
]
self._start_salt_time: Optional[Tuple[int, int]] = None
self._client_id: int = struct.unpack("<q", os.urandom(8))[0]
self._sequence: int = 0
self._last_msg_id: int = 0
self._pending_ack: List[int] = []
self._start_salt_time: Optional[Tuple[int, float]] = None
self._compression_threshold = compression_threshold
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
self._updates: List[bytes] = []
self._buffer = bytearray()
self._msg_count: int = 0
self._salt_request_msg_id: Optional[int] = None
self._handlers = {
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
@ -122,6 +119,13 @@ class Encrypted(Mtp):
HttpWait.constructor_id(): self._handle_http_wait,
}
self._client_id: int
self._sequence: int
self._last_msg_id: int
self._pending_ack: List[int] = []
self._msg_count: int
self._reset_session()
@property
def auth_key(self) -> bytes:
return self._auth_key.data
@ -131,10 +135,18 @@ class Encrypted(Mtp):
correct = msg_id >> 32
self._time_offset = correct - int(now)
def _get_new_msg_id(self) -> int:
now = time.time()
def _adjusted_now(self) -> float:
return time.time() + self._time_offset
new_msg_id = int((now + self._time_offset) * 0x100000000)
def _reset_session(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._pending_ack.clear()
self._msg_count = 0
def _get_new_msg_id(self) -> int:
new_msg_id = int(self._adjusted_now() * 0x100000000)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
@ -149,6 +161,10 @@ class Encrypted(Mtp):
return self._sequence
def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId:
if not self._buffer:
# Reserve space for `finalize`
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(content_related)
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
@ -156,6 +172,9 @@ class Encrypted(Mtp):
self._msg_count += 1
return MsgId(msg_id)
def _get_current_salt(self) -> int:
return self._salts[-1].salt if self._salts else 0
def _finalize_plain(self) -> bytes:
if not self._msg_count:
return b""
@ -164,7 +183,7 @@ class Encrypted(Mtp):
del self._buffer[:CONTAINER_HEADER_LEN]
self._buffer[:HEADER_LEN] = struct.pack(
"<qq", self._salts[-1].salt if self._salts else 0, self._client_id
"<qq", self._get_current_salt(), self._client_id
)
if self._msg_count != 1:
@ -177,6 +196,7 @@ class Encrypted(Mtp):
self._msg_count,
)
print("packed", self._msg_count)
self._msg_count = 0
result = bytes(self._buffer)
self._buffer.clear()
@ -230,37 +250,28 @@ class Encrypted(Mtp):
def _handle_bad_notification(self, message: Message) -> None:
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
if isinstance(bad_msg, BadServerSalt):
self._rpc_results.append(
(
MsgId(bad_msg.bad_msg_id),
BadMessage(code=bad_msg.error_code),
)
)
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
exc = BadMessage(code=bad_msg.error_code)
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
# If we had no valid salt, this error is expected.
exc.severity = logging.INFO
if isinstance(bad_msg, BadServerSalt):
self._salts.clear()
self._salts.append(
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
)
)
self.push(get_future_salts(num=NUM_FUTURE_SALTS))
return
assert isinstance(bad_msg, BadMsgNotification)
self._rpc_results.append(
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
)
if bad_msg.error_code in (16, 17):
self._salt_request_msg_id = None
elif bad_msg.error_code not in (16, 17):
self._correct_time_offset(message.msg_id)
elif bad_msg.error_code == 32:
# TODO start with a fresh session rather than guessing
self._sequence += 64
elif bad_msg.error_code == 33:
# TODO start with a fresh session rather than guessing
self._sequence -= 16
elif bad_msg.error_code in (32, 33):
self._reset_session()
else:
raise exc
def _handle_state_req(self, message: Message) -> None:
MsgsStateReq.from_bytes(message.body)
@ -285,9 +296,14 @@ class Encrypted(Mtp):
def _handle_future_salts(self, message: Message) -> None:
salts = FutureSalts.from_bytes(message.body)
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
self._start_salt_time = (salts.now, int(time.time()))
if salts.req_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate.
self._salt_request_msg_id = None
else:
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = salts.salts
self._salts.sort(key=lambda salt: -salt.valid_since)
@ -334,28 +350,38 @@ class Encrypted(Mtp):
def _handle_update(self, message: Message) -> None:
self._updates.append(message.body)
def push(self, request: bytes) -> Optional[MsgId]:
if not self._buffer:
# Reserve space for `finalize`
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
def _try_request_salts(self) -> None:
if (
len(self._salts) == 1
and self._salt_request_msg_id is None
and self._get_current_salt() != 0
):
# If salts are requested in a container leading to bad_msg,
# the bad_msg_id will refer to the container, not the salts request.
#
# We don't keep track of containers and content-related messages they contain for simplicity.
# This would break, because we couldn't identify the response.
#
# So salts are only requested once we have a valid salt to reduce the chances of this happening.
self._salt_request_msg_id = self._serialize_msg(
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
)
def push(self, request: bytes) -> Optional[MsgId]:
if self._pending_ack:
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
self._pending_ack = []
if self._start_salt_time:
if self._start_salt_time and len(self._salts) >= 2:
start_secs, start_instant = self._start_salt_time
if len(self._salts) >= 2:
salt = self._salts[-2]
now = start_secs + (start_instant - int(time.time()))
if now >= salt.valid_since + SALT_USE_DELAY:
self._salts.pop()
if len(self._salts) == 1:
self._serialize_msg(
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
)
salt = self._salts[-2]
now = start_secs + (start_instant - self._adjusted_now())
if now >= salt.valid_since + SALT_USE_DELAY:
self._salts.pop()
if self._msg_count == CONTAINER_MAX_LENGTH:
self._try_request_salts()
if self._msg_count >= CONTAINER_MAX_LENGTH:
return None
assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE

View File

@ -186,24 +186,17 @@ class Sender:
if self._write_drain_pending:
return
# TODO test that the a request is only ever sent onrece
requests = [r for r in self._requests if isinstance(r.state, NotSerialized)]
if not requests:
return
msg_ids = []
for request in requests:
if (msg_id := self._mtp.push(request.body)) is not None:
msg_ids.append(msg_id)
else:
break
for request in self._requests:
if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None:
request.state = Serialized(msg_id)
else:
break
mtp_buffer = self._mtp.finalize()
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
for req, msg_id in zip(requests, msg_ids):
req.state = Serialized(msg_id)
if mtp_buffer:
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> List[Updates]:
if not read_buffer:
@ -255,34 +248,47 @@ class Sender:
updates.append(u)
for msg_id, ret in result.rpc_results:
found = False
for i in reversed(range(len(self._requests))):
req = self._requests[i]
for i, req in enumerate(self._requests):
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
raise RuntimeError("got rpc result for unsent request")
if isinstance(req.state, Sent) and req.state.msg_id == msg_id:
found = True
if isinstance(ret, bytes):
assert len(ret) >= 4
elif isinstance(ret, RpcError):
ret._caused_by = struct.unpack_from("<I", req.body)[0]
raise ret
elif isinstance(ret, BadMessage):
# TODO test that we resend the request
req.state = NotSerialized()
break
else:
raise RuntimeError("unexpected case")
req = self._requests.pop(i)
req.result.set_result(ret)
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
del self._requests[i]
break
if not found:
else:
self._logger.warning(
"telegram sent rpc_result for unknown msg_id=%d: %s",
msg_id,
ret.hex() if isinstance(ret, bytes) else repr(ret),
)
continue
if isinstance(ret, bytes):
assert len(ret) >= 4
req.result.set_result(ret)
elif isinstance(ret, RpcError):
ret._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(ret)
elif isinstance(ret, BadMessage):
if ret.retryable:
self._logger.log(
ret.severity,
"telegram notified of bad msg_id=%d; will attempt to resend request: %s",
msg_id,
ret,
)
req.state = NotSerialized()
self._requests.append(req)
else:
self._logger.log(
ret.severity,
"telegram notified of bad msg_id=%d; impossible to retry: %s",
msg_id,
ret,
)
ret._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(ret)
else:
raise RuntimeError("unexpected case")
@property
def auth_key(self) -> Optional[bytes]:

View File

@ -66,6 +66,7 @@ class PackedChat:
"""
return bytes(self).hex()
@classmethod
def from_hex(cls, hex: str) -> Self:
"""
Convenience method to convert hexadecimal numbers into bytes then passed to :meth:`from_bytes`:

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple
from ..tl.core.serializable import obj_repr
@ -13,7 +13,7 @@ class DataCenter:
:param auth: See below.
"""
__slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
__slots__: Tuple[str, ...] = ("id", "ipv4_addr", "ipv6_addr", "auth")
def __init__(
self,