mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-24 15:30:48 +03:00
Implement MessageBox
This commit is contained in:
parent
c77c10b48f
commit
69d7941852
|
@ -1,8 +1,8 @@
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from ...tl import abcs, types
|
from ...tl import abcs, types
|
||||||
from ..chat.hash_cache import ChatHashCache
|
from ..chat.hash_cache import ChatHashCache
|
||||||
from .defs import ACCOUNT_WIDE, NO_SEQ, SECRET_CHATS, Gap
|
from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo
|
||||||
|
|
||||||
|
|
||||||
def updates_(updates: types.Updates) -> types.UpdatesCombined:
|
def updates_(updates: types.Updates) -> types.UpdatesCombined:
|
||||||
|
@ -180,64 +180,64 @@ def message_channel_id(message: abcs.Message) -> Optional[int]:
|
||||||
return peer.channel_id if isinstance(peer, types.PeerChannel) else None
|
return peer.channel_id if isinstance(peer, types.PeerChannel) else None
|
||||||
|
|
||||||
|
|
||||||
def pts_info_from_update(update: abcs.Update) -> Optional[Tuple[int | str, int, int]]:
|
def pts_info_from_update(update: abcs.Update) -> Optional[PtsInfo]:
|
||||||
if isinstance(update, types.UpdateNewMessage):
|
if isinstance(update, types.UpdateNewMessage):
|
||||||
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateDeleteMessages):
|
elif isinstance(update, types.UpdateDeleteMessages):
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateNewEncryptedMessage):
|
elif isinstance(update, types.UpdateNewEncryptedMessage):
|
||||||
return SECRET_CHATS, update.qts, 1
|
return PtsInfo(ENTRY_SECRET, update.qts, 1)
|
||||||
elif isinstance(update, types.UpdateReadHistoryInbox):
|
elif isinstance(update, types.UpdateReadHistoryInbox):
|
||||||
assert not isinstance(update.peer, types.PeerChannel)
|
assert not isinstance(update.peer, types.PeerChannel)
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateReadHistoryOutbox):
|
elif isinstance(update, types.UpdateReadHistoryOutbox):
|
||||||
assert not isinstance(update.peer, types.PeerChannel)
|
assert not isinstance(update.peer, types.PeerChannel)
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateWebPage):
|
elif isinstance(update, types.UpdateWebPage):
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateReadMessagesContents):
|
elif isinstance(update, types.UpdateReadMessagesContents):
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateChannelTooLong):
|
elif isinstance(update, types.UpdateChannelTooLong):
|
||||||
if update.pts is not None:
|
if update.pts is not None:
|
||||||
return update.channel_id, update.pts, 0
|
return PtsInfo(update.channel_id, update.pts, 0)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif isinstance(update, types.UpdateNewChannelMessage):
|
elif isinstance(update, types.UpdateNewChannelMessage):
|
||||||
channel_id = message_channel_id(update.message)
|
channel_id = message_channel_id(update.message)
|
||||||
if channel_id is not None:
|
if channel_id is not None:
|
||||||
return channel_id, update.pts, update.pts_count
|
return PtsInfo(channel_id, update.pts, update.pts_count)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif isinstance(update, types.UpdateReadChannelInbox):
|
elif isinstance(update, types.UpdateReadChannelInbox):
|
||||||
return update.channel_id, update.pts, 0
|
return PtsInfo(update.channel_id, update.pts, 0)
|
||||||
elif isinstance(update, types.UpdateDeleteChannelMessages):
|
elif isinstance(update, types.UpdateDeleteChannelMessages):
|
||||||
return update.channel_id, update.pts, update.pts_count
|
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateEditChannelMessage):
|
elif isinstance(update, types.UpdateEditChannelMessage):
|
||||||
channel_id = message_channel_id(update.message)
|
channel_id = message_channel_id(update.message)
|
||||||
if channel_id is not None:
|
if channel_id is not None:
|
||||||
return channel_id, update.pts, update.pts_count
|
return PtsInfo(channel_id, update.pts, update.pts_count)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif isinstance(update, types.UpdateEditMessage):
|
elif isinstance(update, types.UpdateEditMessage):
|
||||||
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateChannelWebPage):
|
elif isinstance(update, types.UpdateChannelWebPage):
|
||||||
return update.channel_id, update.pts, update.pts_count
|
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateFolderPeers):
|
elif isinstance(update, types.UpdateFolderPeers):
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdatePinnedMessages):
|
elif isinstance(update, types.UpdatePinnedMessages):
|
||||||
assert not isinstance(update.peer, types.PeerChannel)
|
assert not isinstance(update.peer, types.PeerChannel)
|
||||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdatePinnedChannelMessages):
|
elif isinstance(update, types.UpdatePinnedChannelMessages):
|
||||||
return update.channel_id, update.pts, update.pts_count
|
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||||
elif isinstance(update, types.UpdateChatParticipant):
|
elif isinstance(update, types.UpdateChatParticipant):
|
||||||
return SECRET_CHATS, update.qts, 0
|
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||||
elif isinstance(update, types.UpdateChannelParticipant):
|
elif isinstance(update, types.UpdateChannelParticipant):
|
||||||
return SECRET_CHATS, update.qts, 0
|
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||||
elif isinstance(update, types.UpdateBotStopped):
|
elif isinstance(update, types.UpdateBotStopped):
|
||||||
return SECRET_CHATS, update.qts, 0
|
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||||
elif isinstance(update, types.UpdateBotChatInviteRequester):
|
elif isinstance(update, types.UpdateBotChatInviteRequester):
|
||||||
return SECRET_CHATS, update.qts, 0
|
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -1,3 +1,133 @@
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from ...tl import abcs
|
||||||
|
|
||||||
|
|
||||||
|
class DataCenter:
|
||||||
|
__slots__ = ("id", "ip", "port", "auth")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, ip: str, port: int, auth: Optional[bytes]) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.ip = ip
|
||||||
|
self.port = port
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
__slots__ = ("id", "dc", "bot")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.dc = dc
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelState:
|
||||||
|
__slots__ = ("id", "pts")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, pts: int) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.pts = pts
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateState:
|
||||||
|
__slots__ = (
|
||||||
|
"pts",
|
||||||
|
"qts",
|
||||||
|
"date",
|
||||||
|
"seq",
|
||||||
|
"channels",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pts: int,
|
||||||
|
qts: int,
|
||||||
|
date: int,
|
||||||
|
seq: int,
|
||||||
|
channels: List[ChannelState],
|
||||||
|
) -> None:
|
||||||
|
self.pts = pts
|
||||||
|
self.qts = qts
|
||||||
|
self.date = date
|
||||||
|
self.seq = seq
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
__slots__ = ("dcs", "user", "state")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dcs: List[DataCenter],
|
||||||
|
user: Optional[User],
|
||||||
|
state: Optional[UpdateState],
|
||||||
|
):
|
||||||
|
self.dcs = dcs
|
||||||
|
self.user = user
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
|
||||||
|
class PtsInfo:
|
||||||
|
__slots__ = ("pts", "pts_count", "entry")
|
||||||
|
|
||||||
|
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None:
|
||||||
|
self.pts = pts
|
||||||
|
self.pts_count = pts_count
|
||||||
|
self.entry = entry
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
__slots__ = ("pts", "deadline")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pts: int,
|
||||||
|
deadline: float,
|
||||||
|
) -> None:
|
||||||
|
self.pts = pts
|
||||||
|
self.deadline = deadline
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"State(pts={self.pts}, deadline={self.deadline})"
|
||||||
|
|
||||||
|
|
||||||
|
class PossibleGap:
|
||||||
|
__slots__ = ("deadline", "updates")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
deadline: float,
|
||||||
|
updates: List[abcs.Update],
|
||||||
|
) -> None:
|
||||||
|
self.deadline = deadline
|
||||||
|
self.updates = updates
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrematureEndReason(Enum):
|
||||||
|
TEMPORARY_SERVER_ISSUES = "tmp"
|
||||||
|
BANNED = "ban"
|
||||||
|
|
||||||
|
|
||||||
|
class Gap(ValueError):
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "Gap()"
|
||||||
|
|
||||||
|
|
||||||
NO_SEQ = 0
|
NO_SEQ = 0
|
||||||
|
|
||||||
NO_PTS = 0
|
NO_PTS = 0
|
||||||
|
@ -11,9 +141,10 @@ POSSIBLE_GAP_TIMEOUT = 0.5
|
||||||
# https://core.telegram.org/api/updates
|
# https://core.telegram.org/api/updates
|
||||||
NO_UPDATES_TIMEOUT = 15 * 60
|
NO_UPDATES_TIMEOUT = 15 * 60
|
||||||
|
|
||||||
ACCOUNT_WIDE = "ACCOUNT"
|
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
||||||
SECRET_CHATS = "SECRET"
|
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
||||||
|
Entry = Union[Literal["ACCOUNT"], Literal["SECRET"], int]
|
||||||
|
|
||||||
|
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
||||||
class Gap(ValueError):
|
# We don't define a name for this as libraries shouldn't do that though.
|
||||||
pass
|
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
||||||
|
|
637
client/src/telethon/_impl/session/message_box/messagebox.py
Normal file
637
client/src/telethon/_impl/session/message_box/messagebox.py
Normal file
|
@ -0,0 +1,637 @@
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from ...tl import abcs, functions, types
|
||||||
|
from ...tl.core.request import Request
|
||||||
|
from ..chat.hash_cache import ChatHashCache
|
||||||
|
from .adaptor import adapt, pts_info_from_update
|
||||||
|
from .defs import (
|
||||||
|
BOT_CHANNEL_DIFF_LIMIT,
|
||||||
|
ENTRY_ACCOUNT,
|
||||||
|
ENTRY_SECRET,
|
||||||
|
LOG_LEVEL_TRACE,
|
||||||
|
NO_PTS,
|
||||||
|
NO_SEQ,
|
||||||
|
NO_UPDATES_TIMEOUT,
|
||||||
|
POSSIBLE_GAP_TIMEOUT,
|
||||||
|
USER_CHANNEL_DIFF_LIMIT,
|
||||||
|
ChannelState,
|
||||||
|
Entry,
|
||||||
|
Gap,
|
||||||
|
PossibleGap,
|
||||||
|
PrematureEndReason,
|
||||||
|
State,
|
||||||
|
UpdateState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def next_updates_deadline() -> float:
|
||||||
|
return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
def epoch() -> datetime.datetime:
|
||||||
|
return datetime.datetime(*time.gmtime(0)[:6]).replace(tzinfo=datetime.timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# https://core.telegram.org/api/updates#message-related-event-sequences.
|
||||||
|
class MessageBox:
|
||||||
|
__slots__ = (
|
||||||
|
"_log",
|
||||||
|
"map",
|
||||||
|
"date",
|
||||||
|
"seq",
|
||||||
|
"possible_gaps",
|
||||||
|
"getting_diff_for",
|
||||||
|
"next_deadline",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
log: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self._log = log or logging.getLogger("telethon.messagebox")
|
||||||
|
self.map: Dict[Entry, State] = {}
|
||||||
|
self.date = epoch()
|
||||||
|
self.seq = NO_SEQ
|
||||||
|
self.possible_gaps: Dict[Entry, PossibleGap] = {}
|
||||||
|
self.getting_diff_for: Set[Entry] = set()
|
||||||
|
self.next_deadline: Optional[Entry] = None
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
self._trace("MessageBox initialized")
|
||||||
|
|
||||||
|
def _trace(self, msg: str, *args: object) -> None:
|
||||||
|
# Calls to trace can't really be removed beforehand without some dark magic.
|
||||||
|
# So every call to trace is prefixed with `if __debug__`` instead, to remove
|
||||||
|
# it when using `python -O`. Probably unnecessary, but it's nice to avoid
|
||||||
|
# paying the cost for something that is not used.
|
||||||
|
self._log.log(
|
||||||
|
LOG_LEVEL_TRACE,
|
||||||
|
"Current MessageBox state: seq = %r, date = %s, map = %r",
|
||||||
|
self.seq,
|
||||||
|
self.date.isoformat(),
|
||||||
|
self.map,
|
||||||
|
)
|
||||||
|
self._log.log(LOG_LEVEL_TRACE, msg, *args)
|
||||||
|
|
||||||
|
def load(self, state: UpdateState) -> None:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Loading MessageBox with state = %r",
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
|
deadline = next_updates_deadline()
|
||||||
|
|
||||||
|
self.map.clear()
|
||||||
|
if state.pts != NO_SEQ:
|
||||||
|
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
|
||||||
|
if state.qts != NO_SEQ:
|
||||||
|
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
|
||||||
|
self.map.update(
|
||||||
|
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
|
||||||
|
)
|
||||||
|
|
||||||
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
state.date, tz=datetime.timezone.utc
|
||||||
|
)
|
||||||
|
self.seq = state.seq
|
||||||
|
self.possible_gaps.clear()
|
||||||
|
self.getting_diff_for.clear()
|
||||||
|
self.next_deadline = ENTRY_ACCOUNT
|
||||||
|
|
||||||
|
def session_state(self) -> UpdateState:
|
||||||
|
return UpdateState(
|
||||||
|
pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else NO_PTS,
|
||||||
|
qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_PTS,
|
||||||
|
date=int(self.date.timestamp()),
|
||||||
|
seq=self.seq,
|
||||||
|
channels=[
|
||||||
|
ChannelState(id=int(entry), pts=state.pts)
|
||||||
|
for entry, state in self.map.items()
|
||||||
|
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return (self.map.get(ENTRY_ACCOUNT) or NO_PTS) == NO_PTS
|
||||||
|
|
||||||
|
def check_deadlines(self) -> float:
|
||||||
|
now = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
|
if self.getting_diff_for:
|
||||||
|
return now
|
||||||
|
|
||||||
|
default_deadline = next_updates_deadline()
|
||||||
|
|
||||||
|
if self.possible_gaps:
|
||||||
|
deadline = min(
|
||||||
|
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
|
||||||
|
)
|
||||||
|
elif self.next_deadline in self.map:
|
||||||
|
deadline = min(default_deadline, self.map[self.next_deadline].deadline)
|
||||||
|
else:
|
||||||
|
deadline = default_deadline
|
||||||
|
|
||||||
|
if now >= deadline:
|
||||||
|
self.getting_diff_for.update(
|
||||||
|
entry
|
||||||
|
for entry, gap in self.possible_gaps.items()
|
||||||
|
if now >= gap.deadline
|
||||||
|
)
|
||||||
|
self.getting_diff_for.update(
|
||||||
|
entry for entry, state in self.map.items() if now >= state.deadline
|
||||||
|
)
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Deadlines met, now getting diff for %r", self.getting_diff_for
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in self.getting_diff_for:
|
||||||
|
self.possible_gaps.pop(entry, None)
|
||||||
|
|
||||||
|
return deadline
|
||||||
|
|
||||||
|
def reset_deadlines(self, entries: Set[Entry], deadline: float) -> None:
|
||||||
|
if not entries:
|
||||||
|
return
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if entry not in self.map:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Called reset_deadline on an entry for which we do not have state"
|
||||||
|
)
|
||||||
|
self.map[entry].deadline = deadline
|
||||||
|
|
||||||
|
if self.next_deadline in entries:
|
||||||
|
self.next_deadline = min(
|
||||||
|
self.map.items(), key=lambda entry_state: entry_state[1].deadline
|
||||||
|
)[0]
|
||||||
|
elif (
|
||||||
|
self.next_deadline in self.map
|
||||||
|
and deadline < self.map[self.next_deadline].deadline
|
||||||
|
):
|
||||||
|
self.next_deadline = entry
|
||||||
|
|
||||||
|
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
|
||||||
|
self.reset_deadlines(
|
||||||
|
{channel_id},
|
||||||
|
asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_state(self, state: abcs.updates.State) -> None:
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Setting state %s", state)
|
||||||
|
|
||||||
|
deadline = next_updates_deadline()
|
||||||
|
assert isinstance(state, types.updates.State)
|
||||||
|
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
|
||||||
|
self.map[ENTRY_SECRET] = State(state.qts, deadline)
|
||||||
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
state.date, tz=datetime.timezone.utc
|
||||||
|
)
|
||||||
|
self.seq = state.seq
|
||||||
|
|
||||||
|
def try_set_channel_state(self, id: int, pts: int) -> None:
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Trying to set channel state for %r: %r", id, pts)
|
||||||
|
|
||||||
|
if id not in self.map:
|
||||||
|
self.map[id] = State(pts=pts, deadline=next_updates_deadline())
|
||||||
|
|
||||||
|
def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
|
||||||
|
if entry not in self.map:
|
||||||
|
if entry in self.possible_gaps:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Should not have a possible_gap for an entry not in the state map"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Marking %r as needing difference because %s", entry, reason)
|
||||||
|
self.getting_diff_for.add(entry)
|
||||||
|
self.possible_gaps.pop(entry, None)
|
||||||
|
|
||||||
|
def end_get_diff(self, entry: Entry) -> None:
|
||||||
|
try:
|
||||||
|
self.getting_diff_for.remove(entry)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Called end_get_diff on an entry which was not getting diff for"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reset_deadlines({entry}, next_updates_deadline())
|
||||||
|
assert (
|
||||||
|
entry not in self.possible_gaps
|
||||||
|
), "gaps shouldn't be created while getting difference"
|
||||||
|
|
||||||
|
def ensure_known_peer_hashes(
|
||||||
|
self,
|
||||||
|
updates: abcs.Updates,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> None:
|
||||||
|
if not chat_hashes.extend_from_updates(updates):
|
||||||
|
can_recover = (
|
||||||
|
not isinstance(updates, types.UpdateShort)
|
||||||
|
or pts_info_from_update(updates.update) is not None
|
||||||
|
)
|
||||||
|
if can_recover:
|
||||||
|
raise Gap
|
||||||
|
|
||||||
|
# https://core.telegram.org/api/updates
|
||||||
|
def process_updates(
|
||||||
|
self,
|
||||||
|
updates: abcs.Updates,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||||
|
result: List[abcs.Update] = []
|
||||||
|
combined = adapt(updates, chat_hashes)
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Processing updates with seq = %r, seq_start = %r, date = %r: %s",
|
||||||
|
combined.seq,
|
||||||
|
combined.seq_start,
|
||||||
|
combined.date,
|
||||||
|
updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
if combined.seq_start != NO_SEQ:
|
||||||
|
if self.seq + 1 > combined.seq_start:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Skipping updates as they should have already been handled"
|
||||||
|
)
|
||||||
|
return result, combined.users, combined.chats
|
||||||
|
elif self.seq + 1 < combined.seq_start:
|
||||||
|
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
|
||||||
|
raise Gap
|
||||||
|
|
||||||
|
def update_sort_key(update: abcs.Update) -> int:
|
||||||
|
pts = pts_info_from_update(update)
|
||||||
|
return pts.pts - pts.pts_count if pts else 0
|
||||||
|
|
||||||
|
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
||||||
|
|
||||||
|
any_pts_applied = False
|
||||||
|
reset_deadlines_for = set()
|
||||||
|
for update in sorted_updates:
|
||||||
|
entry, applied = self.apply_pts_info(update)
|
||||||
|
if entry is not None:
|
||||||
|
reset_deadlines_for.add(entry)
|
||||||
|
if applied is not None:
|
||||||
|
result.append(applied)
|
||||||
|
any_pts_applied |= entry is not None
|
||||||
|
|
||||||
|
self.reset_deadlines(reset_deadlines_for, next_updates_deadline())
|
||||||
|
|
||||||
|
if any_pts_applied:
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Updating seq as local pts was updated too")
|
||||||
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
combined.date, tz=datetime.timezone.utc
|
||||||
|
)
|
||||||
|
if combined.seq != NO_SEQ:
|
||||||
|
self.seq = combined.seq
|
||||||
|
|
||||||
|
if self.possible_gaps:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Trying to re-apply %r possible gaps", len(self.possible_gaps)
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in list(self.possible_gaps.keys()):
|
||||||
|
self.possible_gaps[key].updates.sort(key=update_sort_key)
|
||||||
|
|
||||||
|
for _ in range(len(self.possible_gaps[key].updates)):
|
||||||
|
update = self.possible_gaps[key].updates.pop(0)
|
||||||
|
_, applied = self.apply_pts_info(update)
|
||||||
|
if applied is not None:
|
||||||
|
result.append(applied)
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Resolved gap with %r: %s",
|
||||||
|
pts_info_from_update(applied),
|
||||||
|
applied,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.possible_gaps = {
|
||||||
|
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, combined.users, combined.chats
|
||||||
|
|
||||||
|
def apply_pts_info(
|
||||||
|
self,
|
||||||
|
update: abcs.Update,
|
||||||
|
) -> Tuple[Optional[Entry], Optional[abcs.Update]]:
|
||||||
|
if isinstance(update, types.UpdateChannelTooLong):
|
||||||
|
self.try_begin_get_diff(update.channel_id, "received updateChannelTooLong")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
pts = pts_info_from_update(update)
|
||||||
|
if not pts:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"No pts in update, so it can be applied in any order: %s", update
|
||||||
|
)
|
||||||
|
return None, update
|
||||||
|
|
||||||
|
if pts.entry in self.getting_diff_for:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Skipping update with %r as its difference is being fetched", pts
|
||||||
|
)
|
||||||
|
return pts.entry, None
|
||||||
|
|
||||||
|
if state := self.map.get(pts.entry):
|
||||||
|
local_pts = state.pts
|
||||||
|
if local_pts + pts.pts_count > pts.pts:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Skipping update since local pts %r > %r: %s",
|
||||||
|
local_pts,
|
||||||
|
pts,
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
return pts.entry, None
|
||||||
|
elif local_pts + pts.pts_count < pts.pts:
|
||||||
|
# TODO store chats too?
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Possible gap since local pts %r < %r: %s",
|
||||||
|
local_pts,
|
||||||
|
pts,
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
if pts.entry not in self.possible_gaps:
|
||||||
|
self.possible_gaps[pts.entry] = PossibleGap(
|
||||||
|
deadline=asyncio.get_running_loop().time()
|
||||||
|
+ POSSIBLE_GAP_TIMEOUT,
|
||||||
|
updates=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.possible_gaps[pts.entry].updates.append(update)
|
||||||
|
return pts.entry, None
|
||||||
|
else:
|
||||||
|
if __debug__:
|
||||||
|
self._trace(
|
||||||
|
"Applying update pts since local pts %r = %r: %s",
|
||||||
|
local_pts,
|
||||||
|
pts,
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pts.entry not in self.map:
|
||||||
|
self.map[pts.entry] = State(
|
||||||
|
pts=0,
|
||||||
|
deadline=next_updates_deadline(),
|
||||||
|
)
|
||||||
|
self.map[pts.entry].pts = pts.pts
|
||||||
|
|
||||||
|
return pts.entry, update
|
||||||
|
|
||||||
|
def get_difference(self) -> Optional[Request[abcs.updates.Difference]]:
|
||||||
|
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
||||||
|
if entry in self.getting_diff_for:
|
||||||
|
if entry not in self.map:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Should not try to get difference for an entry without known state"
|
||||||
|
)
|
||||||
|
|
||||||
|
gd = functions.updates.get_difference(
|
||||||
|
pts=self.map[ENTRY_ACCOUNT].pts,
|
||||||
|
pts_total_limit=None,
|
||||||
|
date=int(self.date.timestamp()),
|
||||||
|
qts=self.map[ENTRY_SECRET].pts
|
||||||
|
if ENTRY_SECRET in self.map
|
||||||
|
else NO_SEQ,
|
||||||
|
)
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Requesting account difference %s", gd)
|
||||||
|
return gd
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply_difference(
|
||||||
|
self,
|
||||||
|
diff: abcs.updates.Difference,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Applying account difference %s", diff)
|
||||||
|
|
||||||
|
finish: bool
|
||||||
|
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
|
||||||
|
if isinstance(diff, types.updates.DifferenceEmpty):
|
||||||
|
finish = True
|
||||||
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
diff.date, tz=datetime.timezone.utc
|
||||||
|
)
|
||||||
|
self.seq = diff.seq
|
||||||
|
result = [], [], []
|
||||||
|
elif isinstance(diff, types.updates.Difference):
|
||||||
|
chat_hashes.extend(diff.users, diff.chats)
|
||||||
|
finish = True
|
||||||
|
result = self.apply_difference_type(diff, chat_hashes)
|
||||||
|
elif isinstance(diff, types.updates.DifferenceSlice):
|
||||||
|
chat_hashes.extend(diff.users, diff.chats)
|
||||||
|
finish = False
|
||||||
|
result = self.apply_difference_type(
|
||||||
|
types.updates.Difference(
|
||||||
|
new_messages=diff.new_messages,
|
||||||
|
new_encrypted_messages=diff.new_encrypted_messages,
|
||||||
|
other_updates=diff.other_updates,
|
||||||
|
chats=diff.chats,
|
||||||
|
users=diff.users,
|
||||||
|
state=diff.intermediate_state,
|
||||||
|
),
|
||||||
|
chat_hashes,
|
||||||
|
)
|
||||||
|
elif isinstance(diff, types.updates.DifferenceTooLong):
|
||||||
|
finish = True
|
||||||
|
self.map[ENTRY_ACCOUNT].pts = diff.pts
|
||||||
|
result = [], [], []
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unexpected case")
|
||||||
|
|
||||||
|
if finish:
|
||||||
|
account = ENTRY_ACCOUNT in self.getting_diff_for
|
||||||
|
secret = ENTRY_SECRET in self.getting_diff_for
|
||||||
|
|
||||||
|
if not account and not secret:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Should not be applying the difference when neither account or secret was diff was active"
|
||||||
|
)
|
||||||
|
|
||||||
|
if account:
|
||||||
|
self.end_get_diff(ENTRY_ACCOUNT)
|
||||||
|
if secret:
|
||||||
|
self.end_get_diff(ENTRY_SECRET)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def apply_difference_type(
|
||||||
|
self,
|
||||||
|
diff: types.updates.Difference,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||||
|
state = diff.state
|
||||||
|
assert isinstance(state, types.updates.State)
|
||||||
|
self.map[ENTRY_ACCOUNT].pts = state.pts
|
||||||
|
self.map[ENTRY_SECRET].pts = state.qts
|
||||||
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
state.date, tz=datetime.timezone.utc
|
||||||
|
)
|
||||||
|
self.seq = state.seq
|
||||||
|
|
||||||
|
updates, users, chats = self.process_updates(
|
||||||
|
types.Updates(
|
||||||
|
updates=diff.other_updates,
|
||||||
|
users=diff.users,
|
||||||
|
chats=diff.chats,
|
||||||
|
date=int(epoch().timestamp()),
|
||||||
|
seq=NO_SEQ,
|
||||||
|
),
|
||||||
|
chat_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
updates.extend(
|
||||||
|
types.UpdateNewMessage(
|
||||||
|
message=m,
|
||||||
|
pts=NO_PTS,
|
||||||
|
pts_count=0,
|
||||||
|
)
|
||||||
|
for m in diff.new_messages
|
||||||
|
)
|
||||||
|
updates.extend(
|
||||||
|
types.UpdateNewEncryptedMessage(
|
||||||
|
message=m,
|
||||||
|
qts=NO_PTS,
|
||||||
|
)
|
||||||
|
for m in diff.new_encrypted_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return updates, users, chats
|
||||||
|
|
||||||
|
def get_channel_difference(
|
||||||
|
self,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> Optional[Request[abcs.updates.ChannelDifference]]:
|
||||||
|
for entry in self.getting_diff_for:
|
||||||
|
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
||||||
|
id = int(entry)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
packed = chat_hashes.get(id)
|
||||||
|
if packed:
|
||||||
|
assert packed.access_hash is not None
|
||||||
|
channel = types.InputChannel(
|
||||||
|
channel_id=packed.id,
|
||||||
|
access_hash=packed.access_hash,
|
||||||
|
)
|
||||||
|
if state := self.map.get(entry):
|
||||||
|
gd = functions.updates.get_channel_difference(
|
||||||
|
force=False,
|
||||||
|
channel=channel,
|
||||||
|
filter=types.ChannelMessagesFilterEmpty(),
|
||||||
|
pts=state.pts,
|
||||||
|
limit=BOT_CHANNEL_DIFF_LIMIT
|
||||||
|
if chat_hashes.is_self_bot
|
||||||
|
else USER_CHANNEL_DIFF_LIMIT,
|
||||||
|
)
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Requesting channel difference %s", gd)
|
||||||
|
return gd
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Should not try to get difference for an entry without known state"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.end_get_diff(entry)
|
||||||
|
self.map.pop(entry, None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply_channel_difference(
|
||||||
|
self,
|
||||||
|
channel_id: int,
|
||||||
|
diff: abcs.updates.ChannelDifference,
|
||||||
|
chat_hashes: ChatHashCache,
|
||||||
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||||
|
entry: Entry = channel_id
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Applying channel difference for %r: %s", entry, diff)
|
||||||
|
|
||||||
|
self.possible_gaps.pop(entry, None)
|
||||||
|
|
||||||
|
if isinstance(diff, types.updates.ChannelDifferenceEmpty):
|
||||||
|
assert diff.final
|
||||||
|
self.end_get_diff(entry)
|
||||||
|
self.map[entry].pts = diff.pts
|
||||||
|
return [], [], []
|
||||||
|
elif isinstance(diff, types.updates.ChannelDifferenceTooLong):
|
||||||
|
chat_hashes.extend(diff.users, diff.chats)
|
||||||
|
|
||||||
|
assert diff.final
|
||||||
|
if isinstance(diff.dialog, types.Dialog):
|
||||||
|
assert diff.dialog.pts is not None
|
||||||
|
self.map[entry].pts = diff.dialog.pts
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unexpected type on ChannelDifferenceTooLong")
|
||||||
|
self.reset_channel_deadline(channel_id, diff.timeout)
|
||||||
|
return [], [], []
|
||||||
|
elif isinstance(diff, types.updates.ChannelDifference):
|
||||||
|
chat_hashes.extend(diff.users, diff.chats)
|
||||||
|
|
||||||
|
if diff.final:
|
||||||
|
self.end_get_diff(entry)
|
||||||
|
|
||||||
|
self.map[entry].pts = diff.pts
|
||||||
|
updates, users, chats = self.process_updates(
|
||||||
|
types.Updates(
|
||||||
|
updates=diff.other_updates,
|
||||||
|
users=diff.users,
|
||||||
|
chats=diff.chats,
|
||||||
|
date=int(epoch().timestamp()),
|
||||||
|
seq=NO_SEQ,
|
||||||
|
),
|
||||||
|
chat_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
updates.extend(
|
||||||
|
types.UpdateNewChannelMessage(
|
||||||
|
message=m,
|
||||||
|
pts=NO_PTS,
|
||||||
|
pts_count=0,
|
||||||
|
)
|
||||||
|
for m in diff.new_messages
|
||||||
|
)
|
||||||
|
self.reset_channel_deadline(channel_id, None)
|
||||||
|
|
||||||
|
return updates, users, chats
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unexpected case")
|
||||||
|
|
||||||
|
def end_channel_difference(
|
||||||
|
self, channel_id: int, reason: PrematureEndReason
|
||||||
|
) -> None:
|
||||||
|
entry: Entry = channel_id
|
||||||
|
if __debug__:
|
||||||
|
self._trace("Ending channel difference for %r because %s", entry, reason)
|
||||||
|
|
||||||
|
if reason == PrematureEndReason.TEMPORARY_SERVER_ISSUES:
|
||||||
|
self.possible_gaps.pop(entry, None)
|
||||||
|
self.end_get_diff(entry)
|
||||||
|
elif reason == PrematureEndReason.BANNED:
|
||||||
|
self.possible_gaps.pop(entry, None)
|
||||||
|
self.end_get_diff(entry)
|
||||||
|
del self.map[entry]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown reason to end channel difference")
|
Loading…
Reference in New Issue
Block a user