Unify the way deserialization is returned

This commit is contained in:
Lonami Exo 2024-03-16 15:00:38 +01:00
parent b5db881415
commit 3250f9ec37
6 changed files with 202 additions and 117 deletions

View File

@ -2,7 +2,17 @@ from .authentication import CreatedKey, Step1, Step2, Step3, create_key
from .authentication import step1 as auth_step1 from .authentication import step1 as auth_step1
from .authentication import step2 as auth_step2 from .authentication import step2 as auth_step2
from .authentication import step3 as auth_step3 from .authentication import step3 as auth_step3
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError from .mtp import (
BadMessage,
Deserialization,
Encrypted,
MsgId,
Mtp,
Plain,
RpcError,
RpcResult,
Update,
)
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
from .utils import DEFAULT_COMPRESSION_THRESHOLD from .utils import DEFAULT_COMPRESSION_THRESHOLD
@ -22,6 +32,8 @@ __all__ = [
"Mtp", "Mtp",
"Plain", "Plain",
"RpcError", "RpcError",
"RpcResult",
"Update",
"Abridged", "Abridged",
"BadStatus", "BadStatus",
"Full", "Full",

View File

@ -1,6 +1,6 @@
from .encrypted import Encrypted from .encrypted import Encrypted
from .plain import Plain from .plain import Plain
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
__all__ = [ __all__ = [
"Encrypted", "Encrypted",
@ -10,4 +10,6 @@ __all__ = [
"MsgId", "MsgId",
"Mtp", "Mtp",
"RpcError", "RpcError",
"RpcResult",
"Update",
] ]

View File

@ -60,7 +60,7 @@ from ..utils import (
gzip_decompress, gzip_decompress,
message_requires_ack, message_requires_ack,
) )
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
NUM_FUTURE_SALTS = 64 NUM_FUTURE_SALTS = 64
@ -112,8 +112,7 @@ class Encrypted(Mtp):
] ]
self._start_salt_time: Optional[Tuple[int, float]] = None self._start_salt_time: Optional[Tuple[int, float]] = None
self._compression_threshold = compression_threshold self._compression_threshold = compression_threshold
self._rpc_results: List[Tuple[MsgId, RpcResult]] = [] self._deserialization: List[Deserialization] = []
self._updates: List[bytes] = []
self._buffer = bytearray() self._buffer = bytearray()
self._salt_request_msg_id: Optional[int] = None self._salt_request_msg_id: Optional[int] = None
@ -244,12 +243,9 @@ class Encrypted(Mtp):
inner_constructor = struct.unpack_from("<I", result)[0] inner_constructor = struct.unpack_from("<I", result)[0]
if inner_constructor == GeneratedRpcError.constructor_id(): if inner_constructor == GeneratedRpcError.constructor_id():
self._rpc_results.append( error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result))
( error.msg_id = msg_id
msg_id, self._deserialization.append(error)
RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result)),
)
)
elif inner_constructor == RpcAnswerUnknown.constructor_id(): elif inner_constructor == RpcAnswerUnknown.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id pass # msg_id = rpc_drop_answer.msg_id
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id(): elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
@ -259,15 +255,15 @@ class Encrypted(Mtp):
elif inner_constructor == GzipPacked.constructor_id(): elif inner_constructor == GzipPacked.constructor_id():
body = gzip_decompress(GzipPacked.from_bytes(result)) body = gzip_decompress(GzipPacked.from_bytes(result))
self._store_own_updates(body) self._store_own_updates(body)
self._rpc_results.append((msg_id, body)) self._deserialization.append(RpcResult(msg_id, body))
else: else:
self._store_own_updates(result) self._store_own_updates(result)
self._rpc_results.append((msg_id, result)) self._deserialization.append(RpcResult(msg_id, result))
def _store_own_updates(self, body: bytes) -> None: def _store_own_updates(self, body: bytes) -> None:
constructor_id = struct.unpack_from("I", body)[0] constructor_id = struct.unpack_from("I", body)[0]
if constructor_id in UPDATE_IDS: if constructor_id in UPDATE_IDS:
self._updates.append(body) self._deserialization.append(Update(body))
def _handle_ack(self, message: Message) -> None: def _handle_ack(self, message: Message) -> None:
MsgsAck.from_bytes(message.body) MsgsAck.from_bytes(message.body)
@ -276,13 +272,13 @@ class Encrypted(Mtp):
bad_msg = AbcBadMsgNotification.from_bytes(message.body) bad_msg = AbcBadMsgNotification.from_bytes(message.body)
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
exc = BadMessage(code=bad_msg.error_code) exc = BadMessage(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code)
if bad_msg.bad_msg_id == self._salt_request_msg_id: if bad_msg.bad_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate. # Response to internal request, do not propagate.
self._salt_request_msg_id = None self._salt_request_msg_id = None
else: else:
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc)) self._deserialization.append(exc)
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0: if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
# If we had no valid salt, this error is expected. # If we had no valid salt, this error is expected.
@ -331,7 +327,9 @@ class Encrypted(Mtp):
# Response to internal request, do not propagate. # Response to internal request, do not propagate.
self._salt_request_msg_id = None self._salt_request_msg_id = None
else: else:
self._rpc_results.append((MsgId(salts.req_msg_id), message.body)) self._deserialization.append(
RpcResult(MsgId(salts.req_msg_id), message.body)
)
self._start_salt_time = (salts.now, self._adjusted_now()) self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = salts.salts self._salts = salts.salts
@ -343,7 +341,7 @@ class Encrypted(Mtp):
def _handle_pong(self, message: Message) -> None: def _handle_pong(self, message: Message) -> None:
pong = Pong.from_bytes(message.body) pong = Pong.from_bytes(message.body)
self._rpc_results.append((MsgId(pong.msg_id), message.body)) self._deserialization.append(RpcResult(MsgId(pong.msg_id), message.body))
def _handle_destroy_session(self, message: Message) -> None: def _handle_destroy_session(self, message: Message) -> None:
DestroySessionRes.from_bytes(message.body) DestroySessionRes.from_bytes(message.body)
@ -378,7 +376,7 @@ class Encrypted(Mtp):
HttpWait.from_bytes(message.body) HttpWait.from_bytes(message.body)
def _handle_update(self, message: Message) -> None: def _handle_update(self, message: Message) -> None:
self._updates.append(message.body) self._deserialization.append(Update(message.body))
def _try_request_salts(self) -> None: def _try_request_salts(self) -> None:
if ( if (
@ -441,7 +439,7 @@ class Encrypted(Mtp):
msg_id, buffer = result msg_id, buffer = result
return msg_id, encrypt_data_v2(buffer, self._auth_key) return msg_id, encrypt_data_v2(buffer, self._auth_key)
def deserialize(self, payload: bytes) -> Deserialization: def deserialize(self, payload: bytes) -> List[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key) plaintext = decrypt_data_v2(payload, self._auth_key)
@ -452,7 +450,6 @@ class Encrypted(Mtp):
self._process_message(Message._read_from(Reader(memoryview(plaintext)[16:]))) self._process_message(Message._read_from(Reader(memoryview(plaintext)[16:])))
result = Deserialization(rpc_results=self._rpc_results, updates=self._updates) result = self._deserialization[:]
self._rpc_results = [] self._deserialization.clear()
self._updates = []
return result return result

View File

@ -1,8 +1,8 @@
import struct import struct
from typing import Optional, Tuple from typing import List, Optional, Tuple
from ..utils import check_message_buffer from ..utils import check_message_buffer
from .types import Deserialization, MsgId, Mtp from .types import Deserialization, MsgId, Mtp, RpcResult
class Plain(Mtp): class Plain(Mtp):
@ -31,7 +31,7 @@ class Plain(Mtp):
self._buffer.clear() self._buffer.clear()
return MsgId(0), result return MsgId(0), result
def deserialize(self, payload: bytes) -> Deserialization: def deserialize(self, payload: bytes) -> List[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload) auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
@ -50,6 +50,4 @@ class Plain(Mtp):
f"message too short, expected: {20 + length}, got {len(payload)}" f"message too short, expected: {20 + length}, got {len(payload)}"
) )
return Deserialization( return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
rpc_results=[(MsgId(0), bytes(payload[20 : 20 + length]))], updates=[]
)

