mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-25 10:53:44 +03:00
Unify the way deserialization is returned
This commit is contained in:
parent
b5db881415
commit
3250f9ec37
|
@ -2,7 +2,17 @@ from .authentication import CreatedKey, Step1, Step2, Step3, create_key
|
||||||
from .authentication import step1 as auth_step1
|
from .authentication import step1 as auth_step1
|
||||||
from .authentication import step2 as auth_step2
|
from .authentication import step2 as auth_step2
|
||||||
from .authentication import step3 as auth_step3
|
from .authentication import step3 as auth_step3
|
||||||
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError
|
from .mtp import (
|
||||||
|
BadMessage,
|
||||||
|
Deserialization,
|
||||||
|
Encrypted,
|
||||||
|
MsgId,
|
||||||
|
Mtp,
|
||||||
|
Plain,
|
||||||
|
RpcError,
|
||||||
|
RpcResult,
|
||||||
|
Update,
|
||||||
|
)
|
||||||
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
|
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
|
||||||
from .utils import DEFAULT_COMPRESSION_THRESHOLD
|
from .utils import DEFAULT_COMPRESSION_THRESHOLD
|
||||||
|
|
||||||
|
@ -22,6 +32,8 @@ __all__ = [
|
||||||
"Mtp",
|
"Mtp",
|
||||||
"Plain",
|
"Plain",
|
||||||
"RpcError",
|
"RpcError",
|
||||||
|
"RpcResult",
|
||||||
|
"Update",
|
||||||
"Abridged",
|
"Abridged",
|
||||||
"BadStatus",
|
"BadStatus",
|
||||||
"Full",
|
"Full",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .encrypted import Encrypted
|
from .encrypted import Encrypted
|
||||||
from .plain import Plain
|
from .plain import Plain
|
||||||
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError
|
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Encrypted",
|
"Encrypted",
|
||||||
|
@ -10,4 +10,6 @@ __all__ = [
|
||||||
"MsgId",
|
"MsgId",
|
||||||
"Mtp",
|
"Mtp",
|
||||||
"RpcError",
|
"RpcError",
|
||||||
|
"RpcResult",
|
||||||
|
"Update",
|
||||||
]
|
]
|
||||||
|
|
|
@ -60,7 +60,7 @@ from ..utils import (
|
||||||
gzip_decompress,
|
gzip_decompress,
|
||||||
message_requires_ack,
|
message_requires_ack,
|
||||||
)
|
)
|
||||||
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult
|
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
|
||||||
|
|
||||||
NUM_FUTURE_SALTS = 64
|
NUM_FUTURE_SALTS = 64
|
||||||
|
|
||||||
|
@ -112,8 +112,7 @@ class Encrypted(Mtp):
|
||||||
]
|
]
|
||||||
self._start_salt_time: Optional[Tuple[int, float]] = None
|
self._start_salt_time: Optional[Tuple[int, float]] = None
|
||||||
self._compression_threshold = compression_threshold
|
self._compression_threshold = compression_threshold
|
||||||
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
self._deserialization: List[Deserialization] = []
|
||||||
self._updates: List[bytes] = []
|
|
||||||
self._buffer = bytearray()
|
self._buffer = bytearray()
|
||||||
self._salt_request_msg_id: Optional[int] = None
|
self._salt_request_msg_id: Optional[int] = None
|
||||||
|
|
||||||
|
@ -244,12 +243,9 @@ class Encrypted(Mtp):
|
||||||
inner_constructor = struct.unpack_from("<I", result)[0]
|
inner_constructor = struct.unpack_from("<I", result)[0]
|
||||||
|
|
||||||
if inner_constructor == GeneratedRpcError.constructor_id():
|
if inner_constructor == GeneratedRpcError.constructor_id():
|
||||||
self._rpc_results.append(
|
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result))
|
||||||
(
|
error.msg_id = msg_id
|
||||||
msg_id,
|
self._deserialization.append(error)
|
||||||
RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif inner_constructor == RpcAnswerUnknown.constructor_id():
|
elif inner_constructor == RpcAnswerUnknown.constructor_id():
|
||||||
pass # msg_id = rpc_drop_answer.msg_id
|
pass # msg_id = rpc_drop_answer.msg_id
|
||||||
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
|
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
|
||||||
|
@ -259,15 +255,15 @@ class Encrypted(Mtp):
|
||||||
elif inner_constructor == GzipPacked.constructor_id():
|
elif inner_constructor == GzipPacked.constructor_id():
|
||||||
body = gzip_decompress(GzipPacked.from_bytes(result))
|
body = gzip_decompress(GzipPacked.from_bytes(result))
|
||||||
self._store_own_updates(body)
|
self._store_own_updates(body)
|
||||||
self._rpc_results.append((msg_id, body))
|
self._deserialization.append(RpcResult(msg_id, body))
|
||||||
else:
|
else:
|
||||||
self._store_own_updates(result)
|
self._store_own_updates(result)
|
||||||
self._rpc_results.append((msg_id, result))
|
self._deserialization.append(RpcResult(msg_id, result))
|
||||||
|
|
||||||
def _store_own_updates(self, body: bytes) -> None:
|
def _store_own_updates(self, body: bytes) -> None:
|
||||||
constructor_id = struct.unpack_from("I", body)[0]
|
constructor_id = struct.unpack_from("I", body)[0]
|
||||||
if constructor_id in UPDATE_IDS:
|
if constructor_id in UPDATE_IDS:
|
||||||
self._updates.append(body)
|
self._deserialization.append(Update(body))
|
||||||
|
|
||||||
def _handle_ack(self, message: Message) -> None:
|
def _handle_ack(self, message: Message) -> None:
|
||||||
MsgsAck.from_bytes(message.body)
|
MsgsAck.from_bytes(message.body)
|
||||||
|
@ -276,13 +272,13 @@ class Encrypted(Mtp):
|
||||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||||
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
||||||
|
|
||||||
exc = BadMessage(code=bad_msg.error_code)
|
exc = BadMessage(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code)
|
||||||
|
|
||||||
if bad_msg.bad_msg_id == self._salt_request_msg_id:
|
if bad_msg.bad_msg_id == self._salt_request_msg_id:
|
||||||
# Response to internal request, do not propagate.
|
# Response to internal request, do not propagate.
|
||||||
self._salt_request_msg_id = None
|
self._salt_request_msg_id = None
|
||||||
else:
|
else:
|
||||||
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
|
self._deserialization.append(exc)
|
||||||
|
|
||||||
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
||||||
# If we had no valid salt, this error is expected.
|
# If we had no valid salt, this error is expected.
|
||||||
|
@ -331,7 +327,9 @@ class Encrypted(Mtp):
|
||||||
# Response to internal request, do not propagate.
|
# Response to internal request, do not propagate.
|
||||||
self._salt_request_msg_id = None
|
self._salt_request_msg_id = None
|
||||||
else:
|
else:
|
||||||
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
|
self._deserialization.append(
|
||||||
|
RpcResult(MsgId(salts.req_msg_id), message.body)
|
||||||
|
)
|
||||||
|
|
||||||
self._start_salt_time = (salts.now, self._adjusted_now())
|
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||||
self._salts = salts.salts
|
self._salts = salts.salts
|
||||||
|
@ -343,7 +341,7 @@ class Encrypted(Mtp):
|
||||||
|
|
||||||
def _handle_pong(self, message: Message) -> None:
|
def _handle_pong(self, message: Message) -> None:
|
||||||
pong = Pong.from_bytes(message.body)
|
pong = Pong.from_bytes(message.body)
|
||||||
self._rpc_results.append((MsgId(pong.msg_id), message.body))
|
self._deserialization.append(RpcResult(MsgId(pong.msg_id), message.body))
|
||||||
|
|
||||||
def _handle_destroy_session(self, message: Message) -> None:
|
def _handle_destroy_session(self, message: Message) -> None:
|
||||||
DestroySessionRes.from_bytes(message.body)
|
DestroySessionRes.from_bytes(message.body)
|
||||||
|
@ -378,7 +376,7 @@ class Encrypted(Mtp):
|
||||||
HttpWait.from_bytes(message.body)
|
HttpWait.from_bytes(message.body)
|
||||||
|
|
||||||
def _handle_update(self, message: Message) -> None:
|
def _handle_update(self, message: Message) -> None:
|
||||||
self._updates.append(message.body)
|
self._deserialization.append(Update(message.body))
|
||||||
|
|
||||||
def _try_request_salts(self) -> None:
|
def _try_request_salts(self) -> None:
|
||||||
if (
|
if (
|
||||||
|
@ -441,7 +439,7 @@ class Encrypted(Mtp):
|
||||||
msg_id, buffer = result
|
msg_id, buffer = result
|
||||||
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
|
||||||
plaintext = decrypt_data_v2(payload, self._auth_key)
|
plaintext = decrypt_data_v2(payload, self._auth_key)
|
||||||
|
@ -452,7 +450,6 @@ class Encrypted(Mtp):
|
||||||
|
|
||||||
self._process_message(Message._read_from(Reader(memoryview(plaintext)[16:])))
|
self._process_message(Message._read_from(Reader(memoryview(plaintext)[16:])))
|
||||||
|
|
||||||
result = Deserialization(rpc_results=self._rpc_results, updates=self._updates)
|
result = self._deserialization[:]
|
||||||
self._rpc_results = []
|
self._deserialization.clear()
|
||||||
self._updates = []
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import struct
|
import struct
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..utils import check_message_buffer
|
from ..utils import check_message_buffer
|
||||||
from .types import Deserialization, MsgId, Mtp
|
from .types import Deserialization, MsgId, Mtp, RpcResult
|
||||||
|
|
||||||
|
|
||||||
class Plain(Mtp):
|
class Plain(Mtp):
|
||||||
|
@ -31,7 +31,7 @@ class Plain(Mtp):
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
return MsgId(0), result
|
return MsgId(0), result
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
|
||||||
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
||||||
|
@ -50,6 +50,4 @@ class Plain(Mtp):
|
||||||
f"message too short, expected: {20 + length}, got {len(payload)}"
|
f"message too short, expected: {20 + length}, got {len(payload)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return Deserialization(
|
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
|
||||||
rpc_results=[(MsgId(0), bytes(payload[20 : 20 + length]))], updates=[]
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, NewType, Optional, Self, Tuple
|
from typing import List, NewType, Optional, Self, Tuple
|
||||||
|
|
||||||
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
|
@ -9,6 +8,29 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
MsgId = NewType("MsgId", int)
|
MsgId = NewType("MsgId", int)
|
||||||
|
|
||||||
|
|
||||||
|
class Update:
|
||||||
|
"""
|
||||||
|
An update that does not belong to any RPC.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("body",)
|
||||||
|
|
||||||
|
def __init__(self, body: bytes):
|
||||||
|
self.body = body
|
||||||
|
|
||||||
|
|
||||||
|
class RpcResult:
|
||||||
|
"""
|
||||||
|
A response that belongs to the RPC associated with this message identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("msg_id", "body")
|
||||||
|
|
||||||
|
def __init__(self, msg_id: MsgId, body: bytes):
|
||||||
|
self.msg_id = msg_id
|
||||||
|
self.body = body
|
||||||
|
|
||||||
|
|
||||||
class RpcError(ValueError):
|
class RpcError(ValueError):
|
||||||
"""
|
"""
|
||||||
A Remote Procedure Call Error.
|
A Remote Procedure Call Error.
|
||||||
|
@ -30,15 +52,17 @@ class RpcError(ValueError):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*args: object,
|
||||||
|
msg_id: MsgId = MsgId(0),
|
||||||
code: int = 0,
|
code: int = 0,
|
||||||
name: str = "",
|
name: str = "",
|
||||||
value: Optional[int] = None,
|
value: Optional[int] = None,
|
||||||
caused_by: Optional[int] = None,
|
caused_by: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
append_value = f" ({value})" if value else ""
|
append_value = f" ({value})" if value else ""
|
||||||
super().__init__(f"rpc error {code}: {name}{append_value}")
|
super().__init__(f"rpc error {code}: {name}{append_value}", *args)
|
||||||
|
|
||||||
|
self.msg_id = msg_id
|
||||||
self._code = code
|
self._code = code
|
||||||
self._name = name
|
self._name = name
|
||||||
self._value = value
|
self._value = value
|
||||||
|
@ -121,13 +145,15 @@ NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
|
||||||
class BadMessage(ValueError):
|
class BadMessage(ValueError):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*args: object,
|
||||||
|
msg_id: MsgId = MsgId(0),
|
||||||
code: int,
|
code: int,
|
||||||
caused_by: Optional[int] = None,
|
caused_by: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
|
description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
|
||||||
super().__init__(f"bad msg={code}: {description}")
|
super().__init__(f"bad msg={code}: {description}", *args)
|
||||||
|
|
||||||
|
self.msg_id = msg_id
|
||||||
self._code = code
|
self._code = code
|
||||||
self._caused_by = caused_by
|
self._caused_by = caused_by
|
||||||
self.severity = (
|
self.severity = (
|
||||||
|
@ -152,13 +178,7 @@ class BadMessage(ValueError):
|
||||||
return self._code == other._code
|
return self._code == other._code
|
||||||
|
|
||||||
|
|
||||||
RpcResult = bytes | RpcError | BadMessage
|
Deserialization = Update | RpcResult | RpcError | BadMessage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Deserialization:
|
|
||||||
rpc_results: List[Tuple[MsgId, RpcResult]]
|
|
||||||
updates: List[bytes]
|
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description
|
# https://core.telegram.org/mtproto/description
|
||||||
|
@ -181,7 +201,7 @@ class Mtp(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||||
"""
|
"""
|
||||||
Deserialize incoming buffer payload.
|
Deserialize incoming buffer payload.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,17 @@ import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from asyncio import FIRST_COMPLETED, Event, Future
|
from asyncio import FIRST_COMPLETED, Event, Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, List, Optional, Protocol, Self, Tuple, Type, TypeVar
|
from typing import (
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Self,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
from ..crypto import AuthKey
|
from ..crypto import AuthKey
|
||||||
from ..mtproto import (
|
from ..mtproto import (
|
||||||
|
@ -16,7 +26,9 @@ from ..mtproto import (
|
||||||
Mtp,
|
Mtp,
|
||||||
Plain,
|
Plain,
|
||||||
RpcError,
|
RpcError,
|
||||||
|
RpcResult,
|
||||||
Transport,
|
Transport,
|
||||||
|
Update,
|
||||||
authentication,
|
authentication,
|
||||||
)
|
)
|
||||||
from ..tl import Request as RemoteCall
|
from ..tl import Request as RemoteCall
|
||||||
|
@ -127,17 +139,19 @@ class NotSerialized(RequestState):
|
||||||
|
|
||||||
|
|
||||||
class Serialized(RequestState):
|
class Serialized(RequestState):
|
||||||
__slots__ = ("msg_id",)
|
__slots__ = ("msg_id", "container_msg_id")
|
||||||
|
|
||||||
def __init__(self, msg_id: MsgId):
|
def __init__(self, msg_id: MsgId):
|
||||||
self.msg_id = msg_id
|
self.msg_id = msg_id
|
||||||
|
self.container_msg_id = msg_id
|
||||||
|
|
||||||
|
|
||||||
class Sent(RequestState):
|
class Sent(RequestState):
|
||||||
__slots__ = ("msg_id",)
|
__slots__ = ("msg_id", "container_msg_id")
|
||||||
|
|
||||||
def __init__(self, msg_id: MsgId):
|
def __init__(self, msg_id: MsgId, container_msg_id: MsgId):
|
||||||
self.msg_id = msg_id
|
self.msg_id = msg_id
|
||||||
|
self.container_msg_id = container_msg_id
|
||||||
|
|
||||||
|
|
||||||
Return = TypeVar("Return")
|
Return = TypeVar("Return")
|
||||||
|
@ -273,7 +287,11 @@ class Sender:
|
||||||
|
|
||||||
result = self._mtp.finalize()
|
result = self._mtp.finalize()
|
||||||
if result:
|
if result:
|
||||||
_, mtp_buffer = result
|
container_msg_id, mtp_buffer = result
|
||||||
|
for request in self._requests:
|
||||||
|
if isinstance(request.state, Serialized):
|
||||||
|
request.state.container_msg_id = container_msg_id
|
||||||
|
|
||||||
self._transport.pack(mtp_buffer, self._writer.write)
|
self._transport.pack(mtp_buffer, self._writer.write)
|
||||||
self._write_drain_pending = True
|
self._write_drain_pending = True
|
||||||
|
|
||||||
|
@ -299,7 +317,7 @@ class Sender:
|
||||||
def _on_net_write(self) -> None:
|
def _on_net_write(self) -> None:
|
||||||
for req in self._requests:
|
for req in self._requests:
|
||||||
if isinstance(req.state, Serialized):
|
if isinstance(req.state, Serialized):
|
||||||
req.state = Sent(req.state.msg_id)
|
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
|
||||||
|
|
||||||
def _on_ping_timeout(self) -> None:
|
def _on_ping_timeout(self) -> None:
|
||||||
ping_id = generate_random_id()
|
ping_id = generate_random_id()
|
||||||
|
@ -313,11 +331,23 @@ class Sender:
|
||||||
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
|
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
|
||||||
|
|
||||||
def _process_mtp_buffer(self, updates: List[Updates]) -> None:
|
def _process_mtp_buffer(self, updates: List[Updates]) -> None:
|
||||||
result = self._mtp.deserialize(self._mtp_buffer)
|
results = self._mtp.deserialize(self._mtp_buffer)
|
||||||
|
|
||||||
for update in result.updates:
|
for result in results:
|
||||||
|
if isinstance(result, Update):
|
||||||
|
self._process_update(updates, result.body)
|
||||||
|
elif isinstance(result, RpcResult):
|
||||||
|
self._process_result(result)
|
||||||
|
elif isinstance(result, RpcError):
|
||||||
|
self._process_error(result)
|
||||||
|
elif isinstance(result, BadMessage):
|
||||||
|
self._process_bad_message(result)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unexpected case")
|
||||||
|
|
||||||
|
def _process_update(self, updates: List[Updates], update: bytes) -> None:
|
||||||
try:
|
try:
|
||||||
u = Updates.from_bytes(update)
|
updates.append(Updates.from_bytes(update))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
cid = struct.unpack_from("I", update)[0]
|
cid = struct.unpack_from("I", update)[0]
|
||||||
alt_classes: Tuple[Type[Serializable], ...] = (
|
alt_classes: Tuple[Type[Serializable], ...] = (
|
||||||
|
@ -337,7 +367,8 @@ class Sender:
|
||||||
AffectedMessages,
|
AffectedMessages,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
u = UpdateShort(
|
updates.append(
|
||||||
|
UpdateShort(
|
||||||
update=UpdateDeleteMessages(
|
update=UpdateDeleteMessages(
|
||||||
messages=[],
|
messages=[],
|
||||||
pts=affected.pts,
|
pts=affected.pts,
|
||||||
|
@ -345,58 +376,83 @@ class Sender:
|
||||||
),
|
),
|
||||||
date=0,
|
date=0,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
"failed to deserialize incoming update; make sure the session is not in use elsewhere: %s",
|
"failed to deserialize incoming update; make sure the session is not in use elsewhere: %s",
|
||||||
update.hex(),
|
update.hex(),
|
||||||
)
|
)
|
||||||
continue
|
return
|
||||||
|
|
||||||
updates.append(u)
|
def _process_result(self, result: RpcResult) -> None:
|
||||||
|
req = self._pop_request(result.msg_id)
|
||||||
|
|
||||||
for msg_id, ret in result.rpc_results:
|
if req:
|
||||||
for i, req in enumerate(self._requests):
|
assert len(result.body) >= 4
|
||||||
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
req.result.set_result(result.body)
|
||||||
raise RuntimeError("got rpc result for unsent request")
|
|
||||||
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
|
||||||
del self._requests[i]
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
"telegram sent rpc_result for unknown msg_id=%d: %s",
|
"telegram sent rpc_result for unknown msg_id=%d: %s",
|
||||||
msg_id,
|
result.msg_id,
|
||||||
ret.hex() if isinstance(ret, bytes) else repr(ret),
|
result.body.hex(),
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(ret, bytes):
|
def _process_error(self, result: RpcError) -> None:
|
||||||
assert len(ret) >= 4
|
req = self._pop_request(result.msg_id)
|
||||||
req.result.set_result(ret)
|
|
||||||
elif isinstance(ret, RpcError):
|
if req:
|
||||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
req.result.set_exception(ret)
|
req.result.set_exception(result)
|
||||||
elif isinstance(ret, BadMessage):
|
else:
|
||||||
if ret.retryable:
|
self._logger.warning(
|
||||||
|
"telegram sent rpc_error for unknown msg_id=%d: %s",
|
||||||
|
result.msg_id,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_bad_message(self, result: BadMessage) -> None:
|
||||||
|
for req in self._drain_requests(result.msg_id):
|
||||||
|
if result.retryable:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
ret.severity,
|
result.severity,
|
||||||
"telegram notified of bad msg_id=%d; will attempt to resend request: %s",
|
"telegram notified of bad msg_id=%d; will attempt to resend request: %s",
|
||||||
msg_id,
|
result.msg_id,
|
||||||
ret,
|
result,
|
||||||
)
|
)
|
||||||
req.state = NotSerialized()
|
req.state = NotSerialized()
|
||||||
self._requests.append(req)
|
self._requests.append(req)
|
||||||
else:
|
else:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
ret.severity,
|
result.severity,
|
||||||
"telegram notified of bad msg_id=%d; impossible to retry: %s",
|
"telegram notified of bad msg_id=%d; impossible to retry: %s",
|
||||||
msg_id,
|
result.msg_id,
|
||||||
ret,
|
result,
|
||||||
)
|
)
|
||||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
req.result.set_exception(ret)
|
req.result.set_exception(result)
|
||||||
else:
|
|
||||||
raise RuntimeError("unexpected case")
|
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
||||||
|
for i, req in enumerate(self._requests):
|
||||||
|
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
||||||
|
raise RuntimeError("got response for unsent request")
|
||||||
|
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
||||||
|
del self._requests[i]
|
||||||
|
return req
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]:
|
||||||
|
for i in reversed(range(len(self._requests))):
|
||||||
|
req = self._requests[i]
|
||||||
|
if isinstance(req.state, Serialized) and (
|
||||||
|
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
|
||||||
|
):
|
||||||
|
raise RuntimeError("got response for unsent request")
|
||||||
|
elif isinstance(req.state, Sent) and (
|
||||||
|
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
|
||||||
|
):
|
||||||
|
yield self._requests.pop(i)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_key(self) -> Optional[bytes]:
|
def auth_key(self) -> Optional[bytes]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user