From 3dd8b7c6d159cceb043242999e632017449c4b25 Mon Sep 17 00:00:00 2001 From: josephbiko Date: Fri, 5 Oct 2018 20:25:49 +0200 Subject: [PATCH] Support async def in sessions (#1013) --- .gitignore | 0 telethon/client/auth.py | 2 +- telethon/client/downloads.py | 4 ++-- telethon/client/messages.py | 2 +- telethon/client/telegrambaseclient.py | 19 ++++++++++--------- telethon/client/updates.py | 10 +++++----- telethon/client/uploads.py | 10 +++++----- telethon/client/users.py | 8 ++++---- telethon/network/mtprotosender.py | 3 ++- telethon/utils.py | 19 +++++++++++++++++++ 10 files changed, 49 insertions(+), 28 deletions(-) mode change 100755 => 100644 .gitignore diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/telethon/client/auth.py b/telethon/client/auth.py index de380de3..7c1ed6b6 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -399,7 +399,7 @@ class AuthMethods(MessageParseMethods, UserMethods): return False await self.disconnect() - self.session.delete() + await self.session.delete() self._authorized = False return True diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 4c0e1239..f8ab4ffc 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -223,9 +223,9 @@ class DownloadMethods(UserMethods): config = await self(functions.help.GetConfigRequest()) for option in config.dc_options: if option.ip_address == self.session.server_address: - self.session.set_dc( + await self.session.set_dc( option.id, option.ip_address, option.port) - self.session.save() + await self.session.save() break # TODO Figure out why the session may have the wrong DC ID diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 88aeeaa7..d9cb3d92 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -662,7 +662,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): reply_markup=self.build_reply_markup(buttons) ) msg = self._get_response_message(request, await self(request), entity) - self._cache_media(msg, file, file_handle) + await self._cache_media(msg, file, file_handle) return msg async def delete_messages(self, entity, message_ids, *, revoke=True): diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 8c99ef00..98e14592 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -14,6 +14,7 @@ from ..network import MTProtoSender, ConnectionTcpFull from ..sessions import Session, SQLiteSession, MemorySession from ..tl import TLObject, functions, types from ..tl.alltlobjects import LAYER +from ..utils import AsyncClassWrapper DEFAULT_DC_ID = 4 DEFAULT_IPV4_IP = '149.154.167.51' @@ -196,7 +197,7 @@ class TelegramBaseClient(abc.ABC): ) self.flood_sleep_threshold = flood_sleep_threshold - self.session = session + self.session = AsyncClassWrapper(session) self.api_id = int(api_id) self.api_hash = api_hash @@ -327,8 +328,8 @@ class TelegramBaseClient(abc.ABC): await self._disconnect() if getattr(self, 'session', None): if getattr(self, '_state', None): - self.session.set_update_state(0, self._state) - self.session.close() + await self.session.set_update_state(0, self._state) + await self.session.close() async def _disconnect(self): """ @@ -366,22 +367,22 @@ class TelegramBaseClient(abc.ABC): __log__.info('Reconnecting to new data center %s', new_dc) dc = await self._get_dc(new_dc) - self.session.set_dc(dc.id, dc.ip_address, dc.port) + await self.session.set_dc(dc.id, dc.ip_address, dc.port) # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self.session.auth_key = None - self.session.save() + await self.session.save() await self._disconnect() return await self.connect() - def _auth_key_callback(self, auth_key): + async def _auth_key_callback(self, auth_key): """ Callback from the sender whenever it needed to generate a new authorization key. This means we are not authorized. """ self._authorized = None self.session.auth_key = auth_key - self.session.save() + await self.session.save() # endregion @@ -469,8 +470,8 @@ class TelegramBaseClient(abc.ABC): session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) - session = self.session.clone() - session.set_dc(dc.id, dc.ip_address, dc.port) + session = await self.session.clone() + await session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session __log__.info('Creating new CDN client') diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 1e40a715..d6400d7d 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -137,7 +137,7 @@ class UpdateMethods(UserMethods): This can also be used to forcibly fetch new updates if there are any. """ - state = self.session.get_update_state(0) + state = await self.session.get_update_state(0) if not state or not state.pts: state = await self(functions.updates.GetStateRequest()) @@ -172,15 +172,15 @@ class UpdateMethods(UserMethods): state.pts = d.pts break finally: - self.session.set_update_state(0, state) + await self.session.set_update_state(0, state) self.session.catching_up = False # endregion # region Private methods - def _handle_update(self, update): - self.session.process_entities(update) + async def _handle_update(self, update): + await self.session.process_entities(update) if isinstance(update, (types.Updates, types.UpdatesCombined)): entities = {utils.get_peer_id(x): x for x in itertools.chain(update.users, update.chats)} @@ -236,7 +236,7 @@ class UpdateMethods(UserMethods): # inserted because this is a rather expensive operation # (default's sqlite3 takes ~0.1s to commit changes). Do # it every minute instead. No-op if there's nothing new. - self.session.save() + await self.session.save() # We need to send some content-related request at least hourly # for Telegram to keep delivering updates, otherwise they will diff --git a/telethon/client/uploads.py b/telethon/client/uploads.py index c1075b07..42707ff8 100644 --- a/telethon/client/uploads.py +++ b/telethon/client/uploads.py @@ -174,7 +174,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods): entities=msg_entities, reply_markup=markup, silent=silent ) msg = self._get_response_message(request, await self(request), entity) - self._cache_media(msg, file, file_handle, force_document=force_document) + await self._cache_media(msg, file, file_handle, force_document=force_document) return msg @@ -211,7 +211,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods): entity, media=types.InputMediaUploadedPhoto(fh) )) input_photo = utils.get_input_photo(r.photo) - self.session.cache_file(fh.md5, fh.size, input_photo) + await self.session.cache_file(fh.md5, fh.size, input_photo) fh = input_photo if captions: @@ -326,7 +326,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods): file = stream.read() hash_md5.update(file) if use_cache: - cached = self.session.get_file( + cached = await self.session.get_file( hash_md5.digest(), file_size, cls=use_cache ) if cached: @@ -446,7 +446,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods): ) return file_handle, media - def _cache_media(self, msg, file, file_handle, + async def _cache_media(self, msg, file, file_handle, force_document=False): if file and msg and isinstance(file_handle, custom.InputSizedFile): @@ -457,6 +457,6 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods): to_cache = utils.get_input_photo(msg.media.photo) else: to_cache = utils.get_input_document(msg.media.document) - self.session.cache_file(md5, size, to_cache) + await self.session.cache_file(md5, size, to_cache) # endregion diff --git a/telethon/client/users.py b/telethon/client/users.py index fa0c3015..b59f7ae7 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -48,7 +48,7 @@ class UserMethods(TelegramBaseClient): exceptions.append(e) results.append(None) continue - self.session.process_entities(result) + await self.session.process_entities(result) exceptions.append(None) results.append(result) request_index += 1 @@ -58,7 +58,7 @@ class UserMethods(TelegramBaseClient): return results else: result = await future - self.session.process_entities(result) + await self.session.process_entities(result) return result except (errors.ServerError, errors.RpcCallFailError) as e: __log__.warning('Telegram is having internal issues %s: %s', @@ -288,7 +288,7 @@ class UserMethods(TelegramBaseClient): try: # First try to get the entity from cache, otherwise figure it out - return self.session.get_input_entity(peer) + return await self.session.get_input_entity(peer) except ValueError: pass @@ -393,7 +393,7 @@ class UserMethods(TelegramBaseClient): try: # Nobody with this username, maybe it's an exact name/title return await self.get_entity( - self.session.get_input_entity(string)) + await self.session.get_input_entity(string)) except ValueError: pass diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 407bc24e..eb868fcd 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -21,6 +21,7 @@ from ..tl.types import ( MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload ) +from ..utils import AsyncClassWrapper __log__ = logging.getLogger(__name__) @@ -213,7 +214,7 @@ class MTProtoSender: await authenticator.do_authentication(plain) if self._auth_key_callback: - self._auth_key_callback(state.auth_key) + await self._auth_key_callback(state.auth_key) break except (SecurityError, AssertionError) as e: diff --git a/telethon/utils.py b/telethon/utils.py index de0d9900..d6e4b37d 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -16,6 +16,7 @@ from types import GeneratorType from .extensions import markdown, html from .helpers import add_surrogate, del_surrogate from .tl import types +import inspect try: import hachoir @@ -994,3 +995,21 @@ def get_appropriated_part_size(file_size): return 512 raise ValueError('File size too large') + + +class AsyncClassWrapper: + def __init__(self, wrapped): + self.wrapped = wrapped + + def __getattr__(self, item): + w = getattr(self.wrapped, item) + async def wrapper(*args, **kwargs): + val = w(*args, **kwargs) + return await val if inspect.isawaitable(val) else val + + if callable(w): + return wrapper + elif isinstance(w, property): + return w.fget(self.wrapped) + else: + return w