mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-04-04 01:04:16 +03:00
Fix handling of salts and container buffer
This commit is contained in:
parent
6ed279e773
commit
c91ce98a25
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
|
@ -89,16 +90,12 @@ class Encrypted(Mtp):
|
|||
self._salts: List[FutureSalt] = [
|
||||
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
|
||||
]
|
||||
self._start_salt_time: Optional[Tuple[int, int]] = None
|
||||
self._client_id: int = struct.unpack("<q", os.urandom(8))[0]
|
||||
self._sequence: int = 0
|
||||
self._last_msg_id: int = 0
|
||||
self._pending_ack: List[int] = []
|
||||
self._start_salt_time: Optional[Tuple[int, float]] = None
|
||||
self._compression_threshold = compression_threshold
|
||||
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
||||
self._updates: List[bytes] = []
|
||||
self._buffer = bytearray()
|
||||
self._msg_count: int = 0
|
||||
self._salt_request_msg_id: Optional[int] = None
|
||||
|
||||
self._handlers = {
|
||||
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
|
||||
|
@ -122,6 +119,13 @@ class Encrypted(Mtp):
|
|||
HttpWait.constructor_id(): self._handle_http_wait,
|
||||
}
|
||||
|
||||
self._client_id: int
|
||||
self._sequence: int
|
||||
self._last_msg_id: int
|
||||
self._pending_ack: List[int] = []
|
||||
self._msg_count: int
|
||||
self._reset_session()
|
||||
|
||||
@property
|
||||
def auth_key(self) -> bytes:
|
||||
return self._auth_key.data
|
||||
|
@ -131,10 +135,18 @@ class Encrypted(Mtp):
|
|||
correct = msg_id >> 32
|
||||
self._time_offset = correct - int(now)
|
||||
|
||||
def _get_new_msg_id(self) -> int:
|
||||
now = time.time()
|
||||
def _adjusted_now(self) -> float:
|
||||
return time.time() + self._time_offset
|
||||
|
||||
new_msg_id = int((now + self._time_offset) * 0x100000000)
|
||||
def _reset_session(self) -> None:
|
||||
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
self._pending_ack.clear()
|
||||
self._msg_count = 0
|
||||
|
||||
def _get_new_msg_id(self) -> int:
|
||||
new_msg_id = int(self._adjusted_now() * 0x100000000)
|
||||
if self._last_msg_id >= new_msg_id:
|
||||
new_msg_id = self._last_msg_id + 4
|
||||
|
||||
|
@ -149,6 +161,10 @@ class Encrypted(Mtp):
|
|||
return self._sequence
|
||||
|
||||
def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId:
|
||||
if not self._buffer:
|
||||
# Reserve space for `finalize`
|
||||
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
|
||||
|
||||
msg_id = self._get_new_msg_id()
|
||||
seq_no = self._get_seq_no(content_related)
|
||||
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
|
||||
|
@ -156,6 +172,9 @@ class Encrypted(Mtp):
|
|||
self._msg_count += 1
|
||||
return MsgId(msg_id)
|
||||
|
||||
def _get_current_salt(self) -> int:
|
||||
return self._salts[-1].salt if self._salts else 0
|
||||
|
||||
def _finalize_plain(self) -> bytes:
|
||||
if not self._msg_count:
|
||||
return b""
|
||||
|
@ -164,7 +183,7 @@ class Encrypted(Mtp):
|
|||
del self._buffer[:CONTAINER_HEADER_LEN]
|
||||
|
||||
self._buffer[:HEADER_LEN] = struct.pack(
|
||||
"<qq", self._salts[-1].salt if self._salts else 0, self._client_id
|
||||
"<qq", self._get_current_salt(), self._client_id
|
||||
)
|
||||
|
||||
if self._msg_count != 1:
|
||||
|
@ -177,6 +196,7 @@ class Encrypted(Mtp):
|
|||
self._msg_count,
|
||||
)
|
||||
|
||||
print("packed", self._msg_count)
|
||||
self._msg_count = 0
|
||||
result = bytes(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
@ -230,37 +250,28 @@ class Encrypted(Mtp):
|
|||
|
||||
def _handle_bad_notification(self, message: Message) -> None:
|
||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||
if isinstance(bad_msg, BadServerSalt):
|
||||
self._rpc_results.append(
|
||||
(
|
||||
MsgId(bad_msg.bad_msg_id),
|
||||
BadMessage(code=bad_msg.error_code),
|
||||
)
|
||||
)
|
||||
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
||||
|
||||
exc = BadMessage(code=bad_msg.error_code)
|
||||
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
|
||||
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
||||
# If we had no valid salt, this error is expected.
|
||||
exc.severity = logging.INFO
|
||||
|
||||
if isinstance(bad_msg, BadServerSalt):
|
||||
self._salts.clear()
|
||||
self._salts.append(
|
||||
FutureSalt(
|
||||
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
|
||||
)
|
||||
)
|
||||
|
||||
self.push(get_future_salts(num=NUM_FUTURE_SALTS))
|
||||
return
|
||||
|
||||
assert isinstance(bad_msg, BadMsgNotification)
|
||||
self._rpc_results.append(
|
||||
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
|
||||
)
|
||||
|
||||
if bad_msg.error_code in (16, 17):
|
||||
self._salt_request_msg_id = None
|
||||
elif bad_msg.error_code not in (16, 17):
|
||||
self._correct_time_offset(message.msg_id)
|
||||
elif bad_msg.error_code == 32:
|
||||
# TODO start with a fresh session rather than guessing
|
||||
self._sequence += 64
|
||||
elif bad_msg.error_code == 33:
|
||||
# TODO start with a fresh session rather than guessing
|
||||
self._sequence -= 16
|
||||
elif bad_msg.error_code in (32, 33):
|
||||
self._reset_session()
|
||||
else:
|
||||
raise exc
|
||||
|
||||
def _handle_state_req(self, message: Message) -> None:
|
||||
MsgsStateReq.from_bytes(message.body)
|
||||
|
@ -285,9 +296,14 @@ class Encrypted(Mtp):
|
|||
|
||||
def _handle_future_salts(self, message: Message) -> None:
|
||||
salts = FutureSalts.from_bytes(message.body)
|
||||
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
|
||||
|
||||
self._start_salt_time = (salts.now, int(time.time()))
|
||||
if salts.req_msg_id == self._salt_request_msg_id:
|
||||
# Response to internal request, do not propagate.
|
||||
self._salt_request_msg_id = None
|
||||
else:
|
||||
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
|
||||
|
||||
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||
self._salts = salts.salts
|
||||
self._salts.sort(key=lambda salt: -salt.valid_since)
|
||||
|
||||
|
@ -334,28 +350,38 @@ class Encrypted(Mtp):
|
|||
def _handle_update(self, message: Message) -> None:
|
||||
self._updates.append(message.body)
|
||||
|
||||
def push(self, request: bytes) -> Optional[MsgId]:
|
||||
if not self._buffer:
|
||||
# Reserve space for `finalize`
|
||||
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
|
||||
def _try_request_salts(self) -> None:
|
||||
if (
|
||||
len(self._salts) == 1
|
||||
and self._salt_request_msg_id is None
|
||||
and self._get_current_salt() != 0
|
||||
):
|
||||
# If salts are requested in a container leading to bad_msg,
|
||||
# the bad_msg_id will refer to the container, not the salts request.
|
||||
#
|
||||
# We don't keep track of containers and content-related messages they contain for simplicity.
|
||||
# This would break, because we couldn't identify the response.
|
||||
#
|
||||
# So salts are only requested once we have a valid salt to reduce the chances of this happening.
|
||||
self._salt_request_msg_id = self._serialize_msg(
|
||||
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
|
||||
)
|
||||
|
||||
def push(self, request: bytes) -> Optional[MsgId]:
|
||||
if self._pending_ack:
|
||||
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
|
||||
self._pending_ack = []
|
||||
|
||||
if self._start_salt_time:
|
||||
if self._start_salt_time and len(self._salts) >= 2:
|
||||
start_secs, start_instant = self._start_salt_time
|
||||
if len(self._salts) >= 2:
|
||||
salt = self._salts[-2]
|
||||
now = start_secs + (start_instant - int(time.time()))
|
||||
if now >= salt.valid_since + SALT_USE_DELAY:
|
||||
self._salts.pop()
|
||||
if len(self._salts) == 1:
|
||||
self._serialize_msg(
|
||||
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
|
||||
)
|
||||
salt = self._salts[-2]
|
||||
now = start_secs + (start_instant - self._adjusted_now())
|
||||
if now >= salt.valid_since + SALT_USE_DELAY:
|
||||
self._salts.pop()
|
||||
|
||||
if self._msg_count == CONTAINER_MAX_LENGTH:
|
||||
self._try_request_salts()
|
||||
|
||||
if self._msg_count >= CONTAINER_MAX_LENGTH:
|
||||
return None
|
||||
|
||||
assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE
|
||||
|
|
|
@ -186,24 +186,17 @@ class Sender:
|
|||
if self._write_drain_pending:
|
||||
return
|
||||
|
||||
# TODO test that the a request is only ever sent onrece
|
||||
requests = [r for r in self._requests if isinstance(r.state, NotSerialized)]
|
||||
if not requests:
|
||||
return
|
||||
|
||||
msg_ids = []
|
||||
for request in requests:
|
||||
if (msg_id := self._mtp.push(request.body)) is not None:
|
||||
msg_ids.append(msg_id)
|
||||
else:
|
||||
break
|
||||
for request in self._requests:
|
||||
if isinstance(request.state, NotSerialized):
|
||||
if (msg_id := self._mtp.push(request.body)) is not None:
|
||||
request.state = Serialized(msg_id)
|
||||
else:
|
||||
break
|
||||
|
||||
mtp_buffer = self._mtp.finalize()
|
||||
self._transport.pack(mtp_buffer, self._writer.write)
|
||||
self._write_drain_pending = True
|
||||
|
||||
for req, msg_id in zip(requests, msg_ids):
|
||||
req.state = Serialized(msg_id)
|
||||
if mtp_buffer:
|
||||
self._transport.pack(mtp_buffer, self._writer.write)
|
||||
self._write_drain_pending = True
|
||||
|
||||
def _on_net_read(self, read_buffer: bytes) -> List[Updates]:
|
||||
if not read_buffer:
|
||||
|
@ -255,34 +248,47 @@ class Sender:
|
|||
updates.append(u)
|
||||
|
||||
for msg_id, ret in result.rpc_results:
|
||||
found = False
|
||||
for i in reversed(range(len(self._requests))):
|
||||
req = self._requests[i]
|
||||
for i, req in enumerate(self._requests):
|
||||
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
||||
raise RuntimeError("got rpc result for unsent request")
|
||||
if isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
||||
found = True
|
||||
if isinstance(ret, bytes):
|
||||
assert len(ret) >= 4
|
||||
elif isinstance(ret, RpcError):
|
||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||
raise ret
|
||||
elif isinstance(ret, BadMessage):
|
||||
# TODO test that we resend the request
|
||||
req.state = NotSerialized()
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
req = self._requests.pop(i)
|
||||
req.result.set_result(ret)
|
||||
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
||||
del self._requests[i]
|
||||
break
|
||||
if not found:
|
||||
else:
|
||||
self._logger.warning(
|
||||
"telegram sent rpc_result for unknown msg_id=%d: %s",
|
||||
msg_id,
|
||||
ret.hex() if isinstance(ret, bytes) else repr(ret),
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(ret, bytes):
|
||||
assert len(ret) >= 4
|
||||
req.result.set_result(ret)
|
||||
elif isinstance(ret, RpcError):
|
||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||
req.result.set_exception(ret)
|
||||
elif isinstance(ret, BadMessage):
|
||||
if ret.retryable:
|
||||
self._logger.log(
|
||||
ret.severity,
|
||||
"telegram notified of bad msg_id=%d; will attempt to resend request: %s",
|
||||
msg_id,
|
||||
ret,
|
||||
)
|
||||
req.state = NotSerialized()
|
||||
self._requests.append(req)
|
||||
else:
|
||||
self._logger.log(
|
||||
ret.severity,
|
||||
"telegram notified of bad msg_id=%d; impossible to retry: %s",
|
||||
msg_id,
|
||||
ret,
|
||||
)
|
||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||
req.result.set_exception(ret)
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
@property
|
||||
def auth_key(self) -> Optional[bytes]:
|
||||
|
|
|
@ -66,6 +66,7 @@ class PackedChat:
|
|||
"""
|
||||
return bytes(self).hex()
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, hex: str) -> Self:
|
||||
"""
|
||||
Convenience method to convert hexadecimal numbers into bytes then passed to :meth:`from_bytes`:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..tl.core.serializable import obj_repr
|
||||
|
||||
|
@ -13,7 +13,7 @@ class DataCenter:
|
|||
:param auth: See below.
|
||||
"""
|
||||
|
||||
__slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
|
||||
__slots__: Tuple[str, ...] = ("id", "ipv4_addr", "ipv6_addr", "auth")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user