Attempt to load and save MessageBox state

This commit is contained in:
Lonami Exo 2022-05-13 17:40:03 +02:00
parent 053a0052c8
commit a5c3df2743
7 changed files with 60 additions and 8 deletions

View File

@ -1,2 +1,3 @@
from .entitycache import EntityCache from .entitycache import EntityCache
from .messagebox import MessageBox, GapError from .messagebox import MessageBox, GapError
from .session import SessionState, ChannelState, Entity, EntityType

View File

@ -1,5 +1,6 @@
from typing import Optional, Tuple from typing import Optional, Tuple
from enum import IntEnum from enum import IntEnum
from ..tl.types import InputPeerUser, InputPeerChat, InputPeerChannel
class SessionState: class SessionState:
@ -175,3 +176,11 @@ class Entity:
def __bytes__(self): def __bytes__(self):
return struct.pack('<Bqq', self.ty, self.id, self.hash) return struct.pack('<Bqq', self.ty, self.id, self.hash)
def _as_input_peer(self):
if self.is_user:
return InputPeerUser(self.id, self.hash)
elif self.ty == EntityType.GROUP:
return InputPeerChat(self.id)
else:
return InputPeerChannel(self.id, self.hash)

View File

@ -15,7 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
from ..sessions import Session, SQLiteSession, MemorySession from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import functions, types from ..tl import functions, types
from ..tl.alltlobjects import LAYER from ..tl.alltlobjects import LAYER
from .._updates import MessageBox, EntityCache as MbEntityCache from .._updates import MessageBox, EntityCache as MbEntityCache, SessionState, ChannelState, Entity, EntityType
DEFAULT_DC_ID = 2 DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -535,6 +535,23 @@ class TelegramBaseClient(abc.ABC):
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
await self.session.save() await self.session.save()
if self._catch_up:
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = []
for entity_id, state in await self.session.get_update_states():
if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, state.date, state.seq, None)
else:
cs.append(ChannelState(entity_id, state.pts))
self._message_box.load(ss, cs)
for state in cs:
entity = await self.session.get_input_entity(state.channel_id)
if entity:
self._mb_entity_cache.put(Entity(EntityType.CHANNEL, entity.channel_id, entity.access_hash))
self._init_request.query = functions.help.GetConfigRequest() self._init_request.query = functions.help.GetConfigRequest()
await self._sender.send(functions.InvokeWithLayerRequest( await self._sender.send(functions.InvokeWithLayerRequest(
@ -642,6 +659,17 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._event_handler_tasks) await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear() self._event_handler_tasks.clear()
entities = self._entity_cache.get_all_entities()
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats.
await self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []))
session_state, channel_states = self._message_box.session_state()
await self.session.set_update_state(0, types.updates.State(ss.pts, ss.qts, ss.date, ss.seq, unread_count=0))
for channel_id, pts in channel_states.items():
await self.session.set_update_state(channel_id, types.updates.State(pts, 0, None, 0, unread_count=0))
await self.session.close() await self.session.close()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):

View File

@ -266,7 +266,7 @@ class UpdateMethods:
self._log[__name__].info('Getting difference for account updates') self._log[__name__].info('Getting difference for account updates')
diff = await self(get_diff) diff = await self(get_diff)
updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache) updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats)) updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
continue continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
@ -274,7 +274,7 @@ class UpdateMethods:
self._log[__name__].info('Getting difference for channel updates') self._log[__name__].info('Getting difference for channel updates')
diff = await self(get_diff) diff = await self(get_diff)
updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache) updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats)) updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
continue continue
deadline = self._message_box.check_deadlines() deadline = self._message_box.check_deadlines()
@ -293,14 +293,11 @@ class UpdateMethods:
except GapError: except GapError:
continue # get(_channel)_difference will start returning requests continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats)) updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats))
except Exception: except Exception:
self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)') self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)')
async def _preprocess_updates(self, updates, users, chats): def _preprocess_updates(self, updates, users, chats):
await self.session.process_entities(update)
self._entity_cache.add(update)
self._mb_entity_cache.extend(users, chats) self._mb_entity_cache.extend(users, chats)
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)} for x in itertools.chain(users, chats)}

View File

@ -97,6 +97,12 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def get_update_states(self):
"""
Returns an iterable over all known pairs of ``(entity ID, update state)``.
"""
@abstractmethod @abstractmethod
async def close(self): async def close(self):
""" """

View File

@ -77,6 +77,9 @@ class MemorySession(Session):
async def set_update_state(self, entity_id, state): async def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state self._update_states[entity_id] = state
async def get_update_states(self):
return self._update_states.items()
async def close(self): async def close(self):
pass pass

View File

@ -215,6 +215,14 @@ class SQLiteSession(MemorySession):
entity_id, state.pts, state.qts, entity_id, state.pts, state.qts,
state.date.timestamp(), state.seq) state.date.timestamp(), state.seq)
async def get_update_states(self):
c = self._cursor()
try:
rows = c.execute('select id, pts, qts, date, seq from update_state').fetchall()
return ((row[0], types.updates.State(*row[1:], unread_count=0)) for row in rows)
finally:
c.close()
async def save(self): async def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
# This is a no-op if there are no changes to commit, so there's # This is a no-op if there are no changes to commit, so there's