mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-31 07:57:38 +03:00 
			
		
		
		
	Make sessions async
SQLiteSession is not updated, don't try to use it
This commit is contained in:
		
							parent
							
								
									43f629f665
								
							
						
					
					
						commit
						d2de0f3aca
					
				|  | @ -597,7 +597,7 @@ class AuthMethods: | |||
|         self._state_cache.reset() | ||||
| 
 | ||||
|         await self.disconnect() | ||||
|         self.session.delete() | ||||
|         await self.session.delete() | ||||
|         return True | ||||
| 
 | ||||
|     async def edit_2fa( | ||||
|  |  | |||
|  | @ -55,7 +55,7 @@ class _DirectDownloadIter(RequestIter): | |||
|                     if option.ip_address == self.client.session.server_address: | ||||
|                         self.client.session.set_dc( | ||||
|                             option.id, option.ip_address, option.port) | ||||
|                         self.client.session.save() | ||||
|                         await self.client.session.save() | ||||
|                         break | ||||
| 
 | ||||
|                 # TODO Figure out why the session may have the wrong DC ID | ||||
|  | @ -402,7 +402,7 @@ class DownloadMethods: | |||
|             if isinstance(message.action, | ||||
|                           types.MessageActionChatEditPhoto): | ||||
|                 media = media.photo | ||||
|         | ||||
| 
 | ||||
|         if isinstance(media, types.MessageMediaWebPage): | ||||
|             if isinstance(media.webpage, types.WebPage): | ||||
|                 media = media.webpage.document or media.webpage.photo | ||||
|  |  | |||
|  | @ -1019,7 +1019,7 @@ class MessageMethods: | |||
|     async def edit_message( | ||||
|             self: 'TelegramClient', | ||||
|             entity: 'typing.Union[hints.EntityLike, types.Message]', | ||||
|             message: 'hints.MessageLike' = None, | ||||
|             message: 'hints.MessageIDLike' = None, | ||||
|             text: str = None, | ||||
|             *, | ||||
|             parse_mode: str = (), | ||||
|  |  | |||
|  | @ -412,10 +412,7 @@ class TelegramBaseClient(abc.ABC): | |||
| 
 | ||||
|         self._authorized = None  # None = unknown, False = no, True = yes | ||||
| 
 | ||||
|         # Update state (for catching up after a disconnection) | ||||
|         # TODO Get state from channels too | ||||
|         self._state_cache = StateCache( | ||||
|             self.session.get_update_state(0), self._log) | ||||
|         self._state_cache = StateCache(None, self._log) | ||||
| 
 | ||||
|         # Some further state for subclasses | ||||
|         self._event_builders = [] | ||||
|  | @ -522,6 +519,11 @@ class TelegramBaseClient(abc.ABC): | |||
|                 except OSError: | ||||
|                     print('Failed to connect') | ||||
|         """ | ||||
|         # Update state (for catching up after a disconnection) | ||||
|         # TODO Get state from channels too | ||||
|         self._state_cache = StateCache( | ||||
|             await self.session.get_update_state(0), self._log) | ||||
| 
 | ||||
|         if not await self._sender.connect(self._connection( | ||||
|             self.session.server_address, | ||||
|             self.session.port, | ||||
|  | @ -534,7 +536,7 @@ class TelegramBaseClient(abc.ABC): | |||
|             return | ||||
| 
 | ||||
|         self.session.auth_key = self._sender.auth_key | ||||
|         self.session.save() | ||||
|         await self.session.save() | ||||
| 
 | ||||
|         self._init_request.query = functions.help.GetConfigRequest() | ||||
| 
 | ||||
|  | @ -644,7 +646,7 @@ class TelegramBaseClient(abc.ABC): | |||
| 
 | ||||
|         pts, date = self._state_cache[None] | ||||
|         if pts and date: | ||||
|             self.session.set_update_state(0, types.updates.State( | ||||
|             await self.session.set_update_state(0, types.updates.State( | ||||
|                 pts=pts, | ||||
|                 qts=0, | ||||
|                 date=date, | ||||
|  | @ -652,7 +654,7 @@ class TelegramBaseClient(abc.ABC): | |||
|                 unread_count=0 | ||||
|             )) | ||||
| 
 | ||||
|         self.session.close() | ||||
|         await self.session.close() | ||||
| 
 | ||||
|     async def _disconnect(self: 'TelegramClient'): | ||||
|         """ | ||||
|  | @ -677,17 +679,17 @@ class TelegramBaseClient(abc.ABC): | |||
|         # so it's not valid anymore. Set to None to force recreating it. | ||||
|         self._sender.auth_key.key = None | ||||
|         self.session.auth_key = None | ||||
|         self.session.save() | ||||
|         await self.session.save() | ||||
|         await self._disconnect() | ||||
|         return await self.connect() | ||||
| 
 | ||||
|     def _auth_key_callback(self: 'TelegramClient', auth_key): | ||||
|     async def _auth_key_callback(self: 'TelegramClient', auth_key): | ||||
|         """ | ||||
|         Callback from the sender whenever it needed to generate a | ||||
|         new authorization key. This means we are not authorized. | ||||
|         """ | ||||
|         self.session.auth_key = auth_key | ||||
|         self.session.save() | ||||
|         await self.session.save() | ||||
| 
 | ||||
|     # endregion | ||||
| 
 | ||||
|  | @ -812,7 +814,7 @@ class TelegramBaseClient(abc.ABC): | |||
|         if not session: | ||||
|             dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) | ||||
|             session = self.session.clone() | ||||
|             await session.set_dc(dc.id, dc.ip_address, dc.port) | ||||
|             session.set_dc(dc.id, dc.ip_address, dc.port) | ||||
|             self._exported_sessions[cdn_redirect.dc_id] = session | ||||
| 
 | ||||
|         self._log[__name__].info('Creating new CDN client') | ||||
|  |  | |||
|  | @ -255,7 +255,7 @@ class UpdateMethods: | |||
|                         state = d.intermediate_state | ||||
| 
 | ||||
|                     pts, date = state.pts, state.date | ||||
|                     self._handle_update(types.Updates( | ||||
|                     await self._handle_update(types.Updates( | ||||
|                         users=d.users, | ||||
|                         chats=d.chats, | ||||
|                         date=state.date, | ||||
|  | @ -300,8 +300,8 @@ class UpdateMethods: | |||
|     # It is important to not make _handle_update async because we rely on | ||||
|     # the order that the updates arrive in to update the pts and date to | ||||
|     # be always-increasing. There is also no need to make this async. | ||||
|     def _handle_update(self: 'TelegramClient', update): | ||||
|         self.session.process_entities(update) | ||||
|     async def _handle_update(self: 'TelegramClient', update): | ||||
|         await self.session.process_entities(update) | ||||
|         self._entity_cache.add(update) | ||||
| 
 | ||||
|         if isinstance(update, (types.Updates, types.UpdatesCombined)): | ||||
|  | @ -372,7 +372,7 @@ class UpdateMethods: | |||
|             # inserted because this is a rather expensive operation | ||||
|             # (default's sqlite3 takes ~0.1s to commit changes). Do | ||||
|             # it every minute instead. No-op if there's nothing new. | ||||
|             self.session.save() | ||||
|             await self.session.save() | ||||
| 
 | ||||
|             # We need to send some content-related request at least hourly | ||||
|             # for Telegram to keep delivering updates, otherwise they will | ||||
|  |  | |||
|  | @ -71,7 +71,7 @@ class UserMethods: | |||
|                             exceptions.append(e) | ||||
|                             results.append(None) | ||||
|                             continue | ||||
|                         self.session.process_entities(result) | ||||
|                         await self.session.process_entities(result) | ||||
|                         self._entity_cache.add(result) | ||||
|                         exceptions.append(None) | ||||
|                         results.append(result) | ||||
|  | @ -82,7 +82,7 @@ class UserMethods: | |||
|                         return results | ||||
|                 else: | ||||
|                     result = await future | ||||
|                     self.session.process_entities(result) | ||||
|                     await self.session.process_entities(result) | ||||
|                     self._entity_cache.add(result) | ||||
|                     return result | ||||
|             except (errors.ServerError, errors.RpcCallFailError, | ||||
|  | @ -427,7 +427,7 @@ class UserMethods: | |||
| 
 | ||||
|         # No InputPeer, cached peer, or known string. Fetch from disk cache | ||||
|         try: | ||||
|             return self.session.get_input_entity(peer) | ||||
|             return await self.session.get_input_entity(peer) | ||||
|         except ValueError: | ||||
|             pass | ||||
| 
 | ||||
|  | @ -567,7 +567,7 @@ class UserMethods: | |||
|             try: | ||||
|                 # Nobody with this username, maybe it's an exact name/title | ||||
|                 return await self.get_entity( | ||||
|                     self.session.get_input_entity(string)) | ||||
|                     await self.session.get_input_entity(string)) | ||||
|             except ValueError: | ||||
|                 pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -295,7 +295,7 @@ class MTProtoSender: | |||
|             # notify whenever we change it. This is crucial when we | ||||
|             # switch to different data centers. | ||||
|             if self._auth_key_callback: | ||||
|                 self._auth_key_callback(self.auth_key) | ||||
|                 await self._auth_key_callback(self.auth_key) | ||||
| 
 | ||||
|             self._log.debug('auth_key generation success!') | ||||
|             return True | ||||
|  | @ -380,7 +380,7 @@ class MTProtoSender: | |||
|                     self._log.info('Broken authorization key; resetting') | ||||
|                     self.auth_key.key = None | ||||
|                     if self._auth_key_callback: | ||||
|                         self._auth_key_callback(None) | ||||
|                         await self._auth_key_callback(None) | ||||
| 
 | ||||
|                     ok = False | ||||
|                     break | ||||
|  | @ -524,7 +524,7 @@ class MTProtoSender: | |||
|                     self._log.info('Broken authorization key; resetting') | ||||
|                     self.auth_key.key = None | ||||
|                     if self._auth_key_callback: | ||||
|                         self._auth_key_callback(None) | ||||
|                         await self._auth_key_callback(None) | ||||
| 
 | ||||
|                     await self._disconnect(error=e) | ||||
|                 else: | ||||
|  | @ -653,7 +653,7 @@ class MTProtoSender: | |||
| 
 | ||||
|         self._log.debug('Handling update %s', message.obj.__class__.__name__) | ||||
|         if self._update_callback: | ||||
|             self._update_callback(message.obj) | ||||
|             await self._update_callback(message.obj) | ||||
| 
 | ||||
|     async def _handle_pong(self, message): | ||||
|         """ | ||||
|  |  | |||
|  | @ -79,7 +79,7 @@ class Session(ABC): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def get_update_state(self, entity_id): | ||||
|     async def get_update_state(self, entity_id): | ||||
|         """ | ||||
|         Returns the ``UpdateState`` associated with the given `entity_id`. | ||||
|         If the `entity_id` is 0, it should return the ``UpdateState`` for | ||||
|  | @ -89,7 +89,7 @@ class Session(ABC): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def set_update_state(self, entity_id, state): | ||||
|     async def set_update_state(self, entity_id, state): | ||||
|         """ | ||||
|         Sets the given ``UpdateState`` for the specified `entity_id`, which | ||||
|         should be 0 if the ``UpdateState`` is the "general" state (and not | ||||
|  | @ -98,14 +98,14 @@ class Session(ABC): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def close(self): | ||||
|     async def close(self): | ||||
|         """ | ||||
|         Called on client disconnection. Should be used to | ||||
|         free any used resources. Can be left empty if none. | ||||
|         """ | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def save(self): | ||||
|     async def save(self): | ||||
|         """ | ||||
|         Called whenever important properties change. It should | ||||
|         make persist the relevant session information to disk. | ||||
|  | @ -113,22 +113,15 @@ class Session(ABC): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def delete(self): | ||||
|     async def delete(self): | ||||
|         """ | ||||
|         Called upon client.log_out(). Should delete the stored | ||||
|         information from disk since it's not valid anymore. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @classmethod | ||||
|     def list_sessions(cls): | ||||
|         """ | ||||
|         Lists available sessions. Not used by the library itself. | ||||
|         """ | ||||
|         return [] | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def process_entities(self, tlo): | ||||
|     async def process_entities(self, tlo): | ||||
|         """ | ||||
|         Processes the input ``TLObject`` or ``list`` and saves | ||||
|         whatever information is relevant (e.g., ID or access hash). | ||||
|  | @ -136,7 +129,7 @@ class Session(ABC): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def get_input_entity(self, key): | ||||
|     async def get_input_entity(self, key): | ||||
|         """ | ||||
|         Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). | ||||
|         The library uses this method whenever an ``InputPeer`` is needed | ||||
|  | @ -144,24 +137,3 @@ class Session(ABC): | |||
|         to use a cached username to avoid extra RPC). | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def cache_file(self, md5_digest, file_size, instance): | ||||
|         """ | ||||
|         Caches the given file information persistently, so that it | ||||
|         doesn't need to be re-uploaded in case the file is used again. | ||||
| 
 | ||||
|         The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``, | ||||
|         both with an ``.id`` and ``.access_hash`` attributes. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def get_file(self, md5_digest, file_size, cls): | ||||
|         """ | ||||
|         Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size`` | ||||
|         match an existing saved record. The class will either be an | ||||
|         ``InputPhoto`` or ``InputDocument``, both with two parameters | ||||
|         ``id`` and ``access_hash`` in that order. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  |  | |||
|  | @ -71,19 +71,19 @@ class MemorySession(Session): | |||
|     def takeout_id(self, value): | ||||
|         self._takeout_id = value | ||||
| 
 | ||||
|     def get_update_state(self, entity_id): | ||||
|     async def get_update_state(self, entity_id): | ||||
|         return self._update_states.get(entity_id, None) | ||||
| 
 | ||||
|     def set_update_state(self, entity_id, state): | ||||
|     async def set_update_state(self, entity_id, state): | ||||
|         self._update_states[entity_id] = state | ||||
| 
 | ||||
|     def close(self): | ||||
|     async def close(self): | ||||
|         pass | ||||
| 
 | ||||
|     def save(self): | ||||
|     async def save(self): | ||||
|         pass | ||||
| 
 | ||||
|     def delete(self): | ||||
|     async def delete(self): | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|  | @ -144,31 +144,31 @@ class MemorySession(Session): | |||
|                 rows.append(row) | ||||
|         return rows | ||||
| 
 | ||||
|     def process_entities(self, tlo): | ||||
|     async def process_entities(self, tlo): | ||||
|         self._entities |= set(self._entities_to_rows(tlo)) | ||||
| 
 | ||||
|     def get_entity_rows_by_phone(self, phone): | ||||
|     async def get_entity_rows_by_phone(self, phone): | ||||
|         try: | ||||
|             return next((id, hash) for id, hash, _, found_phone, _ | ||||
|                         in self._entities if found_phone == phone) | ||||
|         except StopIteration: | ||||
|             pass | ||||
| 
 | ||||
|     def get_entity_rows_by_username(self, username): | ||||
|     async def get_entity_rows_by_username(self, username): | ||||
|         try: | ||||
|             return next((id, hash) for id, hash, found_username, _, _ | ||||
|                         in self._entities if found_username == username) | ||||
|         except StopIteration: | ||||
|             pass | ||||
| 
 | ||||
|     def get_entity_rows_by_name(self, name): | ||||
|     async def get_entity_rows_by_name(self, name): | ||||
|         try: | ||||
|             return next((id, hash) for id, hash, _, _, found_name | ||||
|                         in self._entities if found_name == name) | ||||
|         except StopIteration: | ||||
|             pass | ||||
| 
 | ||||
|     def get_entity_rows_by_id(self, id, exact=True): | ||||
|     async def get_entity_rows_by_id(self, id, exact=True): | ||||
|         try: | ||||
|             if exact: | ||||
|                 return next((id, hash) for found_id, hash, _, _, _ | ||||
|  | @ -184,7 +184,7 @@ class MemorySession(Session): | |||
|         except StopIteration: | ||||
|             pass | ||||
| 
 | ||||
|     def get_input_entity(self, key): | ||||
|     async def get_input_entity(self, key): | ||||
|         try: | ||||
|             if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): | ||||
|                 # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) | ||||
|  | @ -204,21 +204,21 @@ class MemorySession(Session): | |||
|         if isinstance(key, str): | ||||
|             phone = utils.parse_phone(key) | ||||
|             if phone: | ||||
|                 result = self.get_entity_rows_by_phone(phone) | ||||
|                 result = await self.get_entity_rows_by_phone(phone) | ||||
|             else: | ||||
|                 username, invite = utils.parse_username(key) | ||||
|                 if username and not invite: | ||||
|                     result = self.get_entity_rows_by_username(username) | ||||
|                     result = await self.get_entity_rows_by_username(username) | ||||
|                 else: | ||||
|                     tup = utils.resolve_invite_link(key)[1] | ||||
|                     if tup: | ||||
|                         result = self.get_entity_rows_by_id(tup, exact=False) | ||||
|                         result = await self.get_entity_rows_by_id(tup, exact=False) | ||||
| 
 | ||||
|         elif isinstance(key, int): | ||||
|             result = self.get_entity_rows_by_id(key, exact) | ||||
|             result = await self.get_entity_rows_by_id(key, exact) | ||||
| 
 | ||||
|         if not result and isinstance(key, str): | ||||
|             result = self.get_entity_rows_by_name(key) | ||||
|             result = await self.get_entity_rows_by_name(key) | ||||
| 
 | ||||
|         if result: | ||||
|             entity_id, entity_hash = result  # unpack resulting tuple | ||||
|  | @ -233,14 +233,14 @@ class MemorySession(Session): | |||
|         else: | ||||
|             raise ValueError('Could not find input entity with key ', key) | ||||
| 
 | ||||
|     def cache_file(self, md5_digest, file_size, instance): | ||||
|     async def cache_file(self, md5_digest, file_size, instance): | ||||
|         if not isinstance(instance, (InputDocument, InputPhoto)): | ||||
|             raise TypeError('Cannot cache %s instance' % type(instance)) | ||||
|         key = (md5_digest, file_size, _SentFileType.from_type(type(instance))) | ||||
|         value = (instance.id, instance.access_hash) | ||||
|         self._files[key] = value | ||||
| 
 | ||||
|     def get_file(self, md5_digest, file_size, cls): | ||||
|     async def get_file(self, md5_digest, file_size, cls): | ||||
|         key = (md5_digest, file_size, _SentFileType.from_type(cls)) | ||||
|         try: | ||||
|             return cls(*self._files[key]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user