mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-04 04:00:18 +03:00
Add async sessions while keeping retro compatibility
This commit is contained in:
parent
01af2fcca3
commit
4742414247
|
@ -552,7 +552,9 @@ class AuthMethods:
|
||||||
self._authorized = False
|
self._authorized = False
|
||||||
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
self.session.delete()
|
delete = self.session.delete()
|
||||||
|
if inspect.isawaitable(delete):
|
||||||
|
await delete
|
||||||
self.session = None
|
self.session = None
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -63,9 +63,12 @@ class _DirectDownloadIter(RequestIter):
|
||||||
config = await self.client(functions.help.GetConfigRequest())
|
config = await self.client(functions.help.GetConfigRequest())
|
||||||
for option in config.dc_options:
|
for option in config.dc_options:
|
||||||
if option.ip_address == self.client.session.server_address:
|
if option.ip_address == self.client.session.server_address:
|
||||||
self.client.session.set_dc(
|
set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port)
|
||||||
option.id, option.ip_address, option.port)
|
if inspect.isawaitable(set_dc):
|
||||||
self.client.session.save()
|
await set_dc
|
||||||
|
save = self.client.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO Figure out why the session may have the wrong DC ID
|
# TODO Figure out why the session may have the wrong DC ID
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import abc
|
import abc
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
@ -305,16 +306,6 @@ class TelegramBaseClient(abc.ABC):
|
||||||
'The given session must be a str or a Session instance.'
|
'The given session must be a str or a Session instance.'
|
||||||
)
|
)
|
||||||
|
|
||||||
# ':' in session.server_address is True if it's an IPv6 address
|
|
||||||
if (not session.server_address or
|
|
||||||
(':' in session.server_address) != use_ipv6):
|
|
||||||
session.set_dc(
|
|
||||||
DEFAULT_DC_ID,
|
|
||||||
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
|
|
||||||
DEFAULT_PORT
|
|
||||||
)
|
|
||||||
session.save()
|
|
||||||
|
|
||||||
self.flood_sleep_threshold = flood_sleep_threshold
|
self.flood_sleep_threshold = flood_sleep_threshold
|
||||||
|
|
||||||
# TODO Use AsyncClassWrapper(session)
|
# TODO Use AsyncClassWrapper(session)
|
||||||
|
@ -546,6 +537,20 @@ class TelegramBaseClient(abc.ABC):
|
||||||
elif self._loop != helpers.get_running_loop():
|
elif self._loop != helpers.get_running_loop():
|
||||||
raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)')
|
raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)')
|
||||||
|
|
||||||
|
# ':' in session.server_address is True if it's an IPv6 address
|
||||||
|
if (not self.session.server_address or
|
||||||
|
(':' in self.session.server_address) != self._use_ipv6):
|
||||||
|
set_dc = self.session.set_dc(
|
||||||
|
DEFAULT_DC_ID,
|
||||||
|
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
|
||||||
|
DEFAULT_PORT
|
||||||
|
)
|
||||||
|
if inspect.isawaitable(set_dc):
|
||||||
|
await set_dc
|
||||||
|
save = self.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
|
|
||||||
if not await self._sender.connect(self._connection(
|
if not await self._sender.connect(self._connection(
|
||||||
self.session.server_address,
|
self.session.server_address,
|
||||||
self.session.port,
|
self.session.port,
|
||||||
|
@ -558,12 +563,19 @@ class TelegramBaseClient(abc.ABC):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.session.auth_key = self._sender.auth_key
|
self.session.auth_key = self._sender.auth_key
|
||||||
self.session.save()
|
save = self.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# See comment when saving entities to understand this hack
|
# See comment when saving entities to understand this hack
|
||||||
self_id = self.session.get_input_entity(0).access_hash
|
self_entity = self.session.get_input_entity(0)
|
||||||
|
if inspect.isawaitable(self_entity):
|
||||||
|
self_entity = await self_entity
|
||||||
|
self_id = self_entity.access_hash
|
||||||
self_user = self.session.get_input_entity(self_id)
|
self_user = self.session.get_input_entity(self_id)
|
||||||
|
if inspect.isawaitable(self_user):
|
||||||
|
self_user = await self_user
|
||||||
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
|
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
@ -572,7 +584,10 @@ class TelegramBaseClient(abc.ABC):
|
||||||
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
|
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
|
||||||
cs = []
|
cs = []
|
||||||
|
|
||||||
for entity_id, state in self.session.get_update_states():
|
update_states = self.session.get_update_states()
|
||||||
|
if inspect.isawaitable(update_states):
|
||||||
|
update_states = await update_states
|
||||||
|
for entity_id, state in update_states:
|
||||||
if entity_id == 0:
|
if entity_id == 0:
|
||||||
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
|
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
|
||||||
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
|
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
|
||||||
|
@ -583,6 +598,8 @@ class TelegramBaseClient(abc.ABC):
|
||||||
for state in cs:
|
for state in cs:
|
||||||
try:
|
try:
|
||||||
entity = self.session.get_input_entity(state.channel_id)
|
entity = self.session.get_input_entity(state.channel_id)
|
||||||
|
if inspect.isawaitable(entity):
|
||||||
|
entity = await entity
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self._log[__name__].warning(
|
self._log[__name__].warning(
|
||||||
'No access_hash in cache for channel %s, will not catch up', state.channel_id)
|
'No access_hash in cache for channel %s, will not catch up', state.channel_id)
|
||||||
|
@ -686,19 +703,29 @@ class TelegramBaseClient(abc.ABC):
|
||||||
else:
|
else:
|
||||||
connection._proxy = proxy
|
connection._proxy = proxy
|
||||||
|
|
||||||
def _save_states_and_entities(self: 'TelegramClient'):
|
async def _save_states_and_entities(self: 'TelegramClient'):
|
||||||
# As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``.
|
# As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``.
|
||||||
# This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved.
|
# This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved.
|
||||||
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
|
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
|
||||||
# It doesn't matter if we put users in the list of chats.
|
# It doesn't matter if we put users in the list of chats.
|
||||||
if self._mb_entity_cache.self_id:
|
if self._mb_entity_cache.self_id:
|
||||||
self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []))
|
process_entities = self.session.process_entities(
|
||||||
|
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
|
||||||
|
)
|
||||||
|
if inspect.isawaitable(process_entities):
|
||||||
|
await process_entities
|
||||||
|
|
||||||
ss, cs = self._message_box.session_state()
|
ss, cs = self._message_box.session_state()
|
||||||
self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
|
update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
|
||||||
|
if inspect.isawaitable(update_state):
|
||||||
|
await update_state
|
||||||
now = datetime.datetime.now() # any datetime works; channels don't need it
|
now = datetime.datetime.now() # any datetime works; channels don't need it
|
||||||
for channel_id, pts in cs.items():
|
for channel_id, pts in cs.items():
|
||||||
self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0))
|
update_state = self.session.set_update_state(
|
||||||
|
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
|
||||||
|
)
|
||||||
|
if inspect.isawaitable(update_state):
|
||||||
|
await update_state
|
||||||
|
|
||||||
async def _disconnect_coro(self: 'TelegramClient'):
|
async def _disconnect_coro(self: 'TelegramClient'):
|
||||||
if self.session is None:
|
if self.session is None:
|
||||||
|
@ -730,9 +757,11 @@ class TelegramBaseClient(abc.ABC):
|
||||||
await asyncio.wait(self._event_handler_tasks)
|
await asyncio.wait(self._event_handler_tasks)
|
||||||
self._event_handler_tasks.clear()
|
self._event_handler_tasks.clear()
|
||||||
|
|
||||||
self._save_states_and_entities()
|
await self._save_states_and_entities()
|
||||||
|
|
||||||
self.session.close()
|
close = self.session.close()
|
||||||
|
if inspect.isawaitable(close):
|
||||||
|
await close
|
||||||
|
|
||||||
async def _disconnect(self: 'TelegramClient'):
|
async def _disconnect(self: 'TelegramClient'):
|
||||||
"""
|
"""
|
||||||
|
@ -753,22 +782,28 @@ class TelegramBaseClient(abc.ABC):
|
||||||
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
|
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
|
||||||
dc = await self._get_dc(new_dc)
|
dc = await self._get_dc(new_dc)
|
||||||
|
|
||||||
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
|
if inspect.isawaitable(set_dc):
|
||||||
|
await set_dc
|
||||||
# auth_key's are associated with a server, which has now changed
|
# auth_key's are associated with a server, which has now changed
|
||||||
# so it's not valid anymore. Set to None to force recreating it.
|
# so it's not valid anymore. Set to None to force recreating it.
|
||||||
self._sender.auth_key.key = None
|
self._sender.auth_key.key = None
|
||||||
self.session.auth_key = None
|
self.session.auth_key = None
|
||||||
self.session.save()
|
save = self.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
await self._disconnect()
|
await self._disconnect()
|
||||||
return await self.connect()
|
return await self.connect()
|
||||||
|
|
||||||
def _auth_key_callback(self: 'TelegramClient', auth_key):
|
async def _auth_key_callback(self: 'TelegramClient', auth_key):
|
||||||
"""
|
"""
|
||||||
Callback from the sender whenever it needed to generate a
|
Callback from the sender whenever it needed to generate a
|
||||||
new authorization key. This means we are not authorized.
|
new authorization key. This means we are not authorized.
|
||||||
"""
|
"""
|
||||||
self.session.auth_key = auth_key
|
self.session.auth_key = auth_key
|
||||||
self.session.save()
|
save = self.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -895,7 +930,11 @@ class TelegramBaseClient(abc.ABC):
|
||||||
if not session:
|
if not session:
|
||||||
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||||
session = self.session.clone()
|
session = self.session.clone()
|
||||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
if inspect.isawaitable(session):
|
||||||
|
session = await session
|
||||||
|
set_dc = session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
|
if inspect.isawaitable(set_dc):
|
||||||
|
await set_dc
|
||||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||||
|
|
||||||
self._log[__name__].info('Creating new CDN client')
|
self._log[__name__].info('Creating new CDN client')
|
||||||
|
|
|
@ -289,7 +289,7 @@ class UpdateMethods:
|
||||||
len(self._mb_entity_cache),
|
len(self._mb_entity_cache),
|
||||||
self._entity_cache_limit
|
self._entity_cache_limit
|
||||||
)
|
)
|
||||||
self._save_states_and_entities()
|
await self._save_states_and_entities()
|
||||||
self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map)
|
self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map)
|
||||||
if len(self._mb_entity_cache) >= self._entity_cache_limit:
|
if len(self._mb_entity_cache) >= self._entity_cache_limit:
|
||||||
warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit')
|
warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit')
|
||||||
|
@ -343,7 +343,10 @@ class UpdateMethods:
|
||||||
if updates:
|
if updates:
|
||||||
self._log[__name__].info('Got difference for account updates')
|
self._log[__name__].info('Got difference for account updates')
|
||||||
|
|
||||||
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
|
_preprocess_updates = self._preprocess_updates(updates, users, chats)
|
||||||
|
if inspect.isawaitable(_preprocess_updates):
|
||||||
|
_preprocess_updates = await _preprocess_updates
|
||||||
|
updates_to_dispatch.extend(_preprocess_updates)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
|
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
|
||||||
|
@ -441,7 +444,10 @@ class UpdateMethods:
|
||||||
if updates:
|
if updates:
|
||||||
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
|
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
|
||||||
|
|
||||||
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats))
|
_preprocess_updates = self._preprocess_updates(updates, users, chats)
|
||||||
|
if inspect.isawaitable(_preprocess_updates):
|
||||||
|
_preprocess_updates = await _preprocess_updates
|
||||||
|
updates_to_dispatch.extend(_preprocess_updates)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
deadline = self._message_box.check_deadlines()
|
deadline = self._message_box.check_deadlines()
|
||||||
|
@ -462,7 +468,10 @@ class UpdateMethods:
|
||||||
except GapError:
|
except GapError:
|
||||||
continue # get(_channel)_difference will start returning requests
|
continue # get(_channel)_difference will start returning requests
|
||||||
|
|
||||||
updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats))
|
_preprocess_updates = self._preprocess_updates(processed, users, chats)
|
||||||
|
if inspect.isawaitable(_preprocess_updates):
|
||||||
|
_preprocess_updates = await _preprocess_updates
|
||||||
|
updates_to_dispatch.extend(_preprocess_updates)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -470,9 +479,11 @@ class UpdateMethods:
|
||||||
self._updates_error = e
|
self._updates_error = e
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
def _preprocess_updates(self, updates, users, chats):
|
async def _preprocess_updates(self, updates, users, chats):
|
||||||
self._mb_entity_cache.extend(users, chats)
|
self._mb_entity_cache.extend(users, chats)
|
||||||
self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))
|
process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))
|
||||||
|
if inspect.isawaitable(process_entities):
|
||||||
|
await process_entities
|
||||||
entities = {utils.get_peer_id(x): x
|
entities = {utils.get_peer_id(x): x
|
||||||
for x in itertools.chain(users, chats)}
|
for x in itertools.chain(users, chats)}
|
||||||
for u in updates:
|
for u in updates:
|
||||||
|
@ -515,9 +526,11 @@ class UpdateMethods:
|
||||||
# inserted because this is a rather expensive operation
|
# inserted because this is a rather expensive operation
|
||||||
# (default's sqlite3 takes ~0.1s to commit changes). Do
|
# (default's sqlite3 takes ~0.1s to commit changes). Do
|
||||||
# it every minute instead. No-op if there's nothing new.
|
# it every minute instead. No-op if there's nothing new.
|
||||||
self._save_states_and_entities()
|
await self._save_states_and_entities()
|
||||||
|
|
||||||
self.session.save()
|
save = self.session.save()
|
||||||
|
if inspect.isawaitable(save):
|
||||||
|
await save
|
||||||
|
|
||||||
async def _dispatch_update(self: 'TelegramClient', update):
|
async def _dispatch_update(self: 'TelegramClient', update):
|
||||||
# TODO only used for AlbumHack, and MessageBox is not really designed for this
|
# TODO only used for AlbumHack, and MessageBox is not really designed for this
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
@ -80,7 +81,9 @@ class UserMethods:
|
||||||
exceptions.append(e)
|
exceptions.append(e)
|
||||||
results.append(None)
|
results.append(None)
|
||||||
continue
|
continue
|
||||||
self.session.process_entities(result)
|
process_entities = self.session.process_entities(result)
|
||||||
|
if inspect.isawaitable(process_entities):
|
||||||
|
await process_entities
|
||||||
exceptions.append(None)
|
exceptions.append(None)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
request_index += 1
|
request_index += 1
|
||||||
|
@ -90,7 +93,9 @@ class UserMethods:
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
result = await future
|
result = await future
|
||||||
self.session.process_entities(result)
|
process_entities = self.session.process_entities(result)
|
||||||
|
if inspect.isawaitable(process_entities):
|
||||||
|
await process_entities
|
||||||
return result
|
return result
|
||||||
except (errors.ServerError, errors.RpcCallFailError,
|
except (errors.ServerError, errors.RpcCallFailError,
|
||||||
errors.RpcMcgetFailError, errors.InterdcCallErrorError,
|
errors.RpcMcgetFailError, errors.InterdcCallErrorError,
|
||||||
|
@ -435,7 +440,10 @@ class UserMethods:
|
||||||
|
|
||||||
# No InputPeer, cached peer, or known string. Fetch from disk cache
|
# No InputPeer, cached peer, or known string. Fetch from disk cache
|
||||||
try:
|
try:
|
||||||
return self.session.get_input_entity(peer)
|
input_entity = self.session.get_input_entity(peer)
|
||||||
|
if inspect.isawaitable(input_entity):
|
||||||
|
input_entity = await input_entity
|
||||||
|
return input_entity
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -574,8 +582,10 @@ class UserMethods:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
# Nobody with this username, maybe it's an exact name/title
|
# Nobody with this username, maybe it's an exact name/title
|
||||||
return await self.get_entity(
|
input_entity = self.session.get_input_entity(string)
|
||||||
self.session.get_input_entity(string))
|
if inspect.isawaitable(input_entity):
|
||||||
|
input_entity = await input_entity
|
||||||
|
return await self.get_entity(input_entity)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -302,7 +302,7 @@ class MTProtoSender:
|
||||||
# notify whenever we change it. This is crucial when we
|
# notify whenever we change it. This is crucial when we
|
||||||
# switch to different data centers.
|
# switch to different data centers.
|
||||||
if self._auth_key_callback:
|
if self._auth_key_callback:
|
||||||
self._auth_key_callback(self.auth_key)
|
await self._auth_key_callback(self.auth_key)
|
||||||
|
|
||||||
self._log.debug('auth_key generation success!')
|
self._log.debug('auth_key generation success!')
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in New Issue
Block a user