mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-07-30 01:39:47 +03:00
Fix catch_up
updates.py: - Fix catch_up getting stuck in infinite loop by saving qts - Save new state to session after catch_up - Save initial state immediately if catch_up was ran first time - If DifferenceTooLong was received, continue fetching with new pts instead of exiting - Catch up channels - Implement pts limit (pts_total_limit in chats and limit in channels) - Check if update was already processed before processing telegrambaseclient.py, sessions: - Get saved state from channels on startup and save channels state statecache.py: - Create separate variable _store in StateCache, because __dict__ also stores _logger variable - Move has_pts, has_qts, has_date, has_channel_pts outside of function arguments Fixes #3041
This commit is contained in:
parent
3d350c6087
commit
34bb2b8fc3
|
@ -6,6 +6,7 @@ import logging
|
|||
import platform
|
||||
import time
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
from .. import version, helpers, __name__ as __base_name__
|
||||
from ..crypto import rsa
|
||||
|
@ -401,9 +402,10 @@ class TelegramBaseClient(abc.ABC):
|
|||
self._authorized = None # None = unknown, False = no, True = yes
|
||||
|
||||
# Update state (for catching up after a disconnection)
|
||||
# TODO Get state from channels too
|
||||
self._state_cache = StateCache(
|
||||
self.session.get_update_state(0), self._log)
|
||||
for k, v in self.session.get_channel_pts().items():
|
||||
self._state_cache[k] = v
|
||||
|
||||
# Some further state for subclasses
|
||||
self._event_builders = []
|
||||
|
@ -630,15 +632,23 @@ class TelegramBaseClient(abc.ABC):
|
|||
await asyncio.wait(self._updates_queue)
|
||||
self._updates_queue.clear()
|
||||
|
||||
pts, date = self._state_cache[None]
|
||||
pts, qts, date = self._state_cache[None]
|
||||
if pts and date:
|
||||
self.session.set_update_state(0, types.updates.State(
|
||||
pts=pts,
|
||||
qts=0,
|
||||
qts=qts,
|
||||
date=date,
|
||||
seq=0,
|
||||
unread_count=0
|
||||
))
|
||||
for channel_id, pts in self._state_cache.get_channel_pts().items():
|
||||
self.session.set_update_state(channel_id, types.updates.State(
|
||||
pts=pts,
|
||||
qts=0,
|
||||
date=datetime.fromtimestamp(0),
|
||||
seq=0,
|
||||
unread_count=0
|
||||
))
|
||||
|
||||
self.session.close()
|
||||
|
||||
|
|
|
@ -7,10 +7,12 @@ import time
|
|||
import traceback
|
||||
import typing
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from .. import events, utils, errors
|
||||
from ..events.common import EventBuilder, EventCommon
|
||||
from ..tl import types, functions
|
||||
from ..tl.types import UpdateChannelTooLong
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .telegramclient import TelegramClient
|
||||
|
@ -211,7 +213,45 @@ class UpdateMethods:
|
|||
"""
|
||||
return [(callback, event) for event, callback in self._event_builders]
|
||||
|
||||
async def catch_up(self: 'TelegramClient'):
|
||||
async def _catch_up_channel(self: 'TelegramClient', channel_id: int, pts: int, limit: int = None):
|
||||
if self._state_cache[channel_id]:
|
||||
pts = self._state_cache[channel_id]
|
||||
|
||||
if not pts:
|
||||
# First-time, can't get difference. Get pts instead.
|
||||
result = await self(functions.channels.GetFullChannelRequest(channel_id))
|
||||
pts = self._state_cache[channel_id] = result.full_chat.pts
|
||||
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
d = await self(functions.updates.GetChannelDifferenceRequest(
|
||||
channel=channel_id,
|
||||
filter=types.ChannelMessagesFilterEmpty(),
|
||||
pts=pts,
|
||||
limit=limit or 100000 # If the limit isn't set, fetch all updates that we can
|
||||
))
|
||||
if isinstance(d, types.updates.ChannelDifference):
|
||||
pts = d.pts
|
||||
self._handle_update(types.Updates(
|
||||
users=d.users,
|
||||
chats=d.chats,
|
||||
date=None,
|
||||
seq=0,
|
||||
updates=d.other_updates + [
|
||||
types.UpdateNewChannelMessage(m, 0, 0)
|
||||
for m in d.new_messages
|
||||
]
|
||||
))
|
||||
elif isinstance(d, (types.updates.ChannelDifferenceTooLong,
|
||||
types.updates.ChannelDifferenceEmpty)):
|
||||
# If there is too much updates (ChannelDifferenceTooLong),
|
||||
# there is no way to get them without raising limit or GetHistoryRequest, so just break
|
||||
break
|
||||
finally:
|
||||
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
|
||||
|
||||
async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None):
|
||||
"""
|
||||
"Catches up" on the missed updates while the client was offline.
|
||||
You should call this method after registering the event handlers
|
||||
|
@ -224,15 +264,22 @@ class UpdateMethods:
|
|||
|
||||
await client.catch_up()
|
||||
"""
|
||||
pts, date = self._state_cache[None]
|
||||
if not pts:
|
||||
return
|
||||
pts, qts, date = self._state_cache[None]
|
||||
|
||||
self.session.catching_up = True
|
||||
try:
|
||||
if not pts:
|
||||
# Ran first time, get initial pts, qts and date and return
|
||||
result = await self(functions.updates.GetStateRequest())
|
||||
pts, qts, date = result.pts, result.qts, result.date
|
||||
return
|
||||
if not qts:
|
||||
qts = 0
|
||||
|
||||
channels_to_fetch = []
|
||||
while True:
|
||||
d = await self(functions.updates.GetDifferenceRequest(
|
||||
pts, date, 0
|
||||
pts, date, qts, pts_total_limit
|
||||
))
|
||||
if isinstance(d, (types.updates.DifferenceSlice,
|
||||
types.updates.Difference)):
|
||||
|
@ -241,43 +288,39 @@ class UpdateMethods:
|
|||
else:
|
||||
state = d.intermediate_state
|
||||
|
||||
pts, date = state.pts, state.date
|
||||
updates = []
|
||||
for update in d.other_updates:
|
||||
if isinstance(update, UpdateChannelTooLong):
|
||||
channels_to_fetch.append((update.channel_id, update.pts))
|
||||
else:
|
||||
updates.append(update)
|
||||
|
||||
pts, qts, date = state.pts, state.qts, state.date
|
||||
self._handle_update(types.Updates(
|
||||
users=d.users,
|
||||
chats=d.chats,
|
||||
date=state.date,
|
||||
seq=state.seq,
|
||||
updates=d.other_updates + [
|
||||
updates=updates + [
|
||||
types.UpdateNewMessage(m, 0, 0)
|
||||
for m in d.new_messages
|
||||
]
|
||||
))
|
||||
|
||||
# TODO Implement upper limit (max_pts)
|
||||
# We don't want to fetch updates we already know about.
|
||||
#
|
||||
# We may still get duplicates because the Difference
|
||||
# contains a lot of updates and presumably only has
|
||||
# the state for the last one, but at least we don't
|
||||
# unnecessarily fetch too many.
|
||||
#
|
||||
# updates.getDifference's pts_total_limit seems to mean
|
||||
# "how many pts is the request allowed to return", and
|
||||
# if there is more than that, it returns "too long" (so
|
||||
# there would be duplicate updates since we know about
|
||||
# some). This can be used to detect collisions (i.e.
|
||||
# it would return an update we have already seen).
|
||||
else:
|
||||
if isinstance(d, types.updates.DifferenceEmpty):
|
||||
date = d.date
|
||||
elif isinstance(d, types.updates.DifferenceTooLong):
|
||||
pts = d.pts
|
||||
elif isinstance(d, types.updates.DifferenceTooLong):
|
||||
pts = d.pts
|
||||
# If the limit isn't set, fetch all updates that we can
|
||||
if pts_total_limit is not None:
|
||||
break
|
||||
elif isinstance(d, types.updates.DifferenceEmpty):
|
||||
date = d.date
|
||||
break
|
||||
for channel_id, channel_pts in channels_to_fetch:
|
||||
await self._catch_up_channel(channel_id, channel_pts, limit)
|
||||
except (ConnectionError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
# TODO Save new pts to session
|
||||
self._state_cache._pts_date = (pts, date)
|
||||
self._state_cache[None] = (pts, qts, date)
|
||||
self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0))
|
||||
self.session.catching_up = False
|
||||
|
||||
# endregion
|
||||
|
@ -288,6 +331,9 @@ class UpdateMethods:
|
|||
# the order that the updates arrive in to update the pts and date to
|
||||
# be always-increasing. There is also no need to make this async.
|
||||
def _handle_update(self: 'TelegramClient', update):
|
||||
if self._state_cache.update_already_processed(update):
|
||||
return
|
||||
|
||||
self.session.process_entities(update)
|
||||
self._entity_cache.add(update)
|
||||
|
||||
|
@ -304,6 +350,9 @@ class UpdateMethods:
|
|||
self._state_cache.update(update)
|
||||
|
||||
def _process_update(self: 'TelegramClient', update, others, entities=None):
|
||||
if self._state_cache.update_already_processed(update):
|
||||
return
|
||||
|
||||
update._entities = entities or {}
|
||||
|
||||
# This part is somewhat hot so we don't bother patching
|
||||
|
@ -553,7 +602,7 @@ class UpdateMethods:
|
|||
if not pts_date[0]:
|
||||
# First-time, can't get difference. Get pts instead.
|
||||
result = await self(functions.updates.GetStateRequest())
|
||||
self._state_cache[None] = result.pts, result.date
|
||||
self._state_cache[None] = result.pts, result.qts, result.date
|
||||
return
|
||||
|
||||
result = await self(functions.updates.GetDifferenceRequest(
|
||||
|
|
|
@ -88,6 +88,13 @@ class Session(ABC):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_channel_pts(self):
|
||||
"""
|
||||
Returns the ``Dict[int, int]`` with pts for all saved channels
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_update_state(self, entity_id, state):
|
||||
"""
|
||||
|
|
|
@ -74,6 +74,9 @@ class MemorySession(Session):
|
|||
def get_update_state(self, entity_id):
|
||||
return self._update_states.get(entity_id, None)
|
||||
|
||||
def get_channel_pts(self):
|
||||
return {x[0]: x[1] for x in self._update_states.items() if x[0] != 0}
|
||||
|
||||
def set_update_state(self, entity_id, state):
|
||||
self._update_states[entity_id] = state
|
||||
|
||||
|
|
|
@ -210,6 +210,14 @@ class SQLiteSession(MemorySession):
|
|||
date, tz=datetime.timezone.utc)
|
||||
return types.updates.State(pts, qts, date, seq, unread_count=0)
|
||||
|
||||
def get_channel_pts(self):
|
||||
c = self._cursor()
|
||||
try:
|
||||
rows = c.execute('select id, pts from update_state').fetchall()
|
||||
finally:
|
||||
c.close()
|
||||
return {x[0]: x[1] for x in rows if x[0] != 0}
|
||||
|
||||
def set_update_state(self, entity_id, state):
|
||||
self._execute('insert or replace into update_state values (?,?,?,?,?)',
|
||||
entity_id, state.pts, state.qts,
|
||||
|
|
|
@ -2,7 +2,6 @@ import inspect
|
|||
|
||||
from .tl import types
|
||||
|
||||
|
||||
# Which updates have the following fields?
|
||||
_has_channel_id = []
|
||||
|
||||
|
@ -26,23 +25,70 @@ def _fill():
|
|||
_fill()
|
||||
|
||||
|
||||
has_pts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateNewMessage,
|
||||
types.UpdateDeleteMessages,
|
||||
types.UpdateReadHistoryInbox,
|
||||
types.UpdateReadHistoryOutbox,
|
||||
types.UpdateWebPage,
|
||||
types.UpdateReadMessagesContents,
|
||||
types.UpdateEditMessage,
|
||||
types.updates.State,
|
||||
types.updates.DifferenceTooLong,
|
||||
types.UpdateShortMessage,
|
||||
types.UpdateShortChatMessage,
|
||||
types.UpdateShortSentMessage
|
||||
))
|
||||
has_qts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateBotStopped,
|
||||
types.UpdateNewEncryptedMessage,
|
||||
types.updates.State
|
||||
))
|
||||
has_date = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateUserPhoto,
|
||||
types.UpdateEncryption,
|
||||
types.UpdateEncryptedMessagesRead,
|
||||
types.UpdateChatParticipantAdd,
|
||||
types.updates.DifferenceEmpty,
|
||||
types.UpdateShortMessage,
|
||||
types.UpdateShortChatMessage,
|
||||
types.UpdateShort,
|
||||
types.UpdatesCombined,
|
||||
types.Updates,
|
||||
types.UpdateShortSentMessage,
|
||||
))
|
||||
has_channel_pts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateChannelTooLong,
|
||||
types.UpdateNewChannelMessage,
|
||||
types.UpdateDeleteChannelMessages,
|
||||
types.UpdateEditChannelMessage,
|
||||
types.UpdateChannelWebPage,
|
||||
types.updates.ChannelDifferenceEmpty,
|
||||
types.updates.ChannelDifferenceTooLong,
|
||||
types.updates.ChannelDifference
|
||||
))
|
||||
|
||||
|
||||
class StateCache:
|
||||
"""
|
||||
In-memory update state cache, defaultdict-like behaviour.
|
||||
"""
|
||||
_store: dict
|
||||
|
||||
def __init__(self, initial, loggers):
|
||||
# We only care about the pts and the date. By using a tuple which
|
||||
# is lightweight and immutable we can easily copy them around to
|
||||
# each update in case they need to fetch missing entities.
|
||||
self._logger = loggers[__name__]
|
||||
self._store = {}
|
||||
if initial:
|
||||
self._pts_date = initial.pts, initial.date
|
||||
self._pts_date = initial.pts, initial.qts, initial.date
|
||||
else:
|
||||
self._pts_date = None, None
|
||||
self._pts_date = None, None, None
|
||||
|
||||
def reset(self):
|
||||
self.__dict__.clear()
|
||||
self._pts_date = None, None
|
||||
self._store.clear()
|
||||
self._pts_date = None, None, None
|
||||
|
||||
# TODO Call this when receiving responses too...?
|
||||
def update(
|
||||
|
@ -50,43 +96,6 @@ class StateCache:
|
|||
update,
|
||||
*,
|
||||
channel_id=None,
|
||||
has_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateNewMessage,
|
||||
types.UpdateDeleteMessages,
|
||||
types.UpdateReadHistoryInbox,
|
||||
types.UpdateReadHistoryOutbox,
|
||||
types.UpdateWebPage,
|
||||
types.UpdateReadMessagesContents,
|
||||
types.UpdateEditMessage,
|
||||
types.updates.State,
|
||||
types.updates.DifferenceTooLong,
|
||||
types.UpdateShortMessage,
|
||||
types.UpdateShortChatMessage,
|
||||
types.UpdateShortSentMessage
|
||||
)),
|
||||
has_date=frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateUserPhoto,
|
||||
types.UpdateEncryption,
|
||||
types.UpdateEncryptedMessagesRead,
|
||||
types.UpdateChatParticipantAdd,
|
||||
types.updates.DifferenceEmpty,
|
||||
types.UpdateShortMessage,
|
||||
types.UpdateShortChatMessage,
|
||||
types.UpdateShort,
|
||||
types.UpdatesCombined,
|
||||
types.Updates,
|
||||
types.UpdateShortSentMessage,
|
||||
)),
|
||||
has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateChannelTooLong,
|
||||
types.UpdateNewChannelMessage,
|
||||
types.UpdateDeleteChannelMessages,
|
||||
types.UpdateEditChannelMessage,
|
||||
types.UpdateChannelWebPage,
|
||||
types.updates.ChannelDifferenceEmpty,
|
||||
types.updates.ChannelDifferenceTooLong,
|
||||
types.updates.ChannelDifference
|
||||
)),
|
||||
check_only=False
|
||||
):
|
||||
"""
|
||||
|
@ -96,13 +105,20 @@ class StateCache:
|
|||
if check_only:
|
||||
return cid in has_pts or cid in has_date or cid in has_channel_pts
|
||||
|
||||
new_pts_date = tuple()
|
||||
if cid in has_pts:
|
||||
if cid in has_date:
|
||||
self._pts_date = update.pts, update.date
|
||||
else:
|
||||
self._pts_date = update.pts, self._pts_date[1]
|
||||
elif cid in has_date:
|
||||
self._pts_date = self._pts_date[0], update.date
|
||||
new_pts_date += update.pts,
|
||||
else:
|
||||
new_pts_date += self._pts_date[0],
|
||||
if cid in has_qts:
|
||||
new_pts_date += update.qts,
|
||||
else:
|
||||
new_pts_date += self._pts_date[1],
|
||||
if cid in has_date:
|
||||
new_pts_date += update.date,
|
||||
else:
|
||||
new_pts_date += self._pts_date[2],
|
||||
self._pts_date = new_pts_date
|
||||
|
||||
if cid in has_channel_pts:
|
||||
if channel_id is None:
|
||||
|
@ -112,7 +128,23 @@ class StateCache:
|
|||
self._logger.info(
|
||||
'Failed to retrieve channel_id from %s', update)
|
||||
else:
|
||||
self.__dict__[channel_id] = update.pts
|
||||
self._store[channel_id] = update.pts
|
||||
|
||||
def update_already_processed(self, update):
|
||||
cid = update.CONSTRUCTOR_ID
|
||||
# If pts == 0, the update is from catch_up
|
||||
if cid in has_pts and \
|
||||
update.pts != 0 and \
|
||||
update.pts >= self._pts_date[0]:
|
||||
return True
|
||||
if cid in has_qts and update.qts >= self._pts_date[1]:
|
||||
return True
|
||||
if cid in has_channel_pts:
|
||||
channel_id = self.get_channel_id(update)
|
||||
if update.pts != 0 and \
|
||||
self._store.get(channel_id, 0) >= update.pts:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_channel_id(
|
||||
self,
|
||||
|
@ -120,8 +152,8 @@ class StateCache:
|
|||
has_channel_id=frozenset(_has_channel_id),
|
||||
# Hardcoded because only some with message are for channels
|
||||
has_message=frozenset(x.CONSTRUCTOR_ID for x in (
|
||||
types.UpdateNewChannelMessage,
|
||||
types.UpdateEditChannelMessage
|
||||
types.UpdateNewChannelMessage,
|
||||
types.UpdateEditChannelMessage
|
||||
))
|
||||
):
|
||||
"""
|
||||
|
@ -155,10 +187,13 @@ class StateCache:
|
|||
if item is None:
|
||||
return self._pts_date
|
||||
else:
|
||||
return self.__dict__.get(item)
|
||||
return self._store.get(item)
|
||||
|
||||
def __setitem__(self, where, value):
|
||||
if where is None:
|
||||
self._pts_date = value
|
||||
else:
|
||||
self.__dict__[where] = value
|
||||
self._store[where] = value
|
||||
|
||||
def get_channel_pts(self):
|
||||
return self._store
|
||||
|
|
Loading…
Reference in New Issue
Block a user