Begin reworking update handling

Use a fixed-size queue instead of a callback to deal with updates.

Port the message box and entity cache from grammers to start off
with a clean design.

Temporarily get rid of other cruft such as automatic pings or old
catch up implementation.
This commit is contained in:
Lonami Exo 2022-01-18 19:46:19 +01:00
parent 3afabdd7c0
commit f6df5d377c
7 changed files with 704 additions and 315 deletions

View File

@ -11,10 +11,11 @@ import dataclasses
from .. import version, __name__ as __base_name__, _tl from .. import version, __name__ as __base_name__, _tl
from .._crypto import rsa from .._crypto import rsa
from .._misc import markdown, statecache, enums, helpers from .._misc import markdown, enums, helpers
from .._network import MTProtoSender, Connection, transports from .._network import MTProtoSender, Connection, transports
from .._sessions import Session, SQLiteSession, MemorySession from .._sessions import Session, SQLiteSession, MemorySession
from .._sessions.types import DataCenter, SessionState from .._sessions.types import DataCenter, SessionState
from .._updates import EntityCache, MessageBox
DEFAULT_DC_ID = 2 DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -91,6 +92,7 @@ def init(
flood_sleep_threshold: int = 60, flood_sleep_threshold: int = 60,
# Update handling. # Update handling.
receive_updates: bool = True, receive_updates: bool = True,
max_queued_updates: int = 100,
): ):
# Logging. # Logging.
if isinstance(base_logger, str): if isinstance(base_logger, str):
@ -139,6 +141,13 @@ def init(
self._flood_waited_requests = {} # prevent calls that would floodwait entirely self._flood_waited_requests = {} # prevent calls that would floodwait entirely
self._parse_mode = markdown self._parse_mode = markdown
# Update handling.
self._no_updates = not receive_updates
self._updates_queue = asyncio.Queue(maxsize=max_queued_updates)
self._updates_handle = None
self._message_box = MessageBox()
self._entity_cache = EntityCache() # required for proper update handling (to know when to getDifference)
# Connection parameters. # Connection parameters.
if not api_id or not api_hash: if not api_id or not api_hash:
raise ValueError( raise ValueError(
@ -189,16 +198,13 @@ def init(
delay=self._connect_retry_delay, delay=self._connect_retry_delay,
auto_reconnect=self._auto_reconnect, auto_reconnect=self._auto_reconnect,
connect_timeout=self._connect_timeout, connect_timeout=self._connect_timeout,
update_callback=self._handle_update, updates_queue=self._updates_queue,
) )
# Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders. # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders.
self._borrowed_senders = {} self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock() self._borrow_sender_lock = asyncio.Lock()
# Update handling.
self._no_updates = not receive_updates
def get_flood_sleep_threshold(self): def get_flood_sleep_threshold(self):
return self._flood_sleep_threshold return self._flood_sleep_threshold
@ -337,15 +343,6 @@ async def _disconnect_coro(self: 'TelegramClient'):
# If any was borrowed # If any was borrowed
self._borrowed_senders.clear() self._borrowed_senders.clear()
# trio's nurseries would handle this for us, but this is asyncio.
# All tasks spawned in the background should properly be terminated.
if self._dispatching_updates_queue is None and self._updates_queue:
for task in self._updates_queue:
task.cancel()
await asyncio.wait(self._updates_queue)
self._updates_queue.clear()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -355,8 +352,11 @@ async def _disconnect(self: 'TelegramClient'):
their job with the client is complete and we should clean it up all. their job with the client is complete and we should clean it up all.
""" """
await self._sender.disconnect() await self._sender.disconnect()
await helpers._cancel(self._log[__name__], await helpers._cancel(self._log[__name__], updates_handle=self._updates_handle)
updates_handle=self._updates_handle) try:
await self._updates_handle
except asyncio.CancelledError:
pass
async def _switch_dc(self: 'TelegramClient', new_dc): async def _switch_dc(self: 'TelegramClient', new_dc):
""" """

View File

@ -2665,6 +2665,7 @@ class TelegramClient:
flood_sleep_threshold: int = 60, flood_sleep_threshold: int = 60,
# Update handling. # Update handling.
receive_updates: bool = True, receive_updates: bool = True,
max_queued_updates: int = 100,
): ):
telegrambaseclient.init(**locals()) telegrambaseclient.init(**locals())
@ -3509,10 +3510,6 @@ class TelegramClient:
async def _clean_exported_senders(self: 'TelegramClient'): async def _clean_exported_senders(self: 'TelegramClient'):
pass pass
@forward_call(updates._handle_update)
def _handle_update(self: 'TelegramClient', update):
pass
@forward_call(auth._update_session_state) @forward_call(auth._update_session_state)
async def _update_session_state(self, user, *, save=True): async def _update_session_state(self, user, *, save=True):
pass pass

View File

@ -79,295 +79,9 @@ def list_event_handlers(self: 'TelegramClient')\
return [(callback, event) for event, callback in self._event_builders] return [(callback, event) for event, callback in self._event_builders]
async def catch_up(self: 'TelegramClient'): async def catch_up(self: 'TelegramClient'):
return
self._catching_up = True
try:
while True:
d = await self(_tl.fn.updates.GetDifference(
pts, date, 0
))
if isinstance(d, (_tl.updates.DifferenceSlice,
_tl.updates.Difference)):
if isinstance(d, _tl.updates.Difference):
state = d.state
else:
state = d.intermediate_state
pts, date = state.pts, state.date
_handle_update(self, _tl.Updates(
users=d.users,
chats=d.chats,
date=state.date,
seq=state.seq,
updates=d.other_updates + [
_tl.UpdateNewMessage(m, 0, 0)
for m in d.new_messages
]
))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about.
#
# We may still get duplicates because the Difference
# contains a lot of updates and presumably only has
# the state for the last one, but at least we don't
# unnecessarily fetch too many.
#
# updates.getDifference's pts_total_limit seems to mean
# "how many pts is the request allowed to return", and
# if there is more than that, it returns "too long" (so
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
else:
if isinstance(d, _tl.updates.DifferenceEmpty):
date = d.date
elif isinstance(d, _tl.updates.DifferenceTooLong):
pts = d.pts
break
except (ConnectionError, asyncio.CancelledError):
pass pass
finally:
self._catching_up = False
# It is important to not make _handle_update async because we rely on
# the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async.
def _handle_update(self: 'TelegramClient', update):
if isinstance(update, (_tl.Updates, _tl.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
_process_update(self, u, entities, update.updates)
elif isinstance(update, _tl.UpdateShort):
_process_update(self, update.update, {}, None)
else:
_process_update(self, update, {}, None)
def _process_update(self: 'TelegramClient', update, entities, others):
# This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass
# arguments which is faster.
args = (update, entities, others, channel_id, None)
if self._dispatching_updates_queue is None:
task = asyncio.create_task(_dispatch_update(self, *args))
self._updates_queue.add(task)
task.add_done_callback(lambda _: self._updates_queue.discard(task))
else:
self._updates_queue.put_nowait(args)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
asyncio.create_task(_dispatch_queue_updates(self))
async def _update_loop(self: 'TelegramClient'): async def _update_loop(self: 'TelegramClient'):
# Pings' ID don't really need to be secure, just "random"
rnd = lambda: random.randrange(-2**63, 2**63)
while self.is_connected(): while self.is_connected():
try: updates = await self._updates_queue.get()
await asyncio.wait_for(self.run_until_disconnected(), timeout=60) updates, users, chats = self._message_box.process_updates(updates, self._entity_cache)
continue # We actually just want to act upon timeout
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
return
except Exception as e:
# Any disconnected exception should be ignored (or it may hint at
# another problem, leading to an infinite loop, hence the logging call)
self._log[__name__].info('Exception waiting on a disconnect: %s', e)
continue
# Check if we have any exported senders to clean-up periodically
await self._clean_exported_senders()
# Don't bother sending pings until the low-level connection is
# ready, otherwise a lot of pings will be batched to be sent upon
# reconnect, when we really don't care about that.
if not self._sender._transport_connected():
continue
# We also don't really care about their result.
# Just send them periodically.
try:
self._sender._keepalive_ping(rnd())
except (ConnectionError, asyncio.CancelledError):
return
# Entities are not saved when they are inserted because this is a rather expensive
# operation (default's sqlite3 takes ~0.1s to commit changes). Do it every minute
# instead. No-op if there's nothing new.
await self._session.save()
# We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will
# just stop even if we're connected. Do so every 30 minutes.
#
# TODO Call getDifference instead since it's more relevant
if time.time() - self._last_request > 30 * 60:
if not await self.is_user_authorized():
# What can be the user doing for so
# long without being logged in...?
continue
try:
await self(_tl.fn.updates.GetState())
except (ConnectionError, asyncio.CancelledError):
return
async def _dispatch_queue_updates(self: 'TelegramClient'):
while not self._updates_queue.empty():
await _dispatch_update(self, *self._updates_queue.get_nowait())
self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, entities, others, channel_id, pts_date):
built = EventBuilderDict(self, update, entities, others)
for builder, callback in self._event_builders:
event = built[type(builder)]
if not event:
continue
if not builder.resolved:
await builder.resolve(self)
filter = builder.filter(event)
if inspect.isawaitable(filter):
filter = await filter
if not filter:
continue
try:
await callback(event)
except StopPropagation:
name = getattr(callback, '__name__', repr(callback))
self._log[__name__].debug(
'Event handler "%s" stopped chain of propagation '
'for event %s.', name, type(event).__name__
)
break
except Exception as e:
if not isinstance(e, asyncio.CancelledError) or self.is_connected():
name = getattr(callback, '__name__', repr(callback))
self._log[__name__].exception('Unhandled exception on %s', name)
async def _dispatch_event(self: 'TelegramClient', event):
"""
Dispatches a single, out-of-order event. Used by `AlbumHack`.
"""
# We're duplicating a most logic from `_dispatch_update`, but all in
# the name of speed; we don't want to make it worse for all updates
# just because albums may need it.
for builder, callback in self._event_builders:
if isinstance(builder, Raw):
continue
if not isinstance(event, builder.Event):
continue
if not builder.resolved:
await builder.resolve(self)
filter = builder.filter(event)
if inspect.isawaitable(filter):
filter = await filter
if not filter:
continue
try:
await callback(event)
except StopPropagation:
name = getattr(callback, '__name__', repr(callback))
self._log[__name__].debug(
'Event handler "%s" stopped chain of propagation '
'for event %s.', name, type(event).__name__
)
break
except Exception as e:
if not isinstance(e, asyncio.CancelledError) or self.is_connected():
name = getattr(callback, '__name__', repr(callback))
self._log[__name__].exception('Unhandled exception on %s', name)
async def _get_difference(self: 'TelegramClient', update, entities, channel_id, pts_date):
"""
Get the difference for this `channel_id` if any, then load entities.
Calls :tl:`updates.getDifference`, which fills the entities cache
(always done by `__call__`) and lets us know about the full entities.
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full, including its entities.
self._log[__name__].debug('Getting difference for entities '
'for %r', update.__class__)
if channel_id:
# There are reports where we somehow call get channel difference
# with `InputPeerEmpty`. Check our assumptions to better debug
# this when it happens.
assert isinstance(channel_id, int), 'channel_id was {}, not int in {}'.format(type(channel_id), update)
try:
# Wrap the ID inside a peer to ensure we get a channel back.
where = await self.get_input_entity(_tl.PeerChannel(channel_id))
except ValueError:
# There's a high chance that this fails, since
# we are getting the difference to fetch entities.
return
if not pts_date:
# First-time, can't get difference. Get pts instead.
result = await self(_tl.fn.channels.GetFullChannel(
utils.get_input_channel(where)
))
return
result = await self(_tl.fn.updates.GetChannelDifference(
channel=where,
filter=_tl.ChannelMessagesFilterEmpty(),
pts=pts_date, # just pts
limit=100,
force=True
))
else:
if not pts_date[0]:
# First-time, can't get difference. Get pts instead.
result = await self(_tl.fn.updates.GetState())
return
result = await self(_tl.fn.updates.GetDifference(
pts=pts_date[0],
date=pts_date[1],
qts=0
))
if isinstance(result, (_tl.updates.Difference,
_tl.updates.DifferenceSlice,
_tl.updates.ChannelDifference,
_tl.updates.ChannelDifferenceTooLong)):
entities.update({
utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats)
})
class EventBuilderDict:
"""
Helper "dictionary" to return events from types and cache them.
"""
def __init__(self, client: 'TelegramClient', update, entities, others):
self.client = client
self.update = update
self.entities = entities
self.others = others
def __getitem__(self, builder):
try:
return self.__dict__[builder]
except KeyError:
event = self.__dict__[builder] = builder.build(
self.update, self.others, self.client._session_state.user_id, self.entities or {}, self.client)
if isinstance(event, EventCommon):
# TODO eww
event.original_update = self.update
event._entities = self.entities or {}
event._set_client(self.client)
return event

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import collections import collections
import struct import struct
import logging
from . import authenticator from . import authenticator
from .._misc.messagepacker import MessagePacker from .._misc.messagepacker import MessagePacker
@ -20,6 +21,9 @@ from .._misc import helpers, utils
from .. import _tl from .. import _tl
UPDATE_BUFFER_FULL_WARN_DELAY = 15 * 60
class MTProtoSender: class MTProtoSender:
""" """
MTProto Mobile Protocol sender MTProto Mobile Protocol sender
@ -35,9 +39,8 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other A new authorization key will be generated on connection if no other
key exists yet. key exists yet.
""" """
def __init__(self, *, loggers, def __init__(self, *, loggers, updates_queue,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None, retries=5, delay=1, auto_reconnect=True, connect_timeout=None,):
update_callback=None):
self._connection = None self._connection = None
self._loggers = loggers self._loggers = loggers
self._log = loggers[__name__] self._log = loggers[__name__]
@ -45,7 +48,7 @@ class MTProtoSender:
self._delay = delay self._delay = delay
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
self._update_callback = update_callback self._updates_queue = updates_queue
self._connect_lock = asyncio.Lock() self._connect_lock = asyncio.Lock()
self._ping = None self._ping = None
@ -83,6 +86,9 @@ class MTProtoSender:
# is received, but we may still need to resend their state on bad salts. # is received, but we may still need to resend their state on bad salts.
self._last_acks = collections.deque(maxlen=10) self._last_acks = collections.deque(maxlen=10)
# Last time we warned about the update buffer being full
self._last_update_warn = -UPDATE_BUFFER_FULL_WARN_DELAY
# Jump table from response ID to method that handles it # Jump table from response ID to method that handles it
self._handlers = { self._handlers = {
RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result, RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result,
@ -629,8 +635,16 @@ class MTProtoSender:
return return
self._log.debug('Handling update %s', message.obj.__class__.__name__) self._log.debug('Handling update %s', message.obj.__class__.__name__)
if self._update_callback: try:
self._update_callback(message.obj) self._updates_queue.put_nowait(message.obj)
except asyncio.QueueFull:
now = asyncio.get_running_loop().time()
if now - self._last_update_warn >= UPDATE_BUFFER_FULL_WARN_DELAY:
self._log.warning(
'Cannot dispatch update because the buffer capacity of %d was reached',
self._updates_queue.maxsize
)
self._last_update_warn = now
async def _handle_pong(self, message): async def _handle_pong(self, message):
""" """

View File

@ -0,0 +1,2 @@
from .entitycache import EntityCache, PackedChat
from .messagebox import MessageBox

View File

@ -0,0 +1,97 @@
import inspect
import itertools
from dataclasses import dataclass, field
from collections import namedtuple
from .._misc import utils
from .. import _tl
from .._sessions.types import EntityType, Entity
class PackedChat(namedtuple('PackedChat', 'ty id hash')):
__slots__ = ()
@property
def is_user(self):
return self.ty in (EntityType.USER, EntityType.BOT)
@property
def is_chat(self):
return self.ty in (EntityType.GROUP,)
@property
def is_channel(self):
return self.ty in (EntityType.CHANNEL, EntityType.MEGAGROUP, EntityType.GIGAGROUP)
def to_peer(self):
if self.is_user:
return _tl.PeerUser(user_id=self.id)
elif self.is_chat:
return _tl.PeerChat(chat_id=self.id)
elif self.is_channel:
return _tl.PeerChannel(channel_id=self.id)
def to_input_peer(self):
if self.is_user:
return _tl.InputPeerUser(user_id=self.id, access_hash=self.hash)
elif self.is_chat:
return _tl.InputPeerChat(chat_id=self.id)
elif self.is_channel:
return _tl.InputPeerChannel(channel_id=self.id, access_hash=self.hash)
def try_to_input_user(self):
if self.is_user:
return _tl.InputUser(user_id=self.id, access_hash=self.hash)
else:
return None
def try_to_chat_id(self):
if self.is_chat:
return self.id
else:
return None
def try_to_input_channel(self):
if self.is_channel:
return _tl.InputChannel(channel_id=self.id, access_hash=self.hash)
else:
return None
def __str__(self):
return f'{chr(self.ty.value)}.{self.id}.{self.hash}'
@dataclass
class EntityCache:
hash_map: dict = field(default_factory=dict) # id -> (hash, ty)
self_id: int = None
self_bot: bool = False
def set_self_user(self, id, bot):
self.self_id = id
self.self_bot = bot
def get(self, id):
value = self.hash_map.get(id)
return PackedChat(ty=value[1], id=id, hash=value[0]) if value else None
def extend(self, users, chats):
# See https://core.telegram.org/api/min for "issues" with "min constructors".
self.hash_map.update(
(u.id, (
u.access_hash,
EntityType.BOT if u.bot else EntityType.USER,
))
for u in users
if getattr(u, 'access_hash', None) and not u.min
)
self.hash_map.update(
(c.id, (
c.access_hash,
EntityType.MEGAGROUP if c.megagroup else (
EntityType.GIGAGROUP if getattr(c, 'gigagroup', None) else EntityType.CHANNEL
),
))
for c in chats
if getattr(c, 'access_hash', None) and not getattr(c, 'min', None)
)

View File

@ -0,0 +1,565 @@
"""
This module deals with correct handling of updates, including gaps, and knowing when the code
should "get difference" (the set of updates that the client should know by now minus the set
of updates that it actually knows).
Each chat has its own [`Entry`] in the [`MessageBox`] (this `struct` is the "entry point").
At any given time, the message box may be either getting difference for them (entry is in
[`MessageBox::getting_diff_for`]) or not. If not getting difference, a possible gap may be
found for the updates (entry is in [`MessageBox::possible_gaps`]). Otherwise, the entry is
on its happy path.
Gaps are cleared when they are either resolved on their own (by waiting for a short time)
or because we got the difference for the corresponding entry.
While there are entries for which their difference must be fetched,
[`MessageBox::check_deadlines`] will always return [`Instant::now`], since "now" is the time
to get the difference.
"""
import asyncio
from dataclasses import dataclass, field
from .._sessions.types import SessionState, ChannelState
# Telegram sends `seq` equal to `0` when "it doesn't matter", so we use that value too.
NO_SEQ = 0
# See https://core.telegram.org/method/updates.getChannelDifference.
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
# > It may be useful to wait up to 0.5 seconds
POSSIBLE_GAP_TIMEOUT = 0.5
# After how long without updates the client will "timeout".
#
# When this timeout occurs, the client will attempt to fetch updates by itself, ignoring all the
# updates that arrive in the meantime. After all updates are fetched when this happens, the
# client will resume normal operation, and the timeout will reset.
#
# Documentation recommends 15 minutes without updates (https://core.telegram.org/api/updates).
NO_UPDATES_TIMEOUT = 15 * 60
# Entry "enum".
# Account-wide `pts` includes private conversations (one-to-one) and small group chats.
ENTRY_ACCOUNT = object()
# Account-wide `qts` includes only "secret" one-to-one chats.
ENTRY_SECRET = object()
# Integers will be Channel-specific `pts`, and includes "megagroup", "broadcast" and "supergroup" channels.
def next_updates_deadline():
return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT
class GapError(ValueError):
pass
# Represents the information needed to correctly handle a specific `tl::enums::Update`.
@dataclass
class PtsInfo:
pts: int
pts_count: int
entry: object
@classmethod
def from_update(cls, update):
pts = getattr(update, 'pts', None)
if pts:
pts_count = getattr(update, 'pts_count', None) or 0
entry = getattr(update, 'channel_id', None) or ENTRY_ACCOUNT
return cls(pts=pts, pts_count=pts_count, entry=entry)
qts = getattr(update, 'qts', None)
if qts:
pts_count = 1 if isinstance(update, _tl.UpdateNewEncryptedMessage) else 0
return cls(pts=qts, pts_count=pts_count, entry=ENTRY_SECRET)
return None
# The state of a particular entry in the message box.
@dataclass
class State:
# Current local persistent timestamp.
pts: int
# Next instant when we would get the update difference if no updates arrived before then.
deadline: float
# > ### Recovering gaps
# > […] Manually obtaining updates is also required in the following situations:
# > • Loss of sync: a gap was found in `seq` / `pts` / `qts` (as described above).
# > It may be useful to wait up to 0.5 seconds in this situation and abort the sync in case a new update
# > arrives, that fills the gap.
#
# This is really easy to trigger by spamming messages in a channel (with as little as 3 members works), because
# the updates produced by the RPC request take a while to arrive (whereas the read update comes faster alone).
@dataclass
class PossibleGap:
deadline: float
# Pending updates (those with a larger PTS, producing the gap which may later be filled).
updates: list # of updates
# Represents a "message box" (event `pts` for a specific entry).
#
# See https://core.telegram.org/api/updates#message-related-event-sequences.
@dataclass
class MessageBox:
# Map each entry to their current state.
map: dict = field(default_factory=dict) # entry -> state
# Additional fields beyond PTS needed by `ENTRY_ACCOUNT`.
date: int = 1
seq: int = 0
# Holds the entry with the closest deadline (optimization to avoid recalculating the minimum deadline).
next_deadline: object = None # entry
# Which entries have a gap and may soon trigger a need to get difference.
#
# If a gap is found, stores the required information to resolve it (when should it timeout and what updates
# should be held in case the gap is resolved on its own).
#
# Not stored directly in `map` as an optimization (else we would need another way of knowing which entries have
# a gap in them).
possible_gaps: dict = field(default_factory=dict) # entry -> possiblegap
# For which entries are we currently getting difference.
getting_diff_for: set = field(default_factory=set) # entry
# Temporarily stores which entries should have their update deadline reset.
# Stored in the message box in order to reuse the allocation.
reset_deadlines_for: set = field(default_factory=set) # entry
# region Creation, querying, and setting base state.
@classmethod
def load(cls, session_state, channel_states):
"""
Create a [`MessageBox`] from a previously known update state.
"""
deadline = next_updates_deadline()
return cls(
map={
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
**{s.channel_id: s.pts for s in channel_states}
},
date=session_state.date,
seq=session_state.seq,
next_deadline=ENTRY_ACCOUNT,
)
@classmethod
def session_state(self):
"""
Return the current state in a format that sessions understand.
This should be used for persisting the state.
"""
return SessionState(
user_id=0,
dc_id=0,
bot=False,
pts=self.map.get(ENTRY_ACCOUNT, 0),
qts=self.map.get(ENTRY_SECRET, 0),
date=self.date,
seq=self.seq,
takeout_id=None,
), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)]
def is_empty(self) -> bool:
"""
Return true if the message box is empty and has no state yet.
"""
return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ
def check_deadlines(self):
"""
Return the next deadline when receiving updates should timeout.
If a deadline expired, the corresponding entries will be marked as needing to get its difference.
While there are entries pending of getting their difference, this method returns the current instant.
"""
now = asyncio.get_running_loop().time()
if self.getting_diff_for:
return now
deadline = next_updates_deadline()
# Most of the time there will be zero or one gap in flight so finding the minimum is cheap.
if self.possible_gaps:
deadline = min(deadline, *self.possible_gaps.values())
elif self.next_deadline in self.map:
deadline = min(deadline, self.map[self.next_deadline])
if now > deadline:
# Check all expired entries and add them to the list that needs getting difference.
self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now > gap.deadline)
self.getting_diff_for.update(entry for entry, state in self.map.items() if now > state.deadline)
# When extending `getting_diff_for`, it's important to have the moral equivalent of
# `begin_get_diff` (that is, clear possible gaps if we're now getting difference).
for entry in self.getting_diff_for:
self.possible_gaps.pop(entry, None)
return deadline
# Reset the deadline for the periods without updates for a given entry.
#
# It also updates the next deadline time to reflect the new closest deadline.
def reset_deadline(self, entry, deadline):
if entry in self.map:
self.map[entry].deadline = deadline
# TODO figure out why not in map may happen
if self.next_deadline == entry:
# If the updated deadline was the closest one, recalculate the new minimum.
self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0]
elif deadline < self.map.get(self.next_deadline, 0):
# If the updated deadline is smaller than the next deadline, change the next deadline to be the new one.
self.next_deadline = entry
# else an unrelated deadline was updated, so the closest one remains unchanged.
# Convenience to reset a channel's deadline, with optional timeout.
def reset_channel_deadline(self, channel_id, timeout):
self.reset_deadlines(channel_id, asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT))
# Reset all the deadlines in `reset_deadlines_for` and then empty the set.
def apply_deadlines_reset(self):
next_deadline = next_updates_deadline()
reset_deadlines_for = self.reset_deadlines_for
self.reset_deadlines_for = set() # "move" the set to avoid self.reset_deadline() from touching it during iter
for entry in reset_deadlines_for:
self.reset_deadline(entry, next_deadline)
reset_deadlines_for.clear() # reuse allocation, the other empty set was a temporary dummy value
self.reset_deadlines_for = reset_deadlines_for
# Sets the update state.
#
# Should be called right after login if [`MessageBox::new`] was used, otherwise undesirable
# updates will be fetched.
def set_state(self, state):
deadline = next_updates_deadline()
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
self.date = state.date
self.seq = state.seq
# Like [`MessageBox::set_state`], but for channels. Useful when getting dialogs.
#
# The update state will only be updated if no entry was known previously.
def try_set_channel_state(self, id, pts):
if id not in self.map:
self.map[id] = State(pts=pts, deadline=next_updates_deadline())
# Begin getting difference for the given entry.
#
# Clears any previous gaps.
def begin_get_diff(self, entry):
self.getting_diff_for.add(entry)
self.possible_gaps.pop(entry, None)
# Finish getting difference for the given entry.
#
# It also resets the deadline.
def end_get_diff(self, entry):
self.getting_diff_for.pop(entry, None)
self.reset_deadline(entry, next_updates_deadline())
assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference"
# endregion Creation, querying, and setting base state.
# region "Normal" updates flow (processing and detection of gaps).
# Process an update and return what should be done with it.
#
# Updates corresponding to entries for which their difference is currently being fetched
# will be ignored. While according to the [updates' documentation]:
#
# > Implementations [have] to postpone updates received via the socket while
# > filling gaps in the event and `Update` sequences, as well as avoid filling
# > gaps in the same sequence.
#
# In practice, these updates should have also been retrieved through getting difference.
#
# [updates documentation] https://core.telegram.org/api/updates
def process_updates(
self,
updates,
chat_hashes,
result, # out list of updates; returns list of user, chat, or raise if gap
):
# XXX adapt updates and chat hashes into updatescombined, raise gap on too long
date = updates.date
seq_start = updates.seq_start
seq = updates.seq
updates = updates.updates
users = updates.users
chats = updates.chats
# > For all the other [not `updates` or `updatesCombined`] `Updates` type constructors
# > there is no need to check `seq` or change a local state.
if updates.seq_start != NO_SEQ:
if self.seq + 1 > updates.seq_start:
# Skipping updates that were already handled
return (updates.users, updates.chats)
elif self.seq + 1 < updates.seq_start:
# Gap detected
self.begin_get_diff(ENTRY_ACCOUNT)
raise GapError
# else apply
self.date = updates.date
if updates.seq != NO_SEQ:
self.seq = updates.seq
result.extend(filter(None, (self.apply_pts_info(u, reset_deadline=True) for u in updates.updates)))
self.apply_deadlines_reset()
def _sort_gaps(update):
pts = PtsInfo.from_update(u)
return pts.pts - pts.pts_count if pts else 0
if self.possible_gaps:
# For each update in possible gaps, see if the gap has been resolved already.
for key in list(self.possible_gaps.keys()):
self.possible_gaps[key].updates.sort(key=_sort_gaps)
for _ in range(len(self.possible_gaps[key].updates)):
update = self.possible_gaps[key].updates.pop(0)
# If this fails to apply, it will get re-inserted at the end.
# All should fail, so the order will be preserved (it would've cycled once).
update = self.apply_pts_info(update, reset_deadline=False)
if update:
result.append(update)
# Clear now-empty gaps.
self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps if gap.updates}
return (updates.users, updates.chats)
# Tries to apply the input update if its `PtsInfo` follows the correct order.
#
# If the update can be applied, it is returned; otherwise, the update is stored in a
# possible gap (unless it was already handled or would be handled through getting
# difference) and `None` is returned.
def apply_pts_info(
self,
update,
*,
reset_deadline,
):
pts = PtsInfo.from_update(update)
if not pts:
# No pts means that the update can be applied in any order.
return update
# As soon as we receive an update of any form related to messages (has `PtsInfo`),
# the "no updates" period for that entry is reset.
#
# Build the `HashSet` to avoid calling `reset_deadline` more than once for the same entry.
if reset_deadline:
self.reset_deadlines_for.insert(pts.entry)
if pts.entry in self.getting_diff_for:
# Note: early returning here also prevents gap from being inserted (which they should
# not be while getting difference).
return None
if pts.entry in self.map:
local_pts = self.map[pts.entry].pts
if local_pts + pts.pts_count > pts.pts:
# Ignore
return None
elif local_pts + pts.pts_count < pts.pts:
# Possible gap
# TODO store chats too?
if pts.entry not in self.possible_gaps:
self.possible_gaps[pts.entry] = PossibleGap(
deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT,
updates=[]
)
self.possible_gaps[pts.entry].updates.append(update)
return None
else:
# Apply
pass
else:
# No previous `pts` known, and because this update has to be "right" (it's the first one) our
# `local_pts` must be one less.
local_pts = pts.pts - 1
# For example, when we're in a channel, we immediately receive:
# * ReadChannelInbox (pts = X)
# * NewChannelMessage (pts = X, pts_count = 1)
#
# Notice how both `pts` are the same. If we stored the one from the first, then the second one would
# be considered "already handled" and ignored, which is not desirable. Instead, advance local `pts`
# by `pts_count` (which is 0 for updates not directly related to messages, like reading inbox).
if pts.entry in self.map:
self.map[pts.entry].pts = local_pts + pts.pts_count
else:
self.map[pts.entry] = State(pts=local_pts + pts.pts_count, deadline=next_updates_deadline())
return update
# endregion "Normal" updates flow (processing and detection of gaps).
# region Getting and applying account difference.
# Return the request that needs to be made to get the difference, if any.
def get_difference(self):
entry = ENTRY_ACCOUNT
if entry in self.getting_diff_for:
if entry in self.map:
return _tl.fn.updates.GetDifference(
pts=state.pts,
pts_total_limit=None,
date=self.date,
qts=self.map[ENTRY_SECRET].pts,
)
else:
# TODO investigate when/why/if this can happen
self.end_get_diff(entry)
return None
# Similar to [`MessageBox::process_updates`], but using the result from getting difference.
def apply_difference(
self,
diff,
chat_hashes,
):
if isinstance(diff, _tl.updates.DifferenceEmpty):
self.date = diff.date
self.seq = diff.seq
self.end_get_diff(ENTRY_ACCOUNT)
return [], [], []
elif isinstance(diff, _tl.updates.Difference):
self.end_get_diff(ENTRY_ACCOUNT)
chat_hashes.extend(diff.users, diff.chats)
return self.apply_difference_type(diff)
elif isinstance(diff, _tl.updates.DifferenceSlice):
chat_hashes.extend(diff.users, diff.chats)
return self.apply_difference_type(diff)
elif isinstance(diff, _tl.updates.DifferenceTooLong):
# TODO when are deadlines reset if we update the map??
self.map[ENTRY_ACCOUNT].pts = diff.pts
self.end_get_diff(ENTRY_ACCOUNT)
return [], [], []
def apply_difference_type(
self,
diff,
):
state = getattr(diff, 'intermediate_state', None) or diff.state
self.map[ENTRY_ACCOUNT].pts = state.pts
self.map[ENTRY_SECRET].pts = state.qts
self.date = state.date
self.seq = state.seq
for u in diff.updates:
if isinstance(u, _tl.UpdateChannelTooLong):
self.begin_get_diff(u.channel_id)
updates.extend(_tl.UpdateNewMessage(
message=m,
pts=NO_SEQ,
pts_count=NO_SEQ,
) for m in diff.new_messages)
updates.extend(_tl.UpdateNewEncryptedMessage(
message=m,
qts=NO_SEQ,
) for m in diff.new_encrypted_messages)
return diff.updates, diff.users, diff.chats
# endregion Getting and applying account difference.
# region Getting and applying channel difference.
# Return the request that needs to be made to get a channel's difference, if any.
def get_channel_difference(
self,
chat_hashes,
):
entry = next((id for id in self.getting_diff_for if isinstance(id, int)), None)
if not entry:
return None
packed = chat_hashes.get(entry)
if not packed:
# Cannot get channel difference as we're missing its hash
self.end_get_diff(entry)
# Remove the outdated `pts` entry from the map so that the next update can correct
# it. Otherwise, it will spam that the access hash is missing.
self.map.pop(entry, None)
return None
state = self.map.get(entry)
if not state:
# TODO investigate when/why/if this can happen
# Cannot get channel difference as we're missing its pts
self.end_get_diff(entry)
return None
return _tl.fn.updates.GetChannelDifference(
force=False,
channel=channel,
filter=_tl.ChannelMessagesFilterEmpty(),
pts=state.pts,
limit=BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot() else USER_CHANNEL_DIFF_LIMIT
)
# Similar to [`MessageBox::process_updates`], but using the result from getting difference.
def apply_channel_difference(
self,
request,
diff,
chat_hashes,
):
entry = request.channel.channel_id
self.possible_gaps.remove(entry)
if isinstance(diff, _tl.updates.ChannelDifferenceEmpty):
assert diff.final
self.end_get_diff(entry)
self.map[entry].pts = diff.pts
return [], [], []
elif isinstance(diff, _tl.updates.ChannelDifferenceTooLong):
assert diff.final
self.map[entry].pts = diff.dialog.pts
chat_hashes.extend(diff.users, diff.chats)
self.reset_channel_deadline(channel_id, diff.timeout)
# This `diff` has the "latest messages and corresponding chats", but it would
# be strange to give the user only partial changes of these when they would
# expect all updates to be fetched. Instead, nothing is returned.
return [], [], []
elif isinstance(diff, _tl.updates.ChannelDifference):
if diff.final:
self.end_get_diff(entry)
self.map[entry].pts = pts
updates.extend(_tl.UpdateNewMessage(
message=m,
pts=NO_SEQ,
pts_count=NO_SEQ,
) for m in diff.new_messages)
chat_hashes.extend(diff.users, diff.chats);
self.reset_channel_deadline(channel_id, timeout)
(diff.updates, diff.users, diff.chats)
# endregion Getting and applying channel difference.