Persist updates.State upon disconnection

This commit is contained in:
Lonami Exo 2018-04-25 13:37:29 +02:00
parent e2a0de1913
commit 2a00bcaa12
5 changed files with 49 additions and 0 deletions

View File

@ -67,6 +67,25 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_update_state(self, entity_id):
"""
Returns the ``UpdateState`` associated with the given `entity_id`.
If the `entity_id` is 0, it should return the ``UpdateState`` for
no specific channel (the "general" state). If no state is known
it should ``return None``.
"""
raise NotImplementedError
@abstractmethod
def set_update_state(self, entity_id, state):
"""
Sets the given ``UpdateState`` for the specified `entity_id`, which
should be 0 if the ``UpdateState`` is the "general" state (and not
for any specific channel).
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def close(self): def close(self):
""" """

View File

@ -35,6 +35,7 @@ class MemorySession(Session):
self._files = {} self._files = {}
self._entities = set() self._entities = set()
self._update_states = {}
def set_dc(self, dc_id, server_address, port): def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id or 0 self._dc_id = dc_id or 0
@ -57,6 +58,12 @@ class MemorySession(Session):
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)
def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state
def close(self): def close(self):
pass pass

View File

@ -5,6 +5,8 @@ from base64 import b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock, RLock from threading import Lock, RLock
from telethon.tl import types
from .memory import MemorySession, _SentFileType from .memory import MemorySession, _SentFileType
from .. import utils from .. import utils
from ..crypto import AuthKey from ..crypto import AuthKey
@ -226,6 +228,22 @@ class SQLiteSession(MemorySession):
)) ))
c.close() c.close()
def get_update_state(self, entity_id):
c = self._cursor()
row = c.execute('select pts, qts, date, seq from update_state '
'where id = ?', (entity_id,)).fetchone()
c.close()
if row:
return types.updates.State(*row)
def set_update_state(self, entity_id, state):
with self._db_lock:
c = self._cursor()
c.execute('insert or replace into update_state values (?,?,?,?,?)',
(entity_id, state.pts, state.qts, state.date, state.seq))
c.close()
self.save()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
with self._db_lock: with self._db_lock:

View File

@ -275,6 +275,7 @@ class TelegramBareClient:
# TODO Shall we clear the _exported_sessions, or may be reused? # TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again self._first_request = True # On reconnect it will be first again
self.session.set_update_state(0, self.updates.get_update_state(0))
self.session.close() self.session.close()
def _reconnect(self, new_dc=None): def _reconnect(self, new_dc=None):

View File

@ -110,6 +110,10 @@ class UpdateState:
# We don't want to crash a worker thread due to any reason # We don't want to crash a worker thread due to any reason
__log__.exception('Unhandled exception on worker %d', wid) __log__.exception('Unhandled exception on worker %d', wid)
def get_update_state(self, entity_id):
"""Gets the updates.State corresponding to the given entity or 0."""
return self._state
def process(self, update): def process(self, update):
"""Processes an update object. This method is normally called by """Processes an update object. This method is normally called by
the library itself. the library itself.