mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-22 01:16:35 +03:00
Attempt to load and save MessageBox state
This commit is contained in:
parent
053a0052c8
commit
a5c3df2743
|
@ -1,2 +1,3 @@
|
|||
from .entitycache import EntityCache
|
||||
from .messagebox import MessageBox, GapError
|
||||
from .session import SessionState, ChannelState, Entity, EntityType
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, Tuple
|
||||
from enum import IntEnum
|
||||
from ..tl.types import InputPeerUser, InputPeerChat, InputPeerChannel
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
@ -175,3 +176,11 @@ class Entity:
|
|||
|
||||
def __bytes__(self):
|
||||
return struct.pack('<Bqq', self.ty, self.id, self.hash)
|
||||
|
||||
def _as_input_peer(self):
|
||||
if self.is_user:
|
||||
return InputPeerUser(self.id, self.hash)
|
||||
elif self.ty == EntityType.GROUP:
|
||||
return InputPeerChat(self.id)
|
||||
else:
|
||||
return InputPeerChannel(self.id, self.hash)
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
|
|||
from ..sessions import Session, SQLiteSession, MemorySession
|
||||
from ..tl import functions, types
|
||||
from ..tl.alltlobjects import LAYER
|
||||
from .._updates import MessageBox, EntityCache as MbEntityCache
|
||||
from .._updates import MessageBox, EntityCache as MbEntityCache, SessionState, ChannelState, Entity, EntityType
|
||||
|
||||
DEFAULT_DC_ID = 2
|
||||
DEFAULT_IPV4_IP = '149.154.167.51'
|
||||
|
@ -535,6 +535,23 @@ class TelegramBaseClient(abc.ABC):
|
|||
self.session.auth_key = self._sender.auth_key
|
||||
await self.session.save()
|
||||
|
||||
if self._catch_up:
|
||||
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
|
||||
cs = []
|
||||
|
||||
for entity_id, state in await self.session.get_update_states():
|
||||
if entity_id == 0:
|
||||
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
|
||||
ss = SessionState(0, 0, False, state.pts, state.qts, state.date, state.seq, None)
|
||||
else:
|
||||
cs.append(ChannelState(entity_id, state.pts))
|
||||
|
||||
self._message_box.load(ss, cs)
|
||||
for state in cs:
|
||||
entity = await self.session.get_input_entity(state.channel_id)
|
||||
if entity:
|
||||
self._mb_entity_cache.put(Entity(EntityType.CHANNEL, entity.channel_id, entity.access_hash))
|
||||
|
||||
self._init_request.query = functions.help.GetConfigRequest()
|
||||
|
||||
await self._sender.send(functions.InvokeWithLayerRequest(
|
||||
|
@ -642,6 +659,17 @@ class TelegramBaseClient(abc.ABC):
|
|||
await asyncio.wait(self._event_handler_tasks)
|
||||
self._event_handler_tasks.clear()
|
||||
|
||||
entities = self._entity_cache.get_all_entities()
|
||||
|
||||
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
|
||||
# It doesn't matter if we put users in the list of chats.
|
||||
await self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []))
|
||||
|
||||
session_state, channel_states = self._message_box.session_state()
|
||||
await self.session.set_update_state(0, types.updates.State(ss.pts, ss.qts, ss.date, ss.seq, unread_count=0))
|
||||
for channel_id, pts in channel_states.items():
|
||||
await self.session.set_update_state(channel_id, types.updates.State(pts, 0, None, 0, unread_count=0))
|
||||
|
||||
await self.session.close()
|
||||
|
||||
async def _disconnect(self: 'TelegramClient'):
|
||||
|
|
|
@ -266,7 +266,7 @@ class UpdateMethods:
|
|||
self._log[__name__].info('Getting difference for account updates')
|
||||
diff = await self(get_diff)
|
||||
updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache)
|
||||
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
|
||||
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
|
||||
continue
|
||||
|
||||
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
|
||||
|
@ -274,7 +274,7 @@ class UpdateMethods:
|
|||
self._log[__name__].info('Getting difference for channel updates')
|
||||
diff = await self(get_diff)
|
||||
updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache)
|
||||
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
|
||||
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
|
||||
continue
|
||||
|
||||
deadline = self._message_box.check_deadlines()
|
||||
|
@ -293,14 +293,11 @@ class UpdateMethods:
|
|||
except GapError:
|
||||
continue # get(_channel)_difference will start returning requests
|
||||
|
||||
updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats))
|
||||
updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats))
|
||||
except Exception:
|
||||
self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)')
|
||||
|
||||
async def _preprocess_updates(self, updates, users, chats):
|
||||
await self.session.process_entities(update)
|
||||
self._entity_cache.add(update)
|
||||
|
||||
def _preprocess_updates(self, updates, users, chats):
|
||||
self._mb_entity_cache.extend(users, chats)
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(users, chats)}
|
||||
|
|
|
@ -97,6 +97,12 @@ class Session(ABC):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_update_states(self):
|
||||
"""
|
||||
Returns an iterable over all known pairs of ``(entity ID, update state)``.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""
|
||||
|
|
|
@ -77,6 +77,9 @@ class MemorySession(Session):
|
|||
async def set_update_state(self, entity_id, state):
|
||||
self._update_states[entity_id] = state
|
||||
|
||||
async def get_update_states(self):
|
||||
return self._update_states.items()
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -215,6 +215,14 @@ class SQLiteSession(MemorySession):
|
|||
entity_id, state.pts, state.qts,
|
||||
state.date.timestamp(), state.seq)
|
||||
|
||||
async def get_update_states(self):
|
||||
c = self._cursor()
|
||||
try:
|
||||
rows = c.execute('select id, pts, qts, date, seq from update_state').fetchall()
|
||||
return ((row[0], types.updates.State(*row[1:], unread_count=0)) for row in rows)
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def save(self):
|
||||
"""Saves the current session object as session_user_id.session"""
|
||||
# This is a no-op if there are no changes to commit, so there's
|
||||
|
|
Loading…
Reference in New Issue
Block a user