diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 8cfb1248..2100082f 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -1,3 +1,4 @@ +import logging import os import struct import time @@ -89,16 +90,12 @@ class Encrypted(Mtp): self._salts: List[FutureSalt] = [ FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0) ] - self._start_salt_time: Optional[Tuple[int, int]] = None - self._client_id: int = struct.unpack(" bytes: return self._auth_key.data @@ -131,10 +135,18 @@ class Encrypted(Mtp): correct = msg_id >> 32 self._time_offset = correct - int(now) - def _get_new_msg_id(self) -> int: - now = time.time() + def _adjusted_now(self) -> float: + return time.time() + self._time_offset - new_msg_id = int((now + self._time_offset) * 0x100000000) + def _reset_session(self) -> None: + self._client_id = struct.unpack(" int: + new_msg_id = int(self._adjusted_now() * 0x100000000) if self._last_msg_id >= new_msg_id: new_msg_id = self._last_msg_id + 4 @@ -149,6 +161,10 @@ class Encrypted(Mtp): return self._sequence def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId: + if not self._buffer: + # Reserve space for `finalize` + self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN) + msg_id = self._get_new_msg_id() seq_no = self._get_seq_no(content_related) self._buffer += struct.pack(" int: + return self._salts[-1].salt if self._salts else 0 + def _finalize_plain(self) -> bytes: if not self._msg_count: return b"" @@ -164,7 +183,7 @@ class Encrypted(Mtp): del self._buffer[:CONTAINER_HEADER_LEN] self._buffer[:HEADER_LEN] = struct.pack( - " None: bad_msg = AbcBadMsgNotification.from_bytes(message.body) - if isinstance(bad_msg, BadServerSalt): - self._rpc_results.append( - ( - MsgId(bad_msg.bad_msg_id), - BadMessage(code=bad_msg.error_code), - ) - ) + assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) + exc = BadMessage(code=bad_msg.error_code) + self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc)) + if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: + # If we had no valid salt, this error is expected. + exc.severity = logging.INFO + + if isinstance(bad_msg, BadServerSalt): self._salts.clear() self._salts.append( FutureSalt( valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt ) ) - - self.push(get_future_salts(num=NUM_FUTURE_SALTS)) - return - - assert isinstance(bad_msg, BadMsgNotification) - self._rpc_results.append( - (MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code)) - ) - - if bad_msg.error_code in (16, 17): + self._salt_request_msg_id = None + elif bad_msg.error_code not in (16, 17): self._correct_time_offset(message.msg_id) - elif bad_msg.error_code == 32: - # TODO start with a fresh session rather than guessing - self._sequence += 64 - elif bad_msg.error_code == 33: - # TODO start with a fresh session rather than guessing - self._sequence -= 16 + elif bad_msg.error_code in (32, 33): + self._reset_session() + else: + raise exc def _handle_state_req(self, message: Message) -> None: MsgsStateReq.from_bytes(message.body) @@ -285,9 +296,14 @@ class Encrypted(Mtp): def _handle_future_salts(self, message: Message) -> None: salts = FutureSalts.from_bytes(message.body) - self._rpc_results.append((MsgId(salts.req_msg_id), message.body)) - self._start_salt_time = (salts.now, int(time.time())) + if salts.req_msg_id == self._salt_request_msg_id: + # Response to internal request, do not propagate. + self._salt_request_msg_id = None + else: + self._rpc_results.append((MsgId(salts.req_msg_id), message.body)) + + self._start_salt_time = (salts.now, self._adjusted_now()) self._salts = salts.salts self._salts.sort(key=lambda salt: -salt.valid_since) @@ -334,28 +350,38 @@ class Encrypted(Mtp): def _handle_update(self, message: Message) -> None: self._updates.append(message.body) - def push(self, request: bytes) -> Optional[MsgId]: - if not self._buffer: - # Reserve space for `finalize` - self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN) + def _try_request_salts(self) -> None: + if ( + len(self._salts) == 1 + and self._salt_request_msg_id is None + and self._get_current_salt() != 0 + ): + # If salts are requested in a container leading to bad_msg, + # the bad_msg_id will refer to the container, not the salts request. + # + # We don't keep track of containers and content-related messages they contain for simplicity. + # This would break, because we couldn't identify the response. + # + # So salts are only requested once we have a valid salt to reduce the chances of this happening. + self._salt_request_msg_id = self._serialize_msg( + bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True + ) + def push(self, request: bytes) -> Optional[MsgId]: if self._pending_ack: self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False) self._pending_ack = [] - if self._start_salt_time: + if self._start_salt_time and len(self._salts) >= 2: start_secs, start_instant = self._start_salt_time - if len(self._salts) >= 2: - salt = self._salts[-2] - now = start_secs + (start_instant - int(time.time())) - if now >= salt.valid_since + SALT_USE_DELAY: - self._salts.pop() - if len(self._salts) == 1: - self._serialize_msg( - bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True - ) + salt = self._salts[-2] + now = start_secs + (start_instant - self._adjusted_now()) + if now >= salt.valid_since + SALT_USE_DELAY: + self._salts.pop() - if self._msg_count == CONTAINER_MAX_LENGTH: + self._try_request_salts() + + if self._msg_count >= CONTAINER_MAX_LENGTH: return None assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 72f9a4fc..52b7a8ee 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -186,24 +186,17 @@ class Sender: if self._write_drain_pending: return - # TODO test that the a request is only ever sent onrece - requests = [r for r in self._requests if isinstance(r.state, NotSerialized)] - if not requests: - return - - msg_ids = [] - for request in requests: - if (msg_id := self._mtp.push(request.body)) is not None: - msg_ids.append(msg_id) - else: - break + for request in self._requests: + if isinstance(request.state, NotSerialized): + if (msg_id := self._mtp.push(request.body)) is not None: + request.state = Serialized(msg_id) + else: + break mtp_buffer = self._mtp.finalize() - self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True - - for req, msg_id in zip(requests, msg_ids): - req.state = Serialized(msg_id) + if mtp_buffer: + self._transport.pack(mtp_buffer, self._writer.write) + self._write_drain_pending = True def _on_net_read(self, read_buffer: bytes) -> List[Updates]: if not read_buffer: @@ -255,34 +248,47 @@ class Sender: updates.append(u) for msg_id, ret in result.rpc_results: - found = False - for i in reversed(range(len(self._requests))): - req = self._requests[i] + for i, req in enumerate(self._requests): if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: raise RuntimeError("got rpc result for unsent request") - if isinstance(req.state, Sent) and req.state.msg_id == msg_id: - found = True - if isinstance(ret, bytes): - assert len(ret) >= 4 - elif isinstance(ret, RpcError): - ret._caused_by = struct.unpack_from("= 4 + req.result.set_result(ret) + elif isinstance(ret, RpcError): + ret._caused_by = struct.unpack_from(" Optional[bytes]: diff --git a/client/src/telethon/_impl/session/chat/packed.py b/client/src/telethon/_impl/session/chat/packed.py index 00e27164..9fe030b4 100644 --- a/client/src/telethon/_impl/session/chat/packed.py +++ b/client/src/telethon/_impl/session/chat/packed.py @@ -66,6 +66,7 @@ class PackedChat: """ return bytes(self).hex() + @classmethod def from_hex(cls, hex: str) -> Self: """ Convenience method to convert hexadecimal numbers into bytes then passed to :meth:`from_bytes`: diff --git a/client/src/telethon/_impl/session/session.py b/client/src/telethon/_impl/session/session.py index cbfd2160..9aa739ae 100644 --- a/client/src/telethon/_impl/session/session.py +++ b/client/src/telethon/_impl/session/session.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from ..tl.core.serializable import obj_repr @@ -13,7 +13,7 @@ class DataCenter: :param auth: See below. """ - __slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth") + __slots__: Tuple[str, ...] = ("id", "ipv4_addr", "ipv6_addr", "auth") def __init__( self,