Support async def in sessions (#1013)

This commit is contained in:
josephbiko 2018-10-05 20:25:49 +02:00 committed by Lonami
parent 653f3c043d
commit 3dd8b7c6d1
10 changed files with 49 additions and 28 deletions

0
.gitignore vendored Executable file → Normal file
View File

View File

@ -399,7 +399,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
return False return False
await self.disconnect() await self.disconnect()
self.session.delete() await self.session.delete()
self._authorized = False self._authorized = False
return True return True

View File

@ -223,9 +223,9 @@ class DownloadMethods(UserMethods):
config = await self(functions.help.GetConfigRequest()) config = await self(functions.help.GetConfigRequest())
for option in config.dc_options: for option in config.dc_options:
if option.ip_address == self.session.server_address: if option.ip_address == self.session.server_address:
self.session.set_dc( await self.session.set_dc(
option.id, option.ip_address, option.port) option.id, option.ip_address, option.port)
self.session.save() await self.session.save()
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID

View File

@ -662,7 +662,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
reply_markup=self.build_reply_markup(buttons) reply_markup=self.build_reply_markup(buttons)
) )
msg = self._get_response_message(request, await self(request), entity) msg = self._get_response_message(request, await self(request), entity)
self._cache_media(msg, file, file_handle) await self._cache_media(msg, file, file_handle)
return msg return msg
async def delete_messages(self, entity, message_ids, *, revoke=True): async def delete_messages(self, entity, message_ids, *, revoke=True):

View File

@ -14,6 +14,7 @@ from ..network import MTProtoSender, ConnectionTcpFull
from ..sessions import Session, SQLiteSession, MemorySession from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import TLObject, functions, types from ..tl import TLObject, functions, types
from ..tl.alltlobjects import LAYER from ..tl.alltlobjects import LAYER
from ..utils import AsyncClassWrapper
DEFAULT_DC_ID = 4 DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -196,7 +197,7 @@ class TelegramBaseClient(abc.ABC):
) )
self.flood_sleep_threshold = flood_sleep_threshold self.flood_sleep_threshold = flood_sleep_threshold
self.session = session self.session = AsyncClassWrapper(session)
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -327,8 +328,8 @@ class TelegramBaseClient(abc.ABC):
await self._disconnect() await self._disconnect()
if getattr(self, 'session', None): if getattr(self, 'session', None):
if getattr(self, '_state', None): if getattr(self, '_state', None):
self.session.set_update_state(0, self._state) await self.session.set_update_state(0, self._state)
self.session.close() await self.session.close()
async def _disconnect(self): async def _disconnect(self):
""" """
@ -366,22 +367,22 @@ class TelegramBaseClient(abc.ABC):
__log__.info('Reconnecting to new data center %s', new_dc) __log__.info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc) dc = await self._get_dc(new_dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) await self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self.session.auth_key = None self.session.auth_key = None
self.session.save() await self.session.save()
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
def _auth_key_callback(self, auth_key): async def _auth_key_callback(self, auth_key):
""" """
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self._authorized = None self._authorized = None
self.session.auth_key = auth_key self.session.auth_key = auth_key
self.session.save() await self.session.save()
# endregion # endregion
@ -469,8 +470,8 @@ class TelegramBaseClient(abc.ABC):
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = await self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port) await session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
__log__.info('Creating new CDN client') __log__.info('Creating new CDN client')

View File

@ -137,7 +137,7 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any. This can also be used to forcibly fetch new updates if there are any.
""" """
state = self.session.get_update_state(0) state = await self.session.get_update_state(0)
if not state or not state.pts: if not state or not state.pts:
state = await self(functions.updates.GetStateRequest()) state = await self(functions.updates.GetStateRequest())
@ -172,15 +172,15 @@ class UpdateMethods(UserMethods):
state.pts = d.pts state.pts = d.pts
break break
finally: finally:
self.session.set_update_state(0, state) await self.session.set_update_state(0, state)
self.session.catching_up = False self.session.catching_up = False
# endregion # endregion
# region Private methods # region Private methods
def _handle_update(self, update): async def _handle_update(self, update):
self.session.process_entities(update) await self.session.process_entities(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)): if isinstance(update, (types.Updates, types.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}
@ -236,7 +236,7 @@ class UpdateMethods(UserMethods):
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
self.session.save() await self.session.save()
# We need to send some content-related request at least hourly # We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will # for Telegram to keep delivering updates, otherwise they will

View File

@ -174,7 +174,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
entities=msg_entities, reply_markup=markup, silent=silent entities=msg_entities, reply_markup=markup, silent=silent
) )
msg = self._get_response_message(request, await self(request), entity) msg = self._get_response_message(request, await self(request), entity)
self._cache_media(msg, file, file_handle, force_document=force_document) await self._cache_media(msg, file, file_handle, force_document=force_document)
return msg return msg
@ -211,7 +211,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
entity, media=types.InputMediaUploadedPhoto(fh) entity, media=types.InputMediaUploadedPhoto(fh)
)) ))
input_photo = utils.get_input_photo(r.photo) input_photo = utils.get_input_photo(r.photo)
self.session.cache_file(fh.md5, fh.size, input_photo) await self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo fh = input_photo
if captions: if captions:
@ -326,7 +326,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
file = stream.read() file = stream.read()
hash_md5.update(file) hash_md5.update(file)
if use_cache: if use_cache:
cached = self.session.get_file( cached = await self.session.get_file(
hash_md5.digest(), file_size, cls=use_cache hash_md5.digest(), file_size, cls=use_cache
) )
if cached: if cached:
@ -446,7 +446,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
) )
return file_handle, media return file_handle, media
def _cache_media(self, msg, file, file_handle, async def _cache_media(self, msg, file, file_handle,
force_document=False): force_document=False):
if file and msg and isinstance(file_handle, if file and msg and isinstance(file_handle,
custom.InputSizedFile): custom.InputSizedFile):
@ -457,6 +457,6 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
to_cache = utils.get_input_photo(msg.media.photo) to_cache = utils.get_input_photo(msg.media.photo)
else: else:
to_cache = utils.get_input_document(msg.media.document) to_cache = utils.get_input_document(msg.media.document)
self.session.cache_file(md5, size, to_cache) await self.session.cache_file(md5, size, to_cache)
# endregion # endregion

View File

@ -48,7 +48,7 @@ class UserMethods(TelegramBaseClient):
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) await self.session.process_entities(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -58,7 +58,7 @@ class UserMethods(TelegramBaseClient):
return results return results
else: else:
result = await future result = await future
self.session.process_entities(result) await self.session.process_entities(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError) as e: except (errors.ServerError, errors.RpcCallFailError) as e:
__log__.warning('Telegram is having internal issues %s: %s', __log__.warning('Telegram is having internal issues %s: %s',
@ -288,7 +288,7 @@ class UserMethods(TelegramBaseClient):
try: try:
# First try to get the entity from cache, otherwise figure it out # First try to get the entity from cache, otherwise figure it out
return self.session.get_input_entity(peer) return await self.session.get_input_entity(peer)
except ValueError: except ValueError:
pass pass
@ -393,7 +393,7 @@ class UserMethods(TelegramBaseClient):
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( return await self.get_entity(
self.session.get_input_entity(string)) await self.session.get_input_entity(string))
except ValueError: except ValueError:
pass pass

View File

@ -21,6 +21,7 @@ from ..tl.types import (
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
) )
from ..utils import AsyncClassWrapper
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -213,7 +214,7 @@ class MTProtoSender:
await authenticator.do_authentication(plain) await authenticator.do_authentication(plain)
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(state.auth_key) await self._auth_key_callback(state.auth_key)
break break
except (SecurityError, AssertionError) as e: except (SecurityError, AssertionError) as e:

View File

@ -16,6 +16,7 @@ from types import GeneratorType
from .extensions import markdown, html from .extensions import markdown, html
from .helpers import add_surrogate, del_surrogate from .helpers import add_surrogate, del_surrogate
from .tl import types from .tl import types
import inspect
try: try:
import hachoir import hachoir
@ -994,3 +995,21 @@ def get_appropriated_part_size(file_size):
return 512 return 512
raise ValueError('File size too large') raise ValueError('File size too large')
class AsyncClassWrapper:
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, item):
w = getattr(self.wrapped, item)
async def wrapper(*args, **kwargs):
val = w(*args, **kwargs)
return await val if inspect.isawaitable(val) else val
if callable(w):
return wrapper
elif isinstance(w, property):
return w.fget(self.wrapped)
else:
return w