View File

@ -1,7 +1,6 @@
import logging import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, NewType, Optional, Self, Tuple from typing import List, NewType, Optional, Self, Tuple
from ...tl.mtproto.types import RpcError as GeneratedRpcError from ...tl.mtproto.types import RpcError as GeneratedRpcError
@ -9,6 +8,29 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int) MsgId = NewType("MsgId", int)
class Update:
"""
An update that does not belong to any RPC.
"""
__slots__ = ("body",)
def __init__(self, body: bytes):
self.body = body
class RpcResult:
"""
A response that belongs to the RPC associated with this message identifier.
"""
__slots__ = ("msg_id", "body")
def __init__(self, msg_id: MsgId, body: bytes):
self.msg_id = msg_id
self.body = body
class RpcError(ValueError): class RpcError(ValueError):
""" """
A Remote Procedure Call Error. A Remote Procedure Call Error.
@ -30,15 +52,17 @@ class RpcError(ValueError):
def __init__( def __init__(
self, self,
*, *args: object,
msg_id: MsgId = MsgId(0),
code: int = 0, code: int = 0,
name: str = "", name: str = "",
value: Optional[int] = None, value: Optional[int] = None,
caused_by: Optional[int] = None, caused_by: Optional[int] = None,
) -> None: ) -> None:
append_value = f" ({value})" if value else "" append_value = f" ({value})" if value else ""
super().__init__(f"rpc error {code}: {name}{append_value}") super().__init__(f"rpc error {code}: {name}{append_value}", *args)
self.msg_id = msg_id
self._code = code self._code = code
self._name = name self._name = name
self._value = value self._value = value
@ -121,13 +145,15 @@ NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
class BadMessage(ValueError): class BadMessage(ValueError):
def __init__( def __init__(
self, self,
*, *args: object,
msg_id: MsgId = MsgId(0),
code: int, code: int,
caused_by: Optional[int] = None, caused_by: Optional[int] = None,
) -> None: ) -> None:
description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available" description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
super().__init__(f"bad msg={code}: {description}") super().__init__(f"bad msg={code}: {description}", *args)
self.msg_id = msg_id
self._code = code self._code = code
self._caused_by = caused_by self._caused_by = caused_by
self.severity = ( self.severity = (
@ -152,13 +178,7 @@ class BadMessage(ValueError):
return self._code == other._code return self._code == other._code
RpcResult = bytes | RpcError | BadMessage Deserialization = Update | RpcResult | RpcError | BadMessage
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, RpcResult]]
updates: List[bytes]
# https://core.telegram.org/mtproto/description # https://core.telegram.org/mtproto/description
@ -181,7 +201,7 @@ class Mtp(ABC):
""" """
@abstractmethod @abstractmethod
def deserialize(self, payload: bytes) -> Deserialization: def deserialize(self, payload: bytes) -> List[Deserialization]:
""" """
Deserialize incoming buffer payload. Deserialize incoming buffer payload.
""" """

View File

@ -5,7 +5,17 @@ import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future from asyncio import FIRST_COMPLETED, Event, Future
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, List, Optional, Protocol, Self, Tuple, Type, TypeVar from typing import (
Generic,
Iterator,
List,
Optional,
Protocol,
Self,
Tuple,
Type,
TypeVar,
)
from ..crypto import AuthKey from ..crypto import AuthKey
from ..mtproto import ( from ..mtproto import (
@ -16,7 +26,9 @@ from ..mtproto import (
Mtp, Mtp,
Plain, Plain,
RpcError, RpcError,
RpcResult,
Transport, Transport,
Update,
authentication, authentication,
) )
from ..tl import Request as RemoteCall from ..tl import Request as RemoteCall
@ -127,17 +139,19 @@ class NotSerialized(RequestState):
class Serialized(RequestState): class Serialized(RequestState):
__slots__ = ("msg_id",) __slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId): def __init__(self, msg_id: MsgId):
self.msg_id = msg_id self.msg_id = msg_id
self.container_msg_id = msg_id
class Sent(RequestState): class Sent(RequestState):
__slots__ = ("msg_id",) __slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId): def __init__(self, msg_id: MsgId, container_msg_id: MsgId):
self.msg_id = msg_id self.msg_id = msg_id
self.container_msg_id = container_msg_id
Return = TypeVar("Return") Return = TypeVar("Return")
@ -273,7 +287,11 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
_, mtp_buffer = result container_msg_id, mtp_buffer = result
for request in self._requests:
if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True self._write_drain_pending = True
@ -299,7 +317,7 @@ class Sender:
def _on_net_write(self) -> None: def _on_net_write(self) -> None:
for req in self._requests: for req in self._requests:
if isinstance(req.state, Serialized): if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id) req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_ping_timeout(self) -> None: def _on_ping_timeout(self) -> None:
ping_id = generate_random_id() ping_id = generate_random_id()
@ -313,11 +331,23 @@ class Sender:
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self, updates: List[Updates]) -> None: def _process_mtp_buffer(self, updates: List[Updates]) -> None:
result = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(self._mtp_buffer)
for update in result.updates: for result in results:
if isinstance(result, Update):
self._process_update(updates, result.body)
elif isinstance(result, RpcResult):
self._process_result(result)
elif isinstance(result, RpcError):
self._process_error(result)
elif isinstance(result, BadMessage):
self._process_bad_message(result)
else:
raise RuntimeError("unexpected case")
def _process_update(self, updates: List[Updates], update: bytes) -> None:
try: try:
u = Updates.from_bytes(update) updates.append(Updates.from_bytes(update))
except ValueError: except ValueError:
cid = struct.unpack_from("I", update)[0] cid = struct.unpack_from("I", update)[0]
alt_classes: Tuple[Type[Serializable], ...] = ( alt_classes: Tuple[Type[Serializable], ...] = (
@ -337,7 +367,8 @@ class Sender:
AffectedMessages, AffectedMessages,
), ),
) )
u = UpdateShort( updates.append(
UpdateShort(
update=UpdateDeleteMessages( update=UpdateDeleteMessages(
messages=[], messages=[],
pts=affected.pts, pts=affected.pts,
@ -345,58 +376,83 @@ class Sender:
), ),
date=0, date=0,
) )
)
break break
else: else:
self._logger.warning( self._logger.warning(
"failed to deserialize incoming update; make sure the session is not in use elsewhere: %s", "failed to deserialize incoming update; make sure the session is not in use elsewhere: %s",
update.hex(), update.hex(),
) )
continue return
updates.append(u) def _process_result(self, result: RpcResult) -> None:
req = self._pop_request(result.msg_id)
for msg_id, ret in result.rpc_results: if req:
for i, req in enumerate(self._requests): assert len(result.body) >= 4
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: req.result.set_result(result.body)
raise RuntimeError("got rpc result for unsent request")
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
del self._requests[i]
break
else: else:
self._logger.warning( self._logger.warning(
"telegram sent rpc_result for unknown msg_id=%d: %s", "telegram sent rpc_result for unknown msg_id=%d: %s",
msg_id, result.msg_id,
ret.hex() if isinstance(ret, bytes) else repr(ret), result.body.hex(),
) )
continue
if isinstance(ret, bytes): def _process_error(self, result: RpcError) -> None:
assert len(ret) >= 4 req = self._pop_request(result.msg_id)
req.result.set_result(ret)
elif isinstance(ret, RpcError): if req:
ret._caused_by = struct.unpack_from("<I", req.body)[0] result._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(ret) req.result.set_exception(result)
elif isinstance(ret, BadMessage): else:
if ret.retryable: self._logger.warning(
"telegram sent rpc_error for unknown msg_id=%d: %s",
result.msg_id,
result,
)
def _process_bad_message(self, result: BadMessage) -> None:
for req in self._drain_requests(result.msg_id):
if result.retryable:
self._logger.log( self._logger.log(
ret.severity, result.severity,
"telegram notified of bad msg_id=%d; will attempt to resend request: %s", "telegram notified of bad msg_id=%d; will attempt to resend request: %s",
msg_id, result.msg_id,
ret, result,
) )
req.state = NotSerialized() req.state = NotSerialized()
self._requests.append(req) self._requests.append(req)
else: else:
self._logger.log( self._logger.log(
ret.severity, result.severity,
"telegram notified of bad msg_id=%d; impossible to retry: %s", "telegram notified of bad msg_id=%d; impossible to retry: %s",
msg_id, result.msg_id,
ret, result,
) )
ret._caused_by = struct.unpack_from("<I", req.body)[0] result._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(ret) req.result.set_exception(result)
else:
raise RuntimeError("unexpected case") def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
for i, req in enumerate(self._requests):
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
raise RuntimeError("got response for unsent request")
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
del self._requests[i]
return req
return None
def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]:
for i in reversed(range(len(self._requests))):
req = self._requests[i]
if isinstance(req.state, Serialized) and (
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
):
raise RuntimeError("got response for unsent request")
elif isinstance(req.state, Sent) and (
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
):
yield self._requests.pop(i)
@property @property
def auth_key(self) -> Optional[bytes]: def auth_key(self) -> Optional[bytes]: