Fix catch_up

updates.py:
- Fix catch_up getting stuck in infinite loop by saving qts
- Save new state to session after catch_up
- Save initial state immediately if catch_up was ran first time
- If DifferenceTooLong was received, continue fetching with new pts instead of exiting
- Catch up channels
- Implement pts limit (pts_total_limit in chats and limit in channels)
- Check if update was already processed before processing

telegrambaseclient.py, sessions:
- Get saved state from channels on startup and save channels state

statecache.py:
- Create separate variable _store in StateCache, because __dict__ also stores _logger variable
- Move has_pts, has_qts, has_date, has_channel_pts outside of function arguments

Fixes #3041
This commit is contained in:
vanutp 2021-05-27 11:31:55 +03:00
parent 3d350c6087
commit 34bb2b8fc3
No known key found for this signature in database
GPG Key ID: 1E36ED1DFB6FDE78
6 changed files with 198 additions and 86 deletions

View File

@ -6,6 +6,7 @@ import logging
import platform
import time
import typing
from datetime import datetime
from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa
@ -401,9 +402,10 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = StateCache(
self.session.get_update_state(0), self._log)
for k, v in self.session.get_channel_pts().items():
self._state_cache[k] = v
# Some further state for subclasses
self._event_builders = []
@ -630,15 +632,23 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._updates_queue)
self._updates_queue.clear()
pts, date = self._state_cache[None]
pts, qts, date = self._state_cache[None]
if pts and date:
self.session.set_update_state(0, types.updates.State(
pts=pts,
qts=0,
qts=qts,
date=date,
seq=0,
unread_count=0
))
for channel_id, pts in self._state_cache.get_channel_pts().items():
self.session.set_update_state(channel_id, types.updates.State(
pts=pts,
qts=0,
date=datetime.fromtimestamp(0),
seq=0,
unread_count=0
))
self.session.close()

View File

@ -7,10 +7,12 @@ import time
import traceback
import typing
import logging
from datetime import datetime
from .. import events, utils, errors
from ..events.common import EventBuilder, EventCommon
from ..tl import types, functions
from ..tl.types import UpdateChannelTooLong
if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient
@ -211,7 +213,45 @@ class UpdateMethods:
"""
return [(callback, event) for event, callback in self._event_builders]
async def catch_up(self: 'TelegramClient'):
async def _catch_up_channel(self: 'TelegramClient', channel_id: int, pts: int, limit: int = None):
if self._state_cache[channel_id]:
pts = self._state_cache[channel_id]
if not pts:
# First-time, can't get difference. Get pts instead.
result = await self(functions.channels.GetFullChannelRequest(channel_id))
pts = self._state_cache[channel_id] = result.full_chat.pts
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
return
try:
while True:
d = await self(functions.updates.GetChannelDifferenceRequest(
channel=channel_id,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts,
limit=limit or 100000 # If the limit isn't set, fetch all updates that we can
))
if isinstance(d, types.updates.ChannelDifference):
pts = d.pts
self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=None,
seq=0,
updates=d.other_updates + [
types.UpdateNewChannelMessage(m, 0, 0)
for m in d.new_messages
]
))
elif isinstance(d, (types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifferenceEmpty)):
# If there is too much updates (ChannelDifferenceTooLong),
# there is no way to get them without raising limit or GetHistoryRequest, so just break
break
finally:
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None):
"""
"Catches up" on the missed updates while the client was offline.
You should call this method after registering the event handlers
@ -224,15 +264,22 @@ class UpdateMethods:
await client.catch_up()
"""
pts, date = self._state_cache[None]
if not pts:
return
pts, qts, date = self._state_cache[None]
self.session.catching_up = True
try:
if not pts:
# Ran first time, get initial pts, qts and date and return
result = await self(functions.updates.GetStateRequest())
pts, qts, date = result.pts, result.qts, result.date
return
if not qts:
qts = 0
channels_to_fetch = []
while True:
d = await self(functions.updates.GetDifferenceRequest(
pts, date, 0
pts, date, qts, pts_total_limit
))
if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)):
@ -241,43 +288,39 @@ class UpdateMethods:
else:
state = d.intermediate_state
pts, date = state.pts, state.date
updates = []
for update in d.other_updates:
if isinstance(update, UpdateChannelTooLong):
channels_to_fetch.append((update.channel_id, update.pts))
else:
updates.append(update)
pts, qts, date = state.pts, state.qts, state.date
self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=state.date,
seq=state.seq,
updates=d.other_updates + [
updates=updates + [
types.UpdateNewMessage(m, 0, 0)
for m in d.new_messages
]
))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about.
#
# We may still get duplicates because the Difference
# contains a lot of updates and presumably only has
# the state for the last one, but at least we don't
# unnecessarily fetch too many.
#
# updates.getDifference's pts_total_limit seems to mean
# "how many pts is the request allowed to return", and
# if there is more than that, it returns "too long" (so
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
else:
if isinstance(d, types.updates.DifferenceEmpty):
date = d.date
elif isinstance(d, types.updates.DifferenceTooLong):
pts = d.pts
# If the limit isn't set, fetch all updates that we can
if pts_total_limit is not None:
break
elif isinstance(d, types.updates.DifferenceEmpty):
date = d.date
break
for channel_id, channel_pts in channels_to_fetch:
await self._catch_up_channel(channel_id, channel_pts, limit)
except (ConnectionError, asyncio.CancelledError):
pass
finally:
# TODO Save new pts to session
self._state_cache._pts_date = (pts, date)
self._state_cache[None] = (pts, qts, date)
self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0))
self.session.catching_up = False
# endregion
@ -288,6 +331,9 @@ class UpdateMethods:
# the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async.
def _handle_update(self: 'TelegramClient', update):
if self._state_cache.update_already_processed(update):
return
self.session.process_entities(update)
self._entity_cache.add(update)
@ -304,6 +350,9 @@ class UpdateMethods:
self._state_cache.update(update)
def _process_update(self: 'TelegramClient', update, others, entities=None):
if self._state_cache.update_already_processed(update):
return
update._entities = entities or {}
# This part is somewhat hot so we don't bother patching
@ -553,7 +602,7 @@ class UpdateMethods:
if not pts_date[0]:
# First-time, can't get difference. Get pts instead.
result = await self(functions.updates.GetStateRequest())
self._state_cache[None] = result.pts, result.date
self._state_cache[None] = result.pts, result.qts, result.date
return
result = await self(functions.updates.GetDifferenceRequest(

View File

@ -88,6 +88,13 @@ class Session(ABC):
"""
raise NotImplementedError
@abstractmethod
def get_channel_pts(self):
"""
Returns the ``Dict[int, int]`` with pts for all saved channels
"""
raise NotImplementedError
@abstractmethod
def set_update_state(self, entity_id, state):
"""

View File

@ -74,6 +74,9 @@ class MemorySession(Session):
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)
def get_channel_pts(self):
return {x[0]: x[1] for x in self._update_states.items() if x[0] != 0}
def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state

View File

@ -210,6 +210,14 @@ class SQLiteSession(MemorySession):
date, tz=datetime.timezone.utc)
return types.updates.State(pts, qts, date, seq, unread_count=0)
def get_channel_pts(self):
c = self._cursor()
try:
rows = c.execute('select id, pts from update_state').fetchall()
finally:
c.close()
return {x[0]: x[1] for x in rows if x[0] != 0}
def set_update_state(self, entity_id, state):
self._execute('insert or replace into update_state values (?,?,?,?,?)',
entity_id, state.pts, state.qts,

View File

@ -2,7 +2,6 @@ import inspect
from .tl import types
# Which updates have the following fields?
_has_channel_id = []
@ -26,30 +25,6 @@ def _fill():
_fill()
class StateCache:
"""
In-memory update state cache, defaultdict-like behaviour.
"""
def __init__(self, initial, loggers):
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
self._logger = loggers[__name__]
if initial:
self._pts_date = initial.pts, initial.date
else:
self._pts_date = None, None
def reset(self):
self.__dict__.clear()
self._pts_date = None, None
# TODO Call this when receiving responses too...?
def update(
self,
update,
*,
channel_id=None,
has_pts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewMessage,
types.UpdateDeleteMessages,
@ -63,7 +38,12 @@ class StateCache:
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
)),
))
has_qts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateBotStopped,
types.UpdateNewEncryptedMessage,
types.updates.State
))
has_date = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateUserPhoto,
types.UpdateEncryption,
@ -76,7 +56,7 @@ class StateCache:
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
)),
))
has_channel_pts = frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
@ -86,7 +66,36 @@ class StateCache:
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
)),
))
class StateCache:
"""
In-memory update state cache, defaultdict-like behaviour.
"""
_store: dict
def __init__(self, initial, loggers):
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
self._logger = loggers[__name__]
self._store = {}
if initial:
self._pts_date = initial.pts, initial.qts, initial.date
else:
self._pts_date = None, None, None
def reset(self):
self._store.clear()
self._pts_date = None, None, None
# TODO Call this when receiving responses too...?
def update(
self,
update,
*,
channel_id=None,
check_only=False
):
"""
@ -96,13 +105,20 @@ class StateCache:
if check_only:
return cid in has_pts or cid in has_date or cid in has_channel_pts
new_pts_date = tuple()
if cid in has_pts:
if cid in has_date:
self._pts_date = update.pts, update.date
new_pts_date += update.pts,
else:
self._pts_date = update.pts, self._pts_date[1]
elif cid in has_date:
self._pts_date = self._pts_date[0], update.date
new_pts_date += self._pts_date[0],
if cid in has_qts:
new_pts_date += update.qts,
else:
new_pts_date += self._pts_date[1],
if cid in has_date:
new_pts_date += update.date,
else:
new_pts_date += self._pts_date[2],
self._pts_date = new_pts_date
if cid in has_channel_pts:
if channel_id is None:
@ -112,7 +128,23 @@ class StateCache:
self._logger.info(
'Failed to retrieve channel_id from %s', update)
else:
self.__dict__[channel_id] = update.pts
self._store[channel_id] = update.pts
def update_already_processed(self, update):
cid = update.CONSTRUCTOR_ID
# If pts == 0, the update is from catch_up
if cid in has_pts and \
update.pts != 0 and \
update.pts >= self._pts_date[0]:
return True
if cid in has_qts and update.qts >= self._pts_date[1]:
return True
if cid in has_channel_pts:
channel_id = self.get_channel_id(update)
if update.pts != 0 and \
self._store.get(channel_id, 0) >= update.pts:
return True
return False
def get_channel_id(
self,
@ -155,10 +187,13 @@ class StateCache:
if item is None:
return self._pts_date
else:
return self.__dict__.get(item)
return self._store.get(item)
def __setitem__(self, where, value):
if where is None:
self._pts_date = value
else:
self.__dict__[where] = value
self._store[where] = value
def get_channel_pts(self):
return self._store