Make session, api ID and hash private

This commit is contained in:
Lonami Exo 2019-06-07 21:12:27 +02:00
parent 9bafcdfe0f
commit 80e86e98ff
5 changed files with 45 additions and 33 deletions

View File

@ -54,7 +54,7 @@ class _TakeoutClient:
self.__success)) self.__success))
if not result: if not result:
raise ValueError("Failed to finish the takeout.") raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None self._session.takeout_id = None
__enter__ = helpers._sync_enter __enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit __exit__ = helpers._sync_exit
@ -211,7 +211,7 @@ class AccountMethods(UserMethods):
) )
arg_specified = (arg is not None for arg in request_kwargs.values()) arg_specified = (arg is not None for arg in request_kwargs.values())
if self.session.takeout_id is None or any(arg_specified): if self._session.takeout_id is None or any(arg_specified):
request = functions.account.InitTakeoutSessionRequest( request = functions.account.InitTakeoutSessionRequest(
**request_kwargs) **request_kwargs)
else: else:

View File

@ -330,7 +330,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
elif bot_token: elif bot_token:
result = await self(functions.auth.ImportBotAuthorizationRequest( result = await self(functions.auth.ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token, flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash api_id=self._api_id, api_hash=self._api_hash
)) ))
else: else:
raise ValueError('You must provide a code, password or bot token') raise ValueError('You must provide a code, password or bot token')
@ -465,7 +465,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
if not phone_hash: if not phone_hash:
try: try:
result = await self(functions.auth.SendCodeRequest( result = await self(functions.auth.SendCodeRequest(
phone, self.api_id, self.api_hash, types.CodeSettings())) phone, self._api_id, self._api_hash, types.CodeSettings()))
except errors.AuthRestartError: except errors.AuthRestartError:
return await self.send_code_request(phone, force_sms=force_sms) return await self.send_code_request(phone, force_sms=force_sms)
@ -508,7 +508,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._state_cache.reset() self._state_cache.reset()
await self.disconnect() await self.disconnect()
self.session.delete() self._session.delete()
return True return True
async def edit_2fa( async def edit_2fa(

View File

@ -245,10 +245,10 @@ class TelegramBaseClient(abc.ABC):
# them to disk, and to save additional useful information. # them to disk, and to save additional useful information.
# TODO Session should probably return all cached # TODO Session should probably return all cached
# info of entities, not just the input versions # info of entities, not just the input versions
self.session = session self._session = session
self._entity_cache = EntityCache() self._entity_cache = EntityCache()
self.api_id = int(api_id) self._api_id = int(api_id)
self.api_hash = api_hash self._api_hash = api_hash
self._request_retries = request_retries self._request_retries = request_retries
self._connection_retries = connection_retries self._connection_retries = connection_retries
@ -267,7 +267,7 @@ class TelegramBaseClient(abc.ABC):
system = platform.uname() system = platform.uname()
self._init_with = lambda x: functions.InvokeWithLayerRequest( self._init_with = lambda x: functions.InvokeWithLayerRequest(
LAYER, functions.InitConnectionRequest( LAYER, functions.InitConnectionRequest(
api_id=self.api_id, api_id=self._api_id,
device_model=device_model or system.system or 'Unknown', device_model=device_model or system.system or 'Unknown',
system_version=system_version or system.release or '1.0', system_version=system_version or system.release or '1.0',
app_version=app_version or self.__version__, app_version=app_version or self.__version__,
@ -280,7 +280,7 @@ class TelegramBaseClient(abc.ABC):
) )
self._sender = MTProtoSender( self._sender = MTProtoSender(
self.session.auth_key, self._loop, self._session.auth_key, self._loop,
loggers=self._log, loggers=self._log,
retries=self._connection_retries, retries=self._connection_retries,
delay=self._retry_delay, delay=self._retry_delay,
@ -318,7 +318,7 @@ class TelegramBaseClient(abc.ABC):
# Update state (for catching up after a disconnection) # Update state (for catching up after a disconnection)
# TODO Get state from channels too # TODO Get state from channels too
self._state_cache = StateCache( self._state_cache = StateCache(
self.session.get_update_state(0), self._log) self._session.get_update_state(0), self._log)
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -362,6 +362,18 @@ class TelegramBaseClient(abc.ABC):
""" """
return self._loop return self._loop
@property
def session(self) -> Session:
"""
The ``Session`` instance used by the client.
Example
.. code-block:: python
client.session.set_dc(dc_id, ip, port)
"""
return self._session
@property @property
def disconnected(self: 'TelegramClient') -> asyncio.Future: def disconnected(self: 'TelegramClient') -> asyncio.Future:
""" """
@ -405,15 +417,15 @@ class TelegramBaseClient(abc.ABC):
print('Failed to connect') print('Failed to connect')
""" """
await self._sender.connect(self._connection( await self._sender.connect(self._connection(
self.session.server_address, self._session.server_address,
self.session.port, self._session.port,
self.session.dc_id, self._session.dc_id,
loop=self._loop, loop=self._loop,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy
)) ))
self.session.auth_key = self._sender.auth_key self._session.auth_key = self._sender.auth_key
self.session.save() self._session.save()
await self._sender.send(self._init_with( await self._sender.send(self._init_with(
functions.help.GetConfigRequest())) functions.help.GetConfigRequest()))
@ -476,7 +488,7 @@ class TelegramBaseClient(abc.ABC):
pts, date = self._state_cache[None] pts, date = self._state_cache[None]
if pts and date: if pts and date:
self.session.set_update_state(0, types.updates.State( self._session.set_update_state(0, types.updates.State(
pts=pts, pts=pts,
qts=0, qts=0,
date=date, date=date,
@ -484,7 +496,7 @@ class TelegramBaseClient(abc.ABC):
unread_count=0 unread_count=0
)) ))
self.session.close() self._session.close()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -504,12 +516,12 @@ class TelegramBaseClient(abc.ABC):
self._log[__name__].info('Reconnecting to new data center %s', new_dc) self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc) dc = await self._get_dc(new_dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) self._session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self._session.auth_key = None
self.session.save() self._session.save()
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
@ -518,8 +530,8 @@ class TelegramBaseClient(abc.ABC):
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self._session.auth_key = auth_key
self.session.save() self._session.save()
# endregion # endregion
@ -622,13 +634,13 @@ class TelegramBaseClient(abc.ABC):
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = self._session.clone()
await session.set_dc(dc.id, dc.ip_address, dc.port) await session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self._api_id, self._api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout()
) )

View File

@ -225,7 +225,7 @@ class UpdateMethods(UserMethods):
if not pts: if not pts:
return return
self.session.catching_up = True self._session.catching_up = True
try: try:
while True: while True:
d = await self(functions.updates.GetDifferenceRequest( d = await self(functions.updates.GetDifferenceRequest(
@ -275,7 +275,7 @@ class UpdateMethods(UserMethods):
finally: finally:
# TODO Save new pts to session # TODO Save new pts to session
self._state_cache._pts_date = (pts, date) self._state_cache._pts_date = (pts, date)
self.session.catching_up = False self._session.catching_up = False
# endregion # endregion
@ -285,7 +285,7 @@ class UpdateMethods(UserMethods):
# the order that the updates arrive in to update the pts and date to # the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async. # be always-increasing. There is also no need to make this async.
def _handle_update(self: 'TelegramClient', update): def _handle_update(self: 'TelegramClient', update):
self.session.process_entities(update) self._session.process_entities(update)
self._entity_cache.add(update) self._entity_cache.add(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)): if isinstance(update, (types.Updates, types.UpdatesCombined)):
@ -347,7 +347,7 @@ class UpdateMethods(UserMethods):
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
self.session.save() self._session.save()
# We need to send some content-related request at least hourly # We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will # for Telegram to keep delivering updates, otherwise they will

View File

@ -52,7 +52,7 @@ class UserMethods(TelegramBaseClient):
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) self._session.process_entities(result)
self._entity_cache.add(result) self._entity_cache.add(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
@ -63,7 +63,7 @@ class UserMethods(TelegramBaseClient):
return results return results
else: else:
result = await future result = await future
self.session.process_entities(result) self._session.process_entities(result)
self._entity_cache.add(result) self._entity_cache.add(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
@ -377,7 +377,7 @@ class UserMethods(TelegramBaseClient):
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
return self.session.get_input_entity(peer) return self._session.get_input_entity(peer)
except ValueError: except ValueError:
pass pass
@ -513,7 +513,7 @@ class UserMethods(TelegramBaseClient):
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( return await self.get_entity(
self.session.get_input_entity(string)) self._session.get_input_entity(string))
except ValueError: except ValueError:
pass pass