mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-24 07:20:42 +03:00
Implement MessageBox
This commit is contained in:
parent
c77c10b48f
commit
69d7941852
|
@ -1,8 +1,8 @@
|
|||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ...tl import abcs, types
|
||||
from ..chat.hash_cache import ChatHashCache
|
||||
from .defs import ACCOUNT_WIDE, NO_SEQ, SECRET_CHATS, Gap
|
||||
from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo
|
||||
|
||||
|
||||
def updates_(updates: types.Updates) -> types.UpdatesCombined:
|
||||
|
@ -180,64 +180,64 @@ def message_channel_id(message: abcs.Message) -> Optional[int]:
|
|||
return peer.channel_id if isinstance(peer, types.PeerChannel) else None
|
||||
|
||||
|
||||
def pts_info_from_update(update: abcs.Update) -> Optional[Tuple[int | str, int, int]]:
|
||||
def pts_info_from_update(update: abcs.Update) -> Optional[PtsInfo]:
|
||||
if isinstance(update, types.UpdateNewMessage):
|
||||
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateDeleteMessages):
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateNewEncryptedMessage):
|
||||
return SECRET_CHATS, update.qts, 1
|
||||
return PtsInfo(ENTRY_SECRET, update.qts, 1)
|
||||
elif isinstance(update, types.UpdateReadHistoryInbox):
|
||||
assert not isinstance(update.peer, types.PeerChannel)
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateReadHistoryOutbox):
|
||||
assert not isinstance(update.peer, types.PeerChannel)
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateWebPage):
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateReadMessagesContents):
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateChannelTooLong):
|
||||
if update.pts is not None:
|
||||
return update.channel_id, update.pts, 0
|
||||
return PtsInfo(update.channel_id, update.pts, 0)
|
||||
else:
|
||||
return None
|
||||
elif isinstance(update, types.UpdateNewChannelMessage):
|
||||
channel_id = message_channel_id(update.message)
|
||||
if channel_id is not None:
|
||||
return channel_id, update.pts, update.pts_count
|
||||
return PtsInfo(channel_id, update.pts, update.pts_count)
|
||||
else:
|
||||
return None
|
||||
elif isinstance(update, types.UpdateReadChannelInbox):
|
||||
return update.channel_id, update.pts, 0
|
||||
return PtsInfo(update.channel_id, update.pts, 0)
|
||||
elif isinstance(update, types.UpdateDeleteChannelMessages):
|
||||
return update.channel_id, update.pts, update.pts_count
|
||||
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateEditChannelMessage):
|
||||
channel_id = message_channel_id(update.message)
|
||||
if channel_id is not None:
|
||||
return channel_id, update.pts, update.pts_count
|
||||
return PtsInfo(channel_id, update.pts, update.pts_count)
|
||||
else:
|
||||
return None
|
||||
elif isinstance(update, types.UpdateEditMessage):
|
||||
assert not isinstance(message_peer(update.message), types.PeerChannel)
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateChannelWebPage):
|
||||
return update.channel_id, update.pts, update.pts_count
|
||||
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateFolderPeers):
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdatePinnedMessages):
|
||||
assert not isinstance(update.peer, types.PeerChannel)
|
||||
return ACCOUNT_WIDE, update.pts, update.pts_count
|
||||
return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdatePinnedChannelMessages):
|
||||
return update.channel_id, update.pts, update.pts_count
|
||||
return PtsInfo(update.channel_id, update.pts, update.pts_count)
|
||||
elif isinstance(update, types.UpdateChatParticipant):
|
||||
return SECRET_CHATS, update.qts, 0
|
||||
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||
elif isinstance(update, types.UpdateChannelParticipant):
|
||||
return SECRET_CHATS, update.qts, 0
|
||||
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||
elif isinstance(update, types.UpdateBotStopped):
|
||||
return SECRET_CHATS, update.qts, 0
|
||||
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||
elif isinstance(update, types.UpdateBotChatInviteRequester):
|
||||
return SECRET_CHATS, update.qts, 0
|
||||
return PtsInfo(ENTRY_SECRET, update.qts, 0)
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -1,3 +1,133 @@
|
|||
import logging
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from ...tl import abcs
|
||||
|
||||
|
||||
class DataCenter:
|
||||
__slots__ = ("id", "ip", "port", "auth")
|
||||
|
||||
def __init__(self, *, id: int, ip: str, port: int, auth: Optional[bytes]) -> None:
|
||||
self.id = id
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.auth = auth
|
||||
|
||||
|
||||
class User:
|
||||
__slots__ = ("id", "dc", "bot")
|
||||
|
||||
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
||||
self.id = id
|
||||
self.dc = dc
|
||||
self.bot = bot
|
||||
|
||||
|
||||
class ChannelState:
|
||||
__slots__ = ("id", "pts")
|
||||
|
||||
def __init__(self, *, id: int, pts: int) -> None:
|
||||
self.id = id
|
||||
self.pts = pts
|
||||
|
||||
|
||||
class UpdateState:
|
||||
__slots__ = (
|
||||
"pts",
|
||||
"qts",
|
||||
"date",
|
||||
"seq",
|
||||
"channels",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pts: int,
|
||||
qts: int,
|
||||
date: int,
|
||||
seq: int,
|
||||
channels: List[ChannelState],
|
||||
) -> None:
|
||||
self.pts = pts
|
||||
self.qts = qts
|
||||
self.date = date
|
||||
self.seq = seq
|
||||
self.channels = channels
|
||||
|
||||
|
||||
class Session:
|
||||
__slots__ = ("dcs", "user", "state")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dcs: List[DataCenter],
|
||||
user: Optional[User],
|
||||
state: Optional[UpdateState],
|
||||
):
|
||||
self.dcs = dcs
|
||||
self.user = user
|
||||
self.state = state
|
||||
|
||||
|
||||
class PtsInfo:
|
||||
__slots__ = ("pts", "pts_count", "entry")
|
||||
|
||||
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None:
|
||||
self.pts = pts
|
||||
self.pts_count = pts_count
|
||||
self.entry = entry
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
|
||||
)
|
||||
|
||||
|
||||
class State:
|
||||
__slots__ = ("pts", "deadline")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pts: int,
|
||||
deadline: float,
|
||||
) -> None:
|
||||
self.pts = pts
|
||||
self.deadline = deadline
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"State(pts={self.pts}, deadline={self.deadline})"
|
||||
|
||||
|
||||
class PossibleGap:
|
||||
__slots__ = ("deadline", "updates")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
deadline: float,
|
||||
updates: List[abcs.Update],
|
||||
) -> None:
|
||||
self.deadline = deadline
|
||||
self.updates = updates
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
|
||||
)
|
||||
|
||||
|
||||
class PrematureEndReason(Enum):
|
||||
TEMPORARY_SERVER_ISSUES = "tmp"
|
||||
BANNED = "ban"
|
||||
|
||||
|
||||
class Gap(ValueError):
|
||||
def __repr__(self) -> str:
|
||||
return "Gap()"
|
||||
|
||||
|
||||
NO_SEQ = 0
|
||||
|
||||
NO_PTS = 0
|
||||
|
@ -11,9 +141,10 @@ POSSIBLE_GAP_TIMEOUT = 0.5
|
|||
# https://core.telegram.org/api/updates
|
||||
NO_UPDATES_TIMEOUT = 15 * 60
|
||||
|
||||
ACCOUNT_WIDE = "ACCOUNT"
|
||||
SECRET_CHATS = "SECRET"
|
||||
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
||||
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
||||
Entry = Union[Literal["ACCOUNT"], Literal["SECRET"], int]
|
||||
|
||||
|
||||
class Gap(ValueError):
|
||||
pass
|
||||
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
||||
# We don't define a name for this as libraries shouldn't do that though.
|
||||
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
||||
|
|
637
client/src/telethon/_impl/session/message_box/messagebox.py
Normal file
637
client/src/telethon/_impl/session/message_box/messagebox.py
Normal file
|
@ -0,0 +1,637 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from ...tl import abcs, functions, types
|
||||
from ...tl.core.request import Request
|
||||
from ..chat.hash_cache import ChatHashCache
|
||||
from .adaptor import adapt, pts_info_from_update
|
||||
from .defs import (
|
||||
BOT_CHANNEL_DIFF_LIMIT,
|
||||
ENTRY_ACCOUNT,
|
||||
ENTRY_SECRET,
|
||||
LOG_LEVEL_TRACE,
|
||||
NO_PTS,
|
||||
NO_SEQ,
|
||||
NO_UPDATES_TIMEOUT,
|
||||
POSSIBLE_GAP_TIMEOUT,
|
||||
USER_CHANNEL_DIFF_LIMIT,
|
||||
ChannelState,
|
||||
Entry,
|
||||
Gap,
|
||||
PossibleGap,
|
||||
PrematureEndReason,
|
||||
State,
|
||||
UpdateState,
|
||||
)
|
||||
|
||||
|
||||
def next_updates_deadline() -> float:
|
||||
return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT
|
||||
|
||||
|
||||
def epoch() -> datetime.datetime:
|
||||
return datetime.datetime(*time.gmtime(0)[:6]).replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
# https://core.telegram.org/api/updates#message-related-event-sequences.
|
||||
class MessageBox:
|
||||
__slots__ = (
|
||||
"_log",
|
||||
"map",
|
||||
"date",
|
||||
"seq",
|
||||
"possible_gaps",
|
||||
"getting_diff_for",
|
||||
"next_deadline",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
log: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._log = log or logging.getLogger("telethon.messagebox")
|
||||
self.map: Dict[Entry, State] = {}
|
||||
self.date = epoch()
|
||||
self.seq = NO_SEQ
|
||||
self.possible_gaps: Dict[Entry, PossibleGap] = {}
|
||||
self.getting_diff_for: Set[Entry] = set()
|
||||
self.next_deadline: Optional[Entry] = None
|
||||
|
||||
if __debug__:
|
||||
self._trace("MessageBox initialized")
|
||||
|
||||
def _trace(self, msg: str, *args: object) -> None:
|
||||
# Calls to trace can't really be removed beforehand without some dark magic.
|
||||
# So every call to trace is prefixed with `if __debug__`` instead, to remove
|
||||
# it when using `python -O`. Probably unnecessary, but it's nice to avoid
|
||||
# paying the cost for something that is not used.
|
||||
self._log.log(
|
||||
LOG_LEVEL_TRACE,
|
||||
"Current MessageBox state: seq = %r, date = %s, map = %r",
|
||||
self.seq,
|
||||
self.date.isoformat(),
|
||||
self.map,
|
||||
)
|
||||
self._log.log(LOG_LEVEL_TRACE, msg, *args)
|
||||
|
||||
def load(self, state: UpdateState) -> None:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Loading MessageBox with state = %r",
|
||||
state,
|
||||
)
|
||||
|
||||
deadline = next_updates_deadline()
|
||||
|
||||
self.map.clear()
|
||||
if state.pts != NO_SEQ:
|
||||
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
|
||||
if state.qts != NO_SEQ:
|
||||
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
|
||||
self.map.update(
|
||||
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
|
||||
)
|
||||
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.seq = state.seq
|
||||
self.possible_gaps.clear()
|
||||
self.getting_diff_for.clear()
|
||||
self.next_deadline = ENTRY_ACCOUNT
|
||||
|
||||
def session_state(self) -> UpdateState:
|
||||
return UpdateState(
|
||||
pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else NO_PTS,
|
||||
qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_PTS,
|
||||
date=int(self.date.timestamp()),
|
||||
seq=self.seq,
|
||||
channels=[
|
||||
ChannelState(id=int(entry), pts=state.pts)
|
||||
for entry, state in self.map.items()
|
||||
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET)
|
||||
],
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (self.map.get(ENTRY_ACCOUNT) or NO_PTS) == NO_PTS
|
||||
|
||||
def check_deadlines(self) -> float:
|
||||
now = asyncio.get_running_loop().time()
|
||||
|
||||
if self.getting_diff_for:
|
||||
return now
|
||||
|
||||
default_deadline = next_updates_deadline()
|
||||
|
||||
if self.possible_gaps:
|
||||
deadline = min(
|
||||
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
|
||||
)
|
||||
elif self.next_deadline in self.map:
|
||||
deadline = min(default_deadline, self.map[self.next_deadline].deadline)
|
||||
else:
|
||||
deadline = default_deadline
|
||||
|
||||
if now >= deadline:
|
||||
self.getting_diff_for.update(
|
||||
entry
|
||||
for entry, gap in self.possible_gaps.items()
|
||||
if now >= gap.deadline
|
||||
)
|
||||
self.getting_diff_for.update(
|
||||
entry for entry, state in self.map.items() if now >= state.deadline
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Deadlines met, now getting diff for %r", self.getting_diff_for
|
||||
)
|
||||
|
||||
for entry in self.getting_diff_for:
|
||||
self.possible_gaps.pop(entry, None)
|
||||
|
||||
return deadline
|
||||
|
||||
def reset_deadlines(self, entries: Set[Entry], deadline: float) -> None:
|
||||
if not entries:
|
||||
return
|
||||
|
||||
for entry in entries:
|
||||
if entry not in self.map:
|
||||
raise RuntimeError(
|
||||
"Called reset_deadline on an entry for which we do not have state"
|
||||
)
|
||||
self.map[entry].deadline = deadline
|
||||
|
||||
if self.next_deadline in entries:
|
||||
self.next_deadline = min(
|
||||
self.map.items(), key=lambda entry_state: entry_state[1].deadline
|
||||
)[0]
|
||||
elif (
|
||||
self.next_deadline in self.map
|
||||
and deadline < self.map[self.next_deadline].deadline
|
||||
):
|
||||
self.next_deadline = entry
|
||||
|
||||
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
|
||||
self.reset_deadlines(
|
||||
{channel_id},
|
||||
asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT),
|
||||
)
|
||||
|
||||
def set_state(self, state: abcs.updates.State) -> None:
|
||||
if __debug__:
|
||||
self._trace("Setting state %s", state)
|
||||
|
||||
deadline = next_updates_deadline()
|
||||
assert isinstance(state, types.updates.State)
|
||||
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
|
||||
self.map[ENTRY_SECRET] = State(state.qts, deadline)
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.seq = state.seq
|
||||
|
||||
def try_set_channel_state(self, id: int, pts: int) -> None:
|
||||
if __debug__:
|
||||
self._trace("Trying to set channel state for %r: %r", id, pts)
|
||||
|
||||
if id not in self.map:
|
||||
self.map[id] = State(pts=pts, deadline=next_updates_deadline())
|
||||
|
||||
def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
|
||||
if entry not in self.map:
|
||||
if entry in self.possible_gaps:
|
||||
raise RuntimeError(
|
||||
"Should not have a possible_gap for an entry not in the state map"
|
||||
)
|
||||
return
|
||||
|
||||
if __debug__:
|
||||
self._trace("Marking %r as needing difference because %s", entry, reason)
|
||||
self.getting_diff_for.add(entry)
|
||||
self.possible_gaps.pop(entry, None)
|
||||
|
||||
def end_get_diff(self, entry: Entry) -> None:
|
||||
try:
|
||||
self.getting_diff_for.remove(entry)
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Called end_get_diff on an entry which was not getting diff for"
|
||||
)
|
||||
|
||||
self.reset_deadlines({entry}, next_updates_deadline())
|
||||
assert (
|
||||
entry not in self.possible_gaps
|
||||
), "gaps shouldn't be created while getting difference"
|
||||
|
||||
def ensure_known_peer_hashes(
|
||||
self,
|
||||
updates: abcs.Updates,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> None:
|
||||
if not chat_hashes.extend_from_updates(updates):
|
||||
can_recover = (
|
||||
not isinstance(updates, types.UpdateShort)
|
||||
or pts_info_from_update(updates.update) is not None
|
||||
)
|
||||
if can_recover:
|
||||
raise Gap
|
||||
|
||||
# https://core.telegram.org/api/updates
|
||||
def process_updates(
|
||||
self,
|
||||
updates: abcs.Updates,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
result: List[abcs.Update] = []
|
||||
combined = adapt(updates, chat_hashes)
|
||||
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Processing updates with seq = %r, seq_start = %r, date = %r: %s",
|
||||
combined.seq,
|
||||
combined.seq_start,
|
||||
combined.date,
|
||||
updates,
|
||||
)
|
||||
|
||||
if combined.seq_start != NO_SEQ:
|
||||
if self.seq + 1 > combined.seq_start:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Skipping updates as they should have already been handled"
|
||||
)
|
||||
return result, combined.users, combined.chats
|
||||
elif self.seq + 1 < combined.seq_start:
|
||||
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
|
||||
raise Gap
|
||||
|
||||
def update_sort_key(update: abcs.Update) -> int:
|
||||
pts = pts_info_from_update(update)
|
||||
return pts.pts - pts.pts_count if pts else 0
|
||||
|
||||
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
||||
|
||||
any_pts_applied = False
|
||||
reset_deadlines_for = set()
|
||||
for update in sorted_updates:
|
||||
entry, applied = self.apply_pts_info(update)
|
||||
if entry is not None:
|
||||
reset_deadlines_for.add(entry)
|
||||
if applied is not None:
|
||||
result.append(applied)
|
||||
any_pts_applied |= entry is not None
|
||||
|
||||
self.reset_deadlines(reset_deadlines_for, next_updates_deadline())
|
||||
|
||||
if any_pts_applied:
|
||||
if __debug__:
|
||||
self._trace("Updating seq as local pts was updated too")
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
combined.date, tz=datetime.timezone.utc
|
||||
)
|
||||
if combined.seq != NO_SEQ:
|
||||
self.seq = combined.seq
|
||||
|
||||
if self.possible_gaps:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Trying to re-apply %r possible gaps", len(self.possible_gaps)
|
||||
)
|
||||
|
||||
for key in list(self.possible_gaps.keys()):
|
||||
self.possible_gaps[key].updates.sort(key=update_sort_key)
|
||||
|
||||
for _ in range(len(self.possible_gaps[key].updates)):
|
||||
update = self.possible_gaps[key].updates.pop(0)
|
||||
_, applied = self.apply_pts_info(update)
|
||||
if applied is not None:
|
||||
result.append(applied)
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Resolved gap with %r: %s",
|
||||
pts_info_from_update(applied),
|
||||
applied,
|
||||
)
|
||||
|
||||
self.possible_gaps = {
|
||||
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
|
||||
}
|
||||
|
||||
return result, combined.users, combined.chats
|
||||
|
||||
def apply_pts_info(
|
||||
self,
|
||||
update: abcs.Update,
|
||||
) -> Tuple[Optional[Entry], Optional[abcs.Update]]:
|
||||
if isinstance(update, types.UpdateChannelTooLong):
|
||||
self.try_begin_get_diff(update.channel_id, "received updateChannelTooLong")
|
||||
return None, None
|
||||
|
||||
pts = pts_info_from_update(update)
|
||||
if not pts:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"No pts in update, so it can be applied in any order: %s", update
|
||||
)
|
||||
return None, update
|
||||
|
||||
if pts.entry in self.getting_diff_for:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Skipping update with %r as its difference is being fetched", pts
|
||||
)
|
||||
return pts.entry, None
|
||||
|
||||
if state := self.map.get(pts.entry):
|
||||
local_pts = state.pts
|
||||
if local_pts + pts.pts_count > pts.pts:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Skipping update since local pts %r > %r: %s",
|
||||
local_pts,
|
||||
pts,
|
||||
update,
|
||||
)
|
||||
return pts.entry, None
|
||||
elif local_pts + pts.pts_count < pts.pts:
|
||||
# TODO store chats too?
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Possible gap since local pts %r < %r: %s",
|
||||
local_pts,
|
||||
pts,
|
||||
update,
|
||||
)
|
||||
if pts.entry not in self.possible_gaps:
|
||||
self.possible_gaps[pts.entry] = PossibleGap(
|
||||
deadline=asyncio.get_running_loop().time()
|
||||
+ POSSIBLE_GAP_TIMEOUT,
|
||||
updates=[],
|
||||
)
|
||||
|
||||
self.possible_gaps[pts.entry].updates.append(update)
|
||||
return pts.entry, None
|
||||
else:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"Applying update pts since local pts %r = %r: %s",
|
||||
local_pts,
|
||||
pts,
|
||||
update,
|
||||
)
|
||||
|
||||
if pts.entry not in self.map:
|
||||
self.map[pts.entry] = State(
|
||||
pts=0,
|
||||
deadline=next_updates_deadline(),
|
||||
)
|
||||
self.map[pts.entry].pts = pts.pts
|
||||
|
||||
return pts.entry, update
|
||||
|
||||
def get_difference(self) -> Optional[Request[abcs.updates.Difference]]:
|
||||
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
||||
if entry in self.getting_diff_for:
|
||||
if entry not in self.map:
|
||||
raise RuntimeError(
|
||||
"Should not try to get difference for an entry without known state"
|
||||
)
|
||||
|
||||
gd = functions.updates.get_difference(
|
||||
pts=self.map[ENTRY_ACCOUNT].pts,
|
||||
pts_total_limit=None,
|
||||
date=int(self.date.timestamp()),
|
||||
qts=self.map[ENTRY_SECRET].pts
|
||||
if ENTRY_SECRET in self.map
|
||||
else NO_SEQ,
|
||||
)
|
||||
if __debug__:
|
||||
self._trace("Requesting account difference %s", gd)
|
||||
return gd
|
||||
|
||||
return None
|
||||
|
||||
def apply_difference(
|
||||
self,
|
||||
diff: abcs.updates.Difference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
if __debug__:
|
||||
self._trace("Applying account difference %s", diff)
|
||||
|
||||
finish: bool
|
||||
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
|
||||
if isinstance(diff, types.updates.DifferenceEmpty):
|
||||
finish = True
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
diff.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.seq = diff.seq
|
||||
result = [], [], []
|
||||
elif isinstance(diff, types.updates.Difference):
|
||||
chat_hashes.extend(diff.users, diff.chats)
|
||||
finish = True
|
||||
result = self.apply_difference_type(diff, chat_hashes)
|
||||
elif isinstance(diff, types.updates.DifferenceSlice):
|
||||
chat_hashes.extend(diff.users, diff.chats)
|
||||
finish = False
|
||||
result = self.apply_difference_type(
|
||||
types.updates.Difference(
|
||||
new_messages=diff.new_messages,
|
||||
new_encrypted_messages=diff.new_encrypted_messages,
|
||||
other_updates=diff.other_updates,
|
||||
chats=diff.chats,
|
||||
users=diff.users,
|
||||
state=diff.intermediate_state,
|
||||
),
|
||||
chat_hashes,
|
||||
)
|
||||
elif isinstance(diff, types.updates.DifferenceTooLong):
|
||||
finish = True
|
||||
self.map[ENTRY_ACCOUNT].pts = diff.pts
|
||||
result = [], [], []
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
if finish:
|
||||
account = ENTRY_ACCOUNT in self.getting_diff_for
|
||||
secret = ENTRY_SECRET in self.getting_diff_for
|
||||
|
||||
if not account and not secret:
|
||||
raise RuntimeError(
|
||||
"Should not be applying the difference when neither account or secret was diff was active"
|
||||
)
|
||||
|
||||
if account:
|
||||
self.end_get_diff(ENTRY_ACCOUNT)
|
||||
if secret:
|
||||
self.end_get_diff(ENTRY_SECRET)
|
||||
|
||||
return result
|
||||
|
||||
def apply_difference_type(
|
||||
self,
|
||||
diff: types.updates.Difference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
state = diff.state
|
||||
assert isinstance(state, types.updates.State)
|
||||
self.map[ENTRY_ACCOUNT].pts = state.pts
|
||||
self.map[ENTRY_SECRET].pts = state.qts
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.seq = state.seq
|
||||
|
||||
updates, users, chats = self.process_updates(
|
||||
types.Updates(
|
||||
updates=diff.other_updates,
|
||||
users=diff.users,
|
||||
chats=diff.chats,
|
||||
date=int(epoch().timestamp()),
|
||||
seq=NO_SEQ,
|
||||
),
|
||||
chat_hashes,
|
||||
)
|
||||
|
||||
updates.extend(
|
||||
types.UpdateNewMessage(
|
||||
message=m,
|
||||
pts=NO_PTS,
|
||||
pts_count=0,
|
||||
)
|
||||
for m in diff.new_messages
|
||||
)
|
||||
updates.extend(
|
||||
types.UpdateNewEncryptedMessage(
|
||||
message=m,
|
||||
qts=NO_PTS,
|
||||
)
|
||||
for m in diff.new_encrypted_messages
|
||||
)
|
||||
|
||||
return updates, users, chats
|
||||
|
||||
def get_channel_difference(
|
||||
self,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Optional[Request[abcs.updates.ChannelDifference]]:
|
||||
for entry in self.getting_diff_for:
|
||||
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
||||
id = int(entry)
|
||||
break
|
||||
else:
|
||||
return None
|
||||
|
||||
packed = chat_hashes.get(id)
|
||||
if packed:
|
||||
assert packed.access_hash is not None
|
||||
channel = types.InputChannel(
|
||||
channel_id=packed.id,
|
||||
access_hash=packed.access_hash,
|
||||
)
|
||||
if state := self.map.get(entry):
|
||||
gd = functions.updates.get_channel_difference(
|
||||
force=False,
|
||||
channel=channel,
|
||||
filter=types.ChannelMessagesFilterEmpty(),
|
||||
pts=state.pts,
|
||||
limit=BOT_CHANNEL_DIFF_LIMIT
|
||||
if chat_hashes.is_self_bot
|
||||
else USER_CHANNEL_DIFF_LIMIT,
|
||||
)
|
||||
if __debug__:
|
||||
self._trace("Requesting channel difference %s", gd)
|
||||
return gd
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Should not try to get difference for an entry without known state"
|
||||
)
|
||||
else:
|
||||
self.end_get_diff(entry)
|
||||
self.map.pop(entry, None)
|
||||
return None
|
||||
|
||||
def apply_channel_difference(
|
||||
self,
|
||||
channel_id: int,
|
||||
diff: abcs.updates.ChannelDifference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
entry: Entry = channel_id
|
||||
if __debug__:
|
||||
self._trace("Applying channel difference for %r: %s", entry, diff)
|
||||
|
||||
self.possible_gaps.pop(entry, None)
|
||||
|
||||
if isinstance(diff, types.updates.ChannelDifferenceEmpty):
|
||||
assert diff.final
|
||||
self.end_get_diff(entry)
|
||||
self.map[entry].pts = diff.pts
|
||||
return [], [], []
|
||||
elif isinstance(diff, types.updates.ChannelDifferenceTooLong):
|
||||
chat_hashes.extend(diff.users, diff.chats)
|
||||
|
||||
assert diff.final
|
||||
if isinstance(diff.dialog, types.Dialog):
|
||||
assert diff.dialog.pts is not None
|
||||
self.map[entry].pts = diff.dialog.pts
|
||||
else:
|
||||
raise RuntimeError("unexpected type on ChannelDifferenceTooLong")
|
||||
self.reset_channel_deadline(channel_id, diff.timeout)
|
||||
return [], [], []
|
||||
elif isinstance(diff, types.updates.ChannelDifference):
|
||||
chat_hashes.extend(diff.users, diff.chats)
|
||||
|
||||
if diff.final:
|
||||
self.end_get_diff(entry)
|
||||
|
||||
self.map[entry].pts = diff.pts
|
||||
updates, users, chats = self.process_updates(
|
||||
types.Updates(
|
||||
updates=diff.other_updates,
|
||||
users=diff.users,
|
||||
chats=diff.chats,
|
||||
date=int(epoch().timestamp()),
|
||||
seq=NO_SEQ,
|
||||
),
|
||||
chat_hashes,
|
||||
)
|
||||
|
||||
updates.extend(
|
||||
types.UpdateNewChannelMessage(
|
||||
message=m,
|
||||
pts=NO_PTS,
|
||||
pts_count=0,
|
||||
)
|
||||
for m in diff.new_messages
|
||||
)
|
||||
self.reset_channel_deadline(channel_id, None)
|
||||
|
||||
return updates, users, chats
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
def end_channel_difference(
|
||||
self, channel_id: int, reason: PrematureEndReason
|
||||
) -> None:
|
||||
entry: Entry = channel_id
|
||||
if __debug__:
|
||||
self._trace("Ending channel difference for %r because %s", entry, reason)
|
||||
|
||||
if reason == PrematureEndReason.TEMPORARY_SERVER_ISSUES:
|
||||
self.possible_gaps.pop(entry, None)
|
||||
self.end_get_diff(entry)
|
||||
elif reason == PrematureEndReason.BANNED:
|
||||
self.possible_gaps.pop(entry, None)
|
||||
self.end_get_diff(entry)
|
||||
del self.map[entry]
|
||||
else:
|
||||
raise RuntimeError("Unknown reason to end channel difference")
|
Loading…
Reference in New Issue
Block a user