Attempt to load and save MessageBox state

This commit is contained in:
Lonami Exo 2022-05-13 17:40:03 +02:00
parent 053a0052c8
commit a5c3df2743
7 changed files with 60 additions and 8 deletions

View File

@ -1,2 +1,3 @@
from .entitycache import EntityCache
from .messagebox import MessageBox, GapError
from .session import SessionState, ChannelState, Entity, EntityType

View File

@ -1,5 +1,6 @@
from typing import Optional, Tuple
from enum import IntEnum
from ..tl.types import InputPeerUser, InputPeerChat, InputPeerChannel
class SessionState:
@ -175,3 +176,11 @@ class Entity:
def __bytes__(self):
return struct.pack('<Bqq', self.ty, self.id, self.hash)
def _as_input_peer(self):
if self.is_user:
return InputPeerUser(self.id, self.hash)
elif self.ty == EntityType.GROUP:
return InputPeerChat(self.id)
else:
return InputPeerChannel(self.id, self.hash)

View File

@ -15,7 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import functions, types
from ..tl.alltlobjects import LAYER
from .._updates import MessageBox, EntityCache as MbEntityCache
from .._updates import MessageBox, EntityCache as MbEntityCache, SessionState, ChannelState, Entity, EntityType
DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51'
@ -535,6 +535,23 @@ class TelegramBaseClient(abc.ABC):
self.session.auth_key = self._sender.auth_key
await self.session.save()
if self._catch_up:
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = []
for entity_id, state in await self.session.get_update_states():
if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, state.date, state.seq, None)
else:
cs.append(ChannelState(entity_id, state.pts))
self._message_box.load(ss, cs)
for state in cs:
entity = await self.session.get_input_entity(state.channel_id)
if entity:
self._mb_entity_cache.put(Entity(EntityType.CHANNEL, entity.channel_id, entity.access_hash))
self._init_request.query = functions.help.GetConfigRequest()
await self._sender.send(functions.InvokeWithLayerRequest(
@ -642,6 +659,17 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear()
entities = self._entity_cache.get_all_entities()
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats.
await self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []))
session_state, channel_states = self._message_box.session_state()
await self.session.set_update_state(0, types.updates.State(ss.pts, ss.qts, ss.date, ss.seq, unread_count=0))
for channel_id, pts in channel_states.items():
await self.session.set_update_state(channel_id, types.updates.State(pts, 0, None, 0, unread_count=0))
await self.session.close()
async def _disconnect(self: 'TelegramClient'):

View File

@ -266,7 +266,7 @@ class UpdateMethods:
self._log[__name__].info('Getting difference for account updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
@ -274,7 +274,7 @@ class UpdateMethods:
self._log[__name__].info('Getting difference for channel updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
continue
deadline = self._message_box.check_deadlines()
@ -293,14 +293,11 @@ class UpdateMethods:
except GapError:
continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats))
updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats))
except Exception:
self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)')
async def _preprocess_updates(self, updates, users, chats):
await self.session.process_entities(update)
self._entity_cache.add(update)
def _preprocess_updates(self, updates, users, chats):
self._mb_entity_cache.extend(users, chats)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)}

View File

@ -97,6 +97,12 @@ class Session(ABC):
"""
raise NotImplementedError
@abstractmethod
async def get_update_states(self):
"""
Returns an iterable over all known pairs of ``(entity ID, update state)``.
"""
@abstractmethod
async def close(self):
"""

View File

@ -77,6 +77,9 @@ class MemorySession(Session):
async def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state
async def get_update_states(self):
return self._update_states.items()
async def close(self):
pass

View File

@ -215,6 +215,14 @@ class SQLiteSession(MemorySession):
entity_id, state.pts, state.qts,
state.date.timestamp(), state.seq)
async def get_update_states(self):
c = self._cursor()
try:
rows = c.execute('select id, pts, qts, date, seq from update_state').fetchall()
return ((row[0], types.updates.State(*row[1:], unread_count=0)) for row in rows)
finally:
c.close()
async def save(self):
"""Saves the current session object as session_user_id.session"""
# This is a no-op if there are no changes to commit, so there's