mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-26 03:13:45 +03:00
Persist updates.State upon disconnection
This commit is contained in:
parent
e2a0de1913
commit
2a00bcaa12
|
@ -67,6 +67,25 @@ class Session(ABC):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_update_state(self, entity_id):
|
||||||
|
"""
|
||||||
|
Returns the ``UpdateState`` associated with the given `entity_id`.
|
||||||
|
If the `entity_id` is 0, it should return the ``UpdateState`` for
|
||||||
|
no specific channel (the "general" state). If no state is known
|
||||||
|
it should ``return None``.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_update_state(self, entity_id, state):
|
||||||
|
"""
|
||||||
|
Sets the given ``UpdateState`` for the specified `entity_id`, which
|
||||||
|
should be 0 if the ``UpdateState`` is the "general" state (and not
|
||||||
|
for any specific channel).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -35,6 +35,7 @@ class MemorySession(Session):
|
||||||
|
|
||||||
self._files = {}
|
self._files = {}
|
||||||
self._entities = set()
|
self._entities = set()
|
||||||
|
self._update_states = {}
|
||||||
|
|
||||||
def set_dc(self, dc_id, server_address, port):
|
def set_dc(self, dc_id, server_address, port):
|
||||||
self._dc_id = dc_id or 0
|
self._dc_id = dc_id or 0
|
||||||
|
@ -57,6 +58,12 @@ class MemorySession(Session):
|
||||||
def auth_key(self, value):
|
def auth_key(self, value):
|
||||||
self._auth_key = value
|
self._auth_key = value
|
||||||
|
|
||||||
|
def get_update_state(self, entity_id):
|
||||||
|
return self._update_states.get(entity_id, None)
|
||||||
|
|
||||||
|
def set_update_state(self, entity_id, state):
|
||||||
|
self._update_states[entity_id] = state
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,8 @@ from base64 import b64decode
|
||||||
from os.path import isfile as file_exists
|
from os.path import isfile as file_exists
|
||||||
from threading import Lock, RLock
|
from threading import Lock, RLock
|
||||||
|
|
||||||
|
from telethon.tl import types
|
||||||
|
|
||||||
from .memory import MemorySession, _SentFileType
|
from .memory import MemorySession, _SentFileType
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..crypto import AuthKey
|
from ..crypto import AuthKey
|
||||||
|
@ -226,6 +228,22 @@ class SQLiteSession(MemorySession):
|
||||||
))
|
))
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
|
def get_update_state(self, entity_id):
|
||||||
|
c = self._cursor()
|
||||||
|
row = c.execute('select pts, qts, date, seq from update_state '
|
||||||
|
'where id = ?', (entity_id,)).fetchone()
|
||||||
|
c.close()
|
||||||
|
if row:
|
||||||
|
return types.updates.State(*row)
|
||||||
|
|
||||||
|
def set_update_state(self, entity_id, state):
|
||||||
|
with self._db_lock:
|
||||||
|
c = self._cursor()
|
||||||
|
c.execute('insert or replace into update_state values (?,?,?,?,?)',
|
||||||
|
(entity_id, state.pts, state.qts, state.date, state.seq))
|
||||||
|
c.close()
|
||||||
|
self.save()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Saves the current session object as session_user_id.session"""
|
"""Saves the current session object as session_user_id.session"""
|
||||||
with self._db_lock:
|
with self._db_lock:
|
||||||
|
|
|
@ -275,6 +275,7 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# TODO Shall we clear the _exported_sessions, or may be reused?
|
# TODO Shall we clear the _exported_sessions, or may be reused?
|
||||||
self._first_request = True # On reconnect it will be first again
|
self._first_request = True # On reconnect it will be first again
|
||||||
|
self.session.set_update_state(0, self.updates.get_update_state(0))
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
|
||||||
def _reconnect(self, new_dc=None):
|
def _reconnect(self, new_dc=None):
|
||||||
|
|
|
@ -110,6 +110,10 @@ class UpdateState:
|
||||||
# We don't want to crash a worker thread due to any reason
|
# We don't want to crash a worker thread due to any reason
|
||||||
__log__.exception('Unhandled exception on worker %d', wid)
|
__log__.exception('Unhandled exception on worker %d', wid)
|
||||||
|
|
||||||
|
def get_update_state(self, entity_id):
|
||||||
|
"""Gets the updates.State corresponding to the given entity or 0."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
def process(self, update):
|
def process(self, update):
|
||||||
"""Processes an update object. This method is normally called by
|
"""Processes an update object. This method is normally called by
|
||||||
the library itself.
|
the library itself.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user