mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-18 04:20:57 +03:00
Support async def in sessions (#1013)
This commit is contained in:
parent
653f3c043d
commit
3dd8b7c6d1
0
.gitignore
vendored
Executable file → Normal file
0
.gitignore
vendored
Executable file → Normal file
|
@ -399,7 +399,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
self.session.delete()
|
await self.session.delete()
|
||||||
self._authorized = False
|
self._authorized = False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -223,9 +223,9 @@ class DownloadMethods(UserMethods):
|
||||||
config = await self(functions.help.GetConfigRequest())
|
config = await self(functions.help.GetConfigRequest())
|
||||||
for option in config.dc_options:
|
for option in config.dc_options:
|
||||||
if option.ip_address == self.session.server_address:
|
if option.ip_address == self.session.server_address:
|
||||||
self.session.set_dc(
|
await self.session.set_dc(
|
||||||
option.id, option.ip_address, option.port)
|
option.id, option.ip_address, option.port)
|
||||||
self.session.save()
|
await self.session.save()
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO Figure out why the session may have the wrong DC ID
|
# TODO Figure out why the session may have the wrong DC ID
|
||||||
|
|
|
@ -662,7 +662,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
reply_markup=self.build_reply_markup(buttons)
|
reply_markup=self.build_reply_markup(buttons)
|
||||||
)
|
)
|
||||||
msg = self._get_response_message(request, await self(request), entity)
|
msg = self._get_response_message(request, await self(request), entity)
|
||||||
self._cache_media(msg, file, file_handle)
|
await self._cache_media(msg, file, file_handle)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def delete_messages(self, entity, message_ids, *, revoke=True):
|
async def delete_messages(self, entity, message_ids, *, revoke=True):
|
||||||
|
|
|
@ -14,6 +14,7 @@ from ..network import MTProtoSender, ConnectionTcpFull
|
||||||
from ..sessions import Session, SQLiteSession, MemorySession
|
from ..sessions import Session, SQLiteSession, MemorySession
|
||||||
from ..tl import TLObject, functions, types
|
from ..tl import TLObject, functions, types
|
||||||
from ..tl.alltlobjects import LAYER
|
from ..tl.alltlobjects import LAYER
|
||||||
|
from ..utils import AsyncClassWrapper
|
||||||
|
|
||||||
DEFAULT_DC_ID = 4
|
DEFAULT_DC_ID = 4
|
||||||
DEFAULT_IPV4_IP = '149.154.167.51'
|
DEFAULT_IPV4_IP = '149.154.167.51'
|
||||||
|
@ -196,7 +197,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.flood_sleep_threshold = flood_sleep_threshold
|
self.flood_sleep_threshold = flood_sleep_threshold
|
||||||
self.session = session
|
self.session = AsyncClassWrapper(session)
|
||||||
self.api_id = int(api_id)
|
self.api_id = int(api_id)
|
||||||
self.api_hash = api_hash
|
self.api_hash = api_hash
|
||||||
|
|
||||||
|
@ -327,8 +328,8 @@ class TelegramBaseClient(abc.ABC):
|
||||||
await self._disconnect()
|
await self._disconnect()
|
||||||
if getattr(self, 'session', None):
|
if getattr(self, 'session', None):
|
||||||
if getattr(self, '_state', None):
|
if getattr(self, '_state', None):
|
||||||
self.session.set_update_state(0, self._state)
|
await self.session.set_update_state(0, self._state)
|
||||||
self.session.close()
|
await self.session.close()
|
||||||
|
|
||||||
async def _disconnect(self):
|
async def _disconnect(self):
|
||||||
"""
|
"""
|
||||||
|
@ -366,22 +367,22 @@ class TelegramBaseClient(abc.ABC):
|
||||||
__log__.info('Reconnecting to new data center %s', new_dc)
|
__log__.info('Reconnecting to new data center %s', new_dc)
|
||||||
dc = await self._get_dc(new_dc)
|
dc = await self._get_dc(new_dc)
|
||||||
|
|
||||||
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
await self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
# auth_key's are associated with a server, which has now changed
|
# auth_key's are associated with a server, which has now changed
|
||||||
# so it's not valid anymore. Set to None to force recreating it.
|
# so it's not valid anymore. Set to None to force recreating it.
|
||||||
self.session.auth_key = None
|
self.session.auth_key = None
|
||||||
self.session.save()
|
await self.session.save()
|
||||||
await self._disconnect()
|
await self._disconnect()
|
||||||
return await self.connect()
|
return await self.connect()
|
||||||
|
|
||||||
def _auth_key_callback(self, auth_key):
|
async def _auth_key_callback(self, auth_key):
|
||||||
"""
|
"""
|
||||||
Callback from the sender whenever it needed to generate a
|
Callback from the sender whenever it needed to generate a
|
||||||
new authorization key. This means we are not authorized.
|
new authorization key. This means we are not authorized.
|
||||||
"""
|
"""
|
||||||
self._authorized = None
|
self._authorized = None
|
||||||
self.session.auth_key = auth_key
|
self.session.auth_key = auth_key
|
||||||
self.session.save()
|
await self.session.save()
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -469,8 +470,8 @@ class TelegramBaseClient(abc.ABC):
|
||||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||||
if not session:
|
if not session:
|
||||||
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||||
session = self.session.clone()
|
session = await self.session.clone()
|
||||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
await session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||||
|
|
||||||
__log__.info('Creating new CDN client')
|
__log__.info('Creating new CDN client')
|
||||||
|
|
|
@ -137,7 +137,7 @@ class UpdateMethods(UserMethods):
|
||||||
|
|
||||||
This can also be used to forcibly fetch new updates if there are any.
|
This can also be used to forcibly fetch new updates if there are any.
|
||||||
"""
|
"""
|
||||||
state = self.session.get_update_state(0)
|
state = await self.session.get_update_state(0)
|
||||||
if not state or not state.pts:
|
if not state or not state.pts:
|
||||||
state = await self(functions.updates.GetStateRequest())
|
state = await self(functions.updates.GetStateRequest())
|
||||||
|
|
||||||
|
@ -172,15 +172,15 @@ class UpdateMethods(UserMethods):
|
||||||
state.pts = d.pts
|
state.pts = d.pts
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
self.session.set_update_state(0, state)
|
await self.session.set_update_state(0, state)
|
||||||
self.session.catching_up = False
|
self.session.catching_up = False
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Private methods
|
# region Private methods
|
||||||
|
|
||||||
def _handle_update(self, update):
|
async def _handle_update(self, update):
|
||||||
self.session.process_entities(update)
|
await self.session.process_entities(update)
|
||||||
if isinstance(update, (types.Updates, types.UpdatesCombined)):
|
if isinstance(update, (types.Updates, types.UpdatesCombined)):
|
||||||
entities = {utils.get_peer_id(x): x for x in
|
entities = {utils.get_peer_id(x): x for x in
|
||||||
itertools.chain(update.users, update.chats)}
|
itertools.chain(update.users, update.chats)}
|
||||||
|
@ -236,7 +236,7 @@ class UpdateMethods(UserMethods):
|
||||||
# inserted because this is a rather expensive operation
|
# inserted because this is a rather expensive operation
|
||||||
# (default's sqlite3 takes ~0.1s to commit changes). Do
|
# (default's sqlite3 takes ~0.1s to commit changes). Do
|
||||||
# it every minute instead. No-op if there's nothing new.
|
# it every minute instead. No-op if there's nothing new.
|
||||||
self.session.save()
|
await self.session.save()
|
||||||
|
|
||||||
# We need to send some content-related request at least hourly
|
# We need to send some content-related request at least hourly
|
||||||
# for Telegram to keep delivering updates, otherwise they will
|
# for Telegram to keep delivering updates, otherwise they will
|
||||||
|
|
|
@ -174,7 +174,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
||||||
entities=msg_entities, reply_markup=markup, silent=silent
|
entities=msg_entities, reply_markup=markup, silent=silent
|
||||||
)
|
)
|
||||||
msg = self._get_response_message(request, await self(request), entity)
|
msg = self._get_response_message(request, await self(request), entity)
|
||||||
self._cache_media(msg, file, file_handle, force_document=force_document)
|
await self._cache_media(msg, file, file_handle, force_document=force_document)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
||||||
entity, media=types.InputMediaUploadedPhoto(fh)
|
entity, media=types.InputMediaUploadedPhoto(fh)
|
||||||
))
|
))
|
||||||
input_photo = utils.get_input_photo(r.photo)
|
input_photo = utils.get_input_photo(r.photo)
|
||||||
self.session.cache_file(fh.md5, fh.size, input_photo)
|
await self.session.cache_file(fh.md5, fh.size, input_photo)
|
||||||
fh = input_photo
|
fh = input_photo
|
||||||
|
|
||||||
if captions:
|
if captions:
|
||||||
|
@ -326,7 +326,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
||||||
file = stream.read()
|
file = stream.read()
|
||||||
hash_md5.update(file)
|
hash_md5.update(file)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cached = self.session.get_file(
|
cached = await self.session.get_file(
|
||||||
hash_md5.digest(), file_size, cls=use_cache
|
hash_md5.digest(), file_size, cls=use_cache
|
||||||
)
|
)
|
||||||
if cached:
|
if cached:
|
||||||
|
@ -446,7 +446,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
||||||
)
|
)
|
||||||
return file_handle, media
|
return file_handle, media
|
||||||
|
|
||||||
def _cache_media(self, msg, file, file_handle,
|
async def _cache_media(self, msg, file, file_handle,
|
||||||
force_document=False):
|
force_document=False):
|
||||||
if file and msg and isinstance(file_handle,
|
if file and msg and isinstance(file_handle,
|
||||||
custom.InputSizedFile):
|
custom.InputSizedFile):
|
||||||
|
@ -457,6 +457,6 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
||||||
to_cache = utils.get_input_photo(msg.media.photo)
|
to_cache = utils.get_input_photo(msg.media.photo)
|
||||||
else:
|
else:
|
||||||
to_cache = utils.get_input_document(msg.media.document)
|
to_cache = utils.get_input_document(msg.media.document)
|
||||||
self.session.cache_file(md5, size, to_cache)
|
await self.session.cache_file(md5, size, to_cache)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -48,7 +48,7 @@ class UserMethods(TelegramBaseClient):
|
||||||
exceptions.append(e)
|
exceptions.append(e)
|
||||||
results.append(None)
|
results.append(None)
|
||||||
continue
|
continue
|
||||||
self.session.process_entities(result)
|
await self.session.process_entities(result)
|
||||||
exceptions.append(None)
|
exceptions.append(None)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
request_index += 1
|
request_index += 1
|
||||||
|
@ -58,7 +58,7 @@ class UserMethods(TelegramBaseClient):
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
result = await future
|
result = await future
|
||||||
self.session.process_entities(result)
|
await self.session.process_entities(result)
|
||||||
return result
|
return result
|
||||||
except (errors.ServerError, errors.RpcCallFailError) as e:
|
except (errors.ServerError, errors.RpcCallFailError) as e:
|
||||||
__log__.warning('Telegram is having internal issues %s: %s',
|
__log__.warning('Telegram is having internal issues %s: %s',
|
||||||
|
@ -288,7 +288,7 @@ class UserMethods(TelegramBaseClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First try to get the entity from cache, otherwise figure it out
|
# First try to get the entity from cache, otherwise figure it out
|
||||||
return self.session.get_input_entity(peer)
|
return await self.session.get_input_entity(peer)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -393,7 +393,7 @@ class UserMethods(TelegramBaseClient):
|
||||||
try:
|
try:
|
||||||
# Nobody with this username, maybe it's an exact name/title
|
# Nobody with this username, maybe it's an exact name/title
|
||||||
return await self.get_entity(
|
return await self.get_entity(
|
||||||
self.session.get_input_entity(string))
|
await self.session.get_input_entity(string))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from ..tl.types import (
|
||||||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
||||||
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
|
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
|
||||||
)
|
)
|
||||||
|
from ..utils import AsyncClassWrapper
|
||||||
|
|
||||||
__log__ = logging.getLogger(__name__)
|
__log__ = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -213,7 +214,7 @@ class MTProtoSender:
|
||||||
await authenticator.do_authentication(plain)
|
await authenticator.do_authentication(plain)
|
||||||
|
|
||||||
if self._auth_key_callback:
|
if self._auth_key_callback:
|
||||||
self._auth_key_callback(state.auth_key)
|
await self._auth_key_callback(state.auth_key)
|
||||||
|
|
||||||
break
|
break
|
||||||
except (SecurityError, AssertionError) as e:
|
except (SecurityError, AssertionError) as e:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from types import GeneratorType
|
||||||
from .extensions import markdown, html
|
from .extensions import markdown, html
|
||||||
from .helpers import add_surrogate, del_surrogate
|
from .helpers import add_surrogate, del_surrogate
|
||||||
from .tl import types
|
from .tl import types
|
||||||
|
import inspect
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import hachoir
|
import hachoir
|
||||||
|
@ -994,3 +995,21 @@ def get_appropriated_part_size(file_size):
|
||||||
return 512
|
return 512
|
||||||
|
|
||||||
raise ValueError('File size too large')
|
raise ValueError('File size too large')
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncClassWrapper:
|
||||||
|
def __init__(self, wrapped):
|
||||||
|
self.wrapped = wrapped
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
w = getattr(self.wrapped, item)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
val = w(*args, **kwargs)
|
||||||
|
return await val if inspect.isawaitable(val) else val
|
||||||
|
|
||||||
|
if callable(w):
|
||||||
|
return wrapper
|
||||||
|
elif isinstance(w, property):
|
||||||
|
return w.fget(self.wrapped)
|
||||||
|
else:
|
||||||
|
return w
|
||||||
|
|
Loading…
Reference in New Issue
Block a user