mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-31 07:57:38 +03:00 
			
		
		
		
	Persist updates.State upon disconnection
This commit is contained in:
		
							parent
							
								
									e2a0de1913
								
							
						
					
					
						commit
						2a00bcaa12
					
				|  | @ -67,6 +67,25 @@ class Session(ABC): | |||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def get_update_state(self, entity_id): | ||||
|         """ | ||||
|         Returns the ``UpdateState`` associated with the given `entity_id`. | ||||
|         If the `entity_id` is 0, it should return the ``UpdateState`` for | ||||
|         no specific channel (the "general" state). If no state is known | ||||
|         it should ``return None``. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def set_update_state(self, entity_id, state): | ||||
|         """ | ||||
|         Sets the given ``UpdateState`` for the specified `entity_id`, which | ||||
|         should be 0 if the ``UpdateState`` is the "general" state (and not | ||||
|         for any specific channel). | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def close(self): | ||||
|         """ | ||||
|  |  | |||
|  | @ -35,6 +35,7 @@ class MemorySession(Session): | |||
| 
 | ||||
|         self._files = {} | ||||
|         self._entities = set() | ||||
|         self._update_states = {} | ||||
| 
 | ||||
|     def set_dc(self, dc_id, server_address, port): | ||||
|         self._dc_id = dc_id or 0 | ||||
|  | @ -57,6 +58,12 @@ class MemorySession(Session): | |||
|     def auth_key(self, value): | ||||
|         self._auth_key = value | ||||
| 
 | ||||
|     def get_update_state(self, entity_id): | ||||
|         return self._update_states.get(entity_id, None) | ||||
| 
 | ||||
|     def set_update_state(self, entity_id, state): | ||||
|         self._update_states[entity_id] = state | ||||
| 
 | ||||
|     def close(self): | ||||
|         pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,6 +5,8 @@ from base64 import b64decode | |||
| from os.path import isfile as file_exists | ||||
| from threading import Lock, RLock | ||||
| 
 | ||||
| from telethon.tl import types | ||||
| 
 | ||||
| from .memory import MemorySession, _SentFileType | ||||
| from .. import utils | ||||
| from ..crypto import AuthKey | ||||
|  | @ -226,6 +228,22 @@ class SQLiteSession(MemorySession): | |||
|             )) | ||||
|             c.close() | ||||
| 
 | ||||
|     def get_update_state(self, entity_id): | ||||
|         c = self._cursor() | ||||
|         row = c.execute('select pts, qts, date, seq from update_state ' | ||||
|                         'where id = ?', (entity_id,)).fetchone() | ||||
|         c.close() | ||||
|         if row: | ||||
|             return types.updates.State(*row) | ||||
| 
 | ||||
|     def set_update_state(self, entity_id, state): | ||||
|         with self._db_lock: | ||||
|             c = self._cursor() | ||||
|             c.execute('insert or replace into update_state values (?,?,?,?,?)', | ||||
|                       (entity_id, state.pts, state.qts, state.date, state.seq)) | ||||
|             c.close() | ||||
|             self.save() | ||||
| 
 | ||||
|     def save(self): | ||||
|         """Saves the current session object as session_user_id.session""" | ||||
|         with self._db_lock: | ||||
|  |  | |||
|  | @ -275,6 +275,7 @@ class TelegramBareClient: | |||
| 
 | ||||
|         # TODO Shall we clear the _exported_sessions, or may be reused? | ||||
|         self._first_request = True  # On reconnect it will be first again | ||||
|         self.session.set_update_state(0, self.updates.get_update_state(0)) | ||||
|         self.session.close() | ||||
| 
 | ||||
|     def _reconnect(self, new_dc=None): | ||||
|  |  | |||
|  | @ -110,6 +110,10 @@ class UpdateState: | |||
|                 # We don't want to crash a worker thread due to any reason | ||||
|                 __log__.exception('Unhandled exception on worker %d', wid) | ||||
| 
 | ||||
|     def get_update_state(self, entity_id): | ||||
|         """Gets the updates.State corresponding to the given entity or 0.""" | ||||
|         return self._state | ||||
| 
 | ||||
|     def process(self, update): | ||||
|         """Processes an update object. This method is normally called by | ||||
|            the library itself. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user