Fix catch_up

updates.py:
- Fix catch_up getting stuck in infinite loop by saving qts
- Save new state to session after catch_up
- Save initial state immediately if catch_up was ran first time
- If DifferenceTooLong was received, continue fetching with new pts instead of exiting
- Catch up channels
- Implement pts limit (pts_total_limit in chats and limit in channels)
- Check if update was already processed before processing

telegrambaseclient.py, sessions:
- Get saved state from channels on startup and save channels state

statecache.py:
- Create separate variable _store in StateCache, because __dict__ also stores _logger variable
- Move has_pts, has_qts, has_date, has_channel_pts outside of function arguments

Fixes #3041
This commit is contained in:
vanutp 2021-05-27 11:31:55 +03:00
parent 3d350c6087
commit 34bb2b8fc3
No known key found for this signature in database
GPG Key ID: 1E36ED1DFB6FDE78
6 changed files with 198 additions and 86 deletions

View File

@ -6,6 +6,7 @@ import logging
import platform import platform
import time import time
import typing import typing
from datetime import datetime
from .. import version, helpers, __name__ as __base_name__ from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa from ..crypto import rsa
@ -401,9 +402,10 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection) # Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = StateCache( self._state_cache = StateCache(
self.session.get_update_state(0), self._log) self.session.get_update_state(0), self._log)
for k, v in self.session.get_channel_pts().items():
self._state_cache[k] = v
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -630,15 +632,23 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._updates_queue) await asyncio.wait(self._updates_queue)
self._updates_queue.clear() self._updates_queue.clear()
pts, date = self._state_cache[None] pts, qts, date = self._state_cache[None]
if pts and date: if pts and date:
self.session.set_update_state(0, types.updates.State( self.session.set_update_state(0, types.updates.State(
pts=pts, pts=pts,
qts=0, qts=qts,
date=date, date=date,
seq=0, seq=0,
unread_count=0 unread_count=0
)) ))
for channel_id, pts in self._state_cache.get_channel_pts().items():
self.session.set_update_state(channel_id, types.updates.State(
pts=pts,
qts=0,
date=datetime.fromtimestamp(0),
seq=0,
unread_count=0
))
self.session.close() self.session.close()

View File

@ -7,10 +7,12 @@ import time
import traceback import traceback
import typing import typing
import logging import logging
from datetime import datetime
from .. import events, utils, errors from .. import events, utils, errors
from ..events.common import EventBuilder, EventCommon from ..events.common import EventBuilder, EventCommon
from ..tl import types, functions from ..tl import types, functions
from ..tl.types import UpdateChannelTooLong
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient from .telegramclient import TelegramClient
@ -211,7 +213,45 @@ class UpdateMethods:
""" """
return [(callback, event) for event, callback in self._event_builders] return [(callback, event) for event, callback in self._event_builders]
async def catch_up(self: 'TelegramClient'): async def _catch_up_channel(self: 'TelegramClient', channel_id: int, pts: int, limit: int = None):
if self._state_cache[channel_id]:
pts = self._state_cache[channel_id]
if not pts:
# First-time, can't get difference. Get pts instead.
result = await self(functions.channels.GetFullChannelRequest(channel_id))
pts = self._state_cache[channel_id] = result.full_chat.pts
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
return
try:
while True:
d = await self(functions.updates.GetChannelDifferenceRequest(
channel=channel_id,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts,
limit=limit or 100000 # If the limit isn't set, fetch all updates that we can
))
if isinstance(d, types.updates.ChannelDifference):
pts = d.pts
self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=None,
seq=0,
updates=d.other_updates + [
types.UpdateNewChannelMessage(m, 0, 0)
for m in d.new_messages
]
))
elif isinstance(d, (types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifferenceEmpty)):
# If there is too much updates (ChannelDifferenceTooLong),
# there is no way to get them without raising limit or GetHistoryRequest, so just break
break
finally:
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None):
""" """
"Catches up" on the missed updates while the client was offline. "Catches up" on the missed updates while the client was offline.
You should call this method after registering the event handlers You should call this method after registering the event handlers
@ -224,15 +264,22 @@ class UpdateMethods:
await client.catch_up() await client.catch_up()
""" """
pts, date = self._state_cache[None] pts, qts, date = self._state_cache[None]
if not pts:
return
self.session.catching_up = True self.session.catching_up = True
try: try:
if not pts:
# Ran first time, get initial pts, qts and date and return
result = await self(functions.updates.GetStateRequest())
pts, qts, date = result.pts, result.qts, result.date
return
if not qts:
qts = 0
channels_to_fetch = []
while True: while True:
d = await self(functions.updates.GetDifferenceRequest( d = await self(functions.updates.GetDifferenceRequest(
pts, date, 0 pts, date, qts, pts_total_limit
)) ))
if isinstance(d, (types.updates.DifferenceSlice, if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)): types.updates.Difference)):
@ -241,43 +288,39 @@ class UpdateMethods:
else: else:
state = d.intermediate_state state = d.intermediate_state
pts, date = state.pts, state.date updates = []
for update in d.other_updates:
if isinstance(update, UpdateChannelTooLong):
channels_to_fetch.append((update.channel_id, update.pts))
else:
updates.append(update)
pts, qts, date = state.pts, state.qts, state.date
self._handle_update(types.Updates( self._handle_update(types.Updates(
users=d.users, users=d.users,
chats=d.chats, chats=d.chats,
date=state.date, date=state.date,
seq=state.seq, seq=state.seq,
updates=d.other_updates + [ updates=updates + [
types.UpdateNewMessage(m, 0, 0) types.UpdateNewMessage(m, 0, 0)
for m in d.new_messages for m in d.new_messages
] ]
)) ))
elif isinstance(d, types.updates.DifferenceTooLong):
# TODO Implement upper limit (max_pts) pts = d.pts
# We don't want to fetch updates we already know about. # If the limit isn't set, fetch all updates that we can
# if pts_total_limit is not None:
# We may still get duplicates because the Difference break
# contains a lot of updates and presumably only has elif isinstance(d, types.updates.DifferenceEmpty):
# the state for the last one, but at least we don't date = d.date
# unnecessarily fetch too many.
#
# updates.getDifference's pts_total_limit seems to mean
# "how many pts is the request allowed to return", and
# if there is more than that, it returns "too long" (so
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
else:
if isinstance(d, types.updates.DifferenceEmpty):
date = d.date
elif isinstance(d, types.updates.DifferenceTooLong):
pts = d.pts
break break
for channel_id, channel_pts in channels_to_fetch:
await self._catch_up_channel(channel_id, channel_pts, limit)
except (ConnectionError, asyncio.CancelledError): except (ConnectionError, asyncio.CancelledError):
pass pass
finally: finally:
# TODO Save new pts to session self._state_cache[None] = (pts, qts, date)
self._state_cache._pts_date = (pts, date) self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0))
self.session.catching_up = False self.session.catching_up = False
# endregion # endregion
@ -288,6 +331,9 @@ class UpdateMethods:
# the order that the updates arrive in to update the pts and date to # the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async. # be always-increasing. There is also no need to make this async.
def _handle_update(self: 'TelegramClient', update): def _handle_update(self: 'TelegramClient', update):
if self._state_cache.update_already_processed(update):
return
self.session.process_entities(update) self.session.process_entities(update)
self._entity_cache.add(update) self._entity_cache.add(update)
@ -304,6 +350,9 @@ class UpdateMethods:
self._state_cache.update(update) self._state_cache.update(update)
def _process_update(self: 'TelegramClient', update, others, entities=None): def _process_update(self: 'TelegramClient', update, others, entities=None):
if self._state_cache.update_already_processed(update):
return
update._entities = entities or {} update._entities = entities or {}
# This part is somewhat hot so we don't bother patching # This part is somewhat hot so we don't bother patching
@ -553,7 +602,7 @@ class UpdateMethods:
if not pts_date[0]: if not pts_date[0]:
# First-time, can't get difference. Get pts instead. # First-time, can't get difference. Get pts instead.
result = await self(functions.updates.GetStateRequest()) result = await self(functions.updates.GetStateRequest())
self._state_cache[None] = result.pts, result.date self._state_cache[None] = result.pts, result.qts, result.date
return return
result = await self(functions.updates.GetDifferenceRequest( result = await self(functions.updates.GetDifferenceRequest(

View File

@ -88,6 +88,13 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_channel_pts(self):
"""
Returns the ``Dict[int, int]`` with pts for all saved channels
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
""" """

View File

@ -74,6 +74,9 @@ class MemorySession(Session):
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)
def get_channel_pts(self):
return {x[0]: x[1] for x in self._update_states.items() if x[0] != 0}
def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state self._update_states[entity_id] = state

View File

@ -210,6 +210,14 @@ class SQLiteSession(MemorySession):
date, tz=datetime.timezone.utc) date, tz=datetime.timezone.utc)
return types.updates.State(pts, qts, date, seq, unread_count=0) return types.updates.State(pts, qts, date, seq, unread_count=0)
def get_channel_pts(self):
c = self._cursor()
try:
rows = c.execute('select id, pts from update_state').fetchall()
finally:
c.close()
return {x[0]: x[1] for x in rows if x[0] != 0}
def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
self._execute('insert or replace into update_state values (?,?,?,?,?)', self._execute('insert or replace into update_state values (?,?,?,?,?)',
entity_id, state.pts, state.qts, entity_id, state.pts, state.qts,

View File

@ -2,7 +2,6 @@ import inspect
from .tl import types from .tl import types
# Which updates have the following fields? # Which updates have the following fields?
_has_channel_id = [] _has_channel_id = []
@ -26,23 +25,70 @@ def _fill():
_fill() _fill()
has_pts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewMessage,
types.UpdateDeleteMessages,
types.UpdateReadHistoryInbox,
types.UpdateReadHistoryOutbox,
types.UpdateWebPage,
types.UpdateReadMessagesContents,
types.UpdateEditMessage,
types.updates.State,
types.updates.DifferenceTooLong,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
))
has_qts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateBotStopped,
types.UpdateNewEncryptedMessage,
types.updates.State
))
has_date = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateUserPhoto,
types.UpdateEncryption,
types.UpdateEncryptedMessagesRead,
types.UpdateChatParticipantAdd,
types.updates.DifferenceEmpty,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShort,
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
))
has_channel_pts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
types.UpdateDeleteChannelMessages,
types.UpdateEditChannelMessage,
types.UpdateChannelWebPage,
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
))
class StateCache: class StateCache:
""" """
In-memory update state cache, defaultdict-like behaviour. In-memory update state cache, defaultdict-like behaviour.
""" """
_store: dict
def __init__(self, initial, loggers): def __init__(self, initial, loggers):
# We only care about the pts and the date. By using a tuple which # We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to # is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities. # each update in case they need to fetch missing entities.
self._logger = loggers[__name__] self._logger = loggers[__name__]
self._store = {}
if initial: if initial:
self._pts_date = initial.pts, initial.date self._pts_date = initial.pts, initial.qts, initial.date
else: else:
self._pts_date = None, None self._pts_date = None, None, None
def reset(self): def reset(self):
self.__dict__.clear() self._store.clear()
self._pts_date = None, None self._pts_date = None, None, None
# TODO Call this when receiving responses too...? # TODO Call this when receiving responses too...?
def update( def update(
@ -50,43 +96,6 @@ class StateCache:
update, update,
*, *,
channel_id=None, channel_id=None,
has_pts=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewMessage,
types.UpdateDeleteMessages,
types.UpdateReadHistoryInbox,
types.UpdateReadHistoryOutbox,
types.UpdateWebPage,
types.UpdateReadMessagesContents,
types.UpdateEditMessage,
types.updates.State,
types.updates.DifferenceTooLong,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
)),
has_date=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateUserPhoto,
types.UpdateEncryption,
types.UpdateEncryptedMessagesRead,
types.UpdateChatParticipantAdd,
types.updates.DifferenceEmpty,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShort,
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
)),
has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
types.UpdateDeleteChannelMessages,
types.UpdateEditChannelMessage,
types.UpdateChannelWebPage,
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
)),
check_only=False check_only=False
): ):
""" """
@ -96,13 +105,20 @@ class StateCache:
if check_only: if check_only:
return cid in has_pts or cid in has_date or cid in has_channel_pts return cid in has_pts or cid in has_date or cid in has_channel_pts
new_pts_date = tuple()
if cid in has_pts: if cid in has_pts:
if cid in has_date: new_pts_date += update.pts,
self._pts_date = update.pts, update.date else:
else: new_pts_date += self._pts_date[0],
self._pts_date = update.pts, self._pts_date[1] if cid in has_qts:
elif cid in has_date: new_pts_date += update.qts,
self._pts_date = self._pts_date[0], update.date else:
new_pts_date += self._pts_date[1],
if cid in has_date:
new_pts_date += update.date,
else:
new_pts_date += self._pts_date[2],
self._pts_date = new_pts_date
if cid in has_channel_pts: if cid in has_channel_pts:
if channel_id is None: if channel_id is None:
@ -112,7 +128,23 @@ class StateCache:
self._logger.info( self._logger.info(
'Failed to retrieve channel_id from %s', update) 'Failed to retrieve channel_id from %s', update)
else: else:
self.__dict__[channel_id] = update.pts self._store[channel_id] = update.pts
def update_already_processed(self, update):
cid = update.CONSTRUCTOR_ID
# If pts == 0, the update is from catch_up
if cid in has_pts and \
update.pts != 0 and \
update.pts >= self._pts_date[0]:
return True
if cid in has_qts and update.qts >= self._pts_date[1]:
return True
if cid in has_channel_pts:
channel_id = self.get_channel_id(update)
if update.pts != 0 and \
self._store.get(channel_id, 0) >= update.pts:
return True
return False
def get_channel_id( def get_channel_id(
self, self,
@ -120,8 +152,8 @@ class StateCache:
has_channel_id=frozenset(_has_channel_id), has_channel_id=frozenset(_has_channel_id),
# Hardcoded because only some with message are for channels # Hardcoded because only some with message are for channels
has_message=frozenset(x.CONSTRUCTOR_ID for x in ( has_message=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewChannelMessage, types.UpdateNewChannelMessage,
types.UpdateEditChannelMessage types.UpdateEditChannelMessage
)) ))
): ):
""" """
@ -155,10 +187,13 @@ class StateCache:
if item is None: if item is None:
return self._pts_date return self._pts_date
else: else:
return self.__dict__.get(item) return self._store.get(item)
def __setitem__(self, where, value): def __setitem__(self, where, value):
if where is None: if where is None:
self._pts_date = value self._pts_date = value
else: else:
self.__dict__[where] = value self._store[where] = value
def get_channel_pts(self):
return self._store