Add support for storing the takeout ID in the session and revoking it

This commit is contained in:
Dmitry D. Chernov 2019-02-09 06:25:08 +10:00
parent 3ca24d15da
commit 78b6b02342
6 changed files with 112 additions and 25 deletions

View File

@ -14,14 +14,19 @@ class _TakeoutClient:
def __init__(self, client, request): def __init__(self, client, request):
# We use the name mangling for attributes to make them inaccessible # We use the name mangling for attributes to make them inaccessible
# from within the shadowed client object. # from within the shadowed client object and to distinguish them from
# its own attributes where needed.
self.__client = client self.__client = client
self.__request = request self.__request = request
self.__success = True
# After we initialize the proxy object variables, it's necessary to @property
# translate to the client any write-access to our attributes. def success(self):
self.__setattr__ = lambda _, name, value: \ return self.__success
setattr(client, name, value)
@success.setter
def success(self, value):
self.__success = value
def __enter__(self): def __enter__(self):
if self.__client.loop.is_running(): if self.__client.loop.is_running():
@ -34,19 +39,23 @@ class _TakeoutClient:
async def __aenter__(self): async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start. # Enter/Exit behaviour is "overrode", we don't want to call start.
# TODO: Request only if takeout ID isn't set. client = self.__client
self.__client.takeout_id = (await self.__client(self.__request)).id if client.session.takeout_id is None:
client.session.takeout_id = (await client(self.__request)).id
elif self.__request is not None:
raise ValueError("Can't send a takeout request while another "
"takeout for the current session still not been finished yet.")
return self return self
def __exit__(self, *args): def __exit__(self, *args):
return self.__client.loop.run_until_complete(self.__aexit__(*args)) return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, *args): async def __aexit__(self, *args):
# TODO: Reset only if takeout result is set. if self.__success is not None:
self.__client.takeout_id = None await self.__client.end_takeout(self.__success)
async def __call__(self, request, ordered=False): async def __call__(self, request, ordered=False):
takeout_id = self.__client.takeout_id takeout_id = self.__client.session.takeout_id
if takeout_id is None: if takeout_id is None:
raise ValueError('Takeout mode has not been initialized ' raise ValueError('Takeout mode has not been initialized '
'(are you calling outside of "with"?)') '(are you calling outside of "with"?)')
@ -65,6 +74,10 @@ class _TakeoutClient:
def __getattribute__(self, name): def __getattribute__(self, name):
# We access class via type() because __class__ will recurse infinitely. # We access class via type() because __class__ will recurse infinitely.
# Also note that since we've name-mangled our own class attributes,
# they'll be passed to __getattribute__() as already decorated. For
# example, 'self.__client' will be passed as '_TakeoutClient__client'.
# https://docs.python.org/3/tutorial/classes.html#private-variables
if name.startswith('__') and name not in type(self).__PROXY_INTERFACE: if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
raise AttributeError # force call of __getattr__ raise AttributeError # force call of __getattr__
@ -82,10 +95,16 @@ class _TakeoutClient:
return value return value
def __setattr__(self, name, value):
if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))):
# This is our own name-mangled attribute, keep calm.
return super().__setattr__(name, value)
return setattr(self.__client, name, value)
class AccountMethods(UserMethods): class AccountMethods(UserMethods):
def takeout( def takeout(
self, contacts=None, users=None, chats=None, megagroups=None, self, *, contacts=None, users=None, chats=None, megagroups=None,
channels=None, files=None, max_file_size=None): channels=None, files=None, max_file_size=None):
""" """
Creates a proxy object over the current :ref:`TelegramClient` through Creates a proxy object over the current :ref:`TelegramClient` through
@ -107,14 +126,21 @@ class AccountMethods(UserMethods):
to adjust the `wait_time` of methods like `client.iter_messages to adjust the `wait_time` of methods like `client.iter_messages
<telethon.client.messages.MessageMethods.iter_messages>`. <telethon.client.messages.MessageMethods.iter_messages>`.
By default, all parameters are ``False``, and you need to enable By default, all parameters are ``None``, and you need to enable those
those you plan to use by setting them to ``True``. you plan to use by setting them to either ``True`` or ``False``.
You should ``except errors.TakeoutInitDelayError as e``, since this You should ``except errors.TakeoutInitDelayError as e``, since this
exception will raise depending on the condition of the session. You exception will raise depending on the condition of the session. You
can then access ``e.seconds`` to know how long you should wait for can then access ``e.seconds`` to know how long you should wait for
before calling the method again. before calling the method again.
There's also a `success` property available in the takeout proxy
object, so you can set the boolean takeout result that will be sent
back to Telegram, from within the `with` body. You're also able to
set it to ``None`` so takeout will not be finished and its ID will
be preserved for future usage as `client.session.takeout_id
<telethon.sessions.abstract.Session.takeout_id>`. Default is ``True``.
Args: Args:
contacts (`bool`): contacts (`bool`):
Set to ``True`` if you plan on downloading contacts. Set to ``True`` if you plan on downloading contacts.
@ -143,7 +169,7 @@ class AccountMethods(UserMethods):
The maximum file size, in bytes, that you plan The maximum file size, in bytes, that you plan
to download for each message with media. to download for each message with media.
""" """
return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest( request_kwargs = dict(
contacts=contacts, contacts=contacts,
message_users=users, message_users=users,
message_chats=chats, message_chats=chats,
@ -151,4 +177,26 @@ class AccountMethods(UserMethods):
message_channels=channels, message_channels=channels,
files=files, files=files,
file_max_size=max_file_size file_max_size=max_file_size
)) )
arg_specified = (arg is not None for arg in request_kwargs.values())
if self.session.takeout_id is None or any(arg_specified):
request = functions.account.InitTakeoutSessionRequest(
**request_kwargs)
else:
request = None
return _TakeoutClient(self, request)
async def end_takeout(self, success):
"""
Finishes a takeout, with specified result sent back to Telegram.
Returns:
``True`` if the operation was successful, ``False`` otherwise.
"""
client = _TakeoutClient(self, None)
result = await client(
functions.account.FinishTakeoutSessionRequest(success))
self.session.takeout_id = None
return result

View File

@ -233,7 +233,6 @@ class TelegramBaseClient(abc.ABC):
# With asynchronous sessions, it would need await, # With asynchronous sessions, it would need await,
# and defeats the purpose of properties. # and defeats the purpose of properties.
self.session = session self.session = session
self.takeout_id = None # TODO: Move to session.
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -263,7 +262,6 @@ class TelegramBaseClient(abc.ABC):
) )
) )
self._connection = connection
self._sender = MTProtoSender( self._sender = MTProtoSender(
self.session.auth_key, self._loop, self.session.auth_key, self._loop,
loggers=self._log, loggers=self._log,

View File

@ -53,6 +53,23 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or ``None`` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
""" """

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None self._server_address = None
self._port = None self._port = None
self._auth_key = None self._auth_key = None
self._takeout_id = None
self._files = {} self._files = {}
self._entities = set() self._entities = set()
@ -62,6 +63,14 @@ class MemorySession(Session):
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)

View File

@ -18,7 +18,7 @@ except ImportError:
sqlite3 = None sqlite3 = None
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 4 # database version CURRENT_VERSION = 5 # database version
class SQLiteSession(MemorySession): class SQLiteSession(MemorySession):
@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions') c.execute('select * from sessions')
tuple_ = c.fetchone() tuple_ = c.fetchone()
if tuple_: if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_ self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key) self._auth_key = AuthKey(data=key)
c.close() c.close()
@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
dc_id integer primary key, dc_id integer primary key,
server_address text, server_address text,
port integer, port integer,
auth_key blob auth_key blob,
takeout_id integer
)""" )"""
, ,
"""entities ( """entities (
@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
date integer, date integer,
seq integer seq integer
)""") )""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
c.close() c.close()
@staticmethod @staticmethod
@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
self._auth_key = value self._auth_key = value
self._update_session_table() self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self): def _update_session_table(self):
c = self._cursor() c = self._cursor()
# While we can save multiple rows into the sessions table # While we can save multiple rows into the sessions table
@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for # some more work before being able to save auth_key's for
# multiple DCs. Probably done differently. # multiple DCs. Probably done differently.
c.execute('delete from sessions') c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', ( c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id, self._dc_id,
self._server_address, self._server_address,
self._port, self._port,
self._auth_key.key if self._auth_key else b'' self._auth_key.key if self._auth_key else b'',
self._takeout_id
)) ))
c.close() c.close()

View File

@ -5,12 +5,16 @@ import struct
from .memory import MemorySession from .memory import MemorySession
from ..crypto import AuthKey from ..crypto import AuthKey
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1' CURRENT_VERSION = '1'
class StringSession(MemorySession): class StringSession(MemorySession):
""" """
This minimal session file can be easily saved and loaded as a string. This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions files but would still like to be able to save and load existing sessions
@ -33,7 +37,7 @@ class StringSession(MemorySession):
string = string[1:] string = string[1:]
ip_len = 4 if len(string) == 352 else 16 ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack( self._dc_id, ip, self._port, key = struct.unpack(
'>B{}sH256s'.format(ip_len), StringSession.decode(string)) _STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed self._server_address = ipaddress.ip_address(ip).compressed
if any(key): if any(key):
@ -45,7 +49,7 @@ class StringSession(MemorySession):
ip = ipaddress.ip_address(self._server_address).packed ip = ipaddress.ip_address(self._server_address).packed
return CURRENT_VERSION + StringSession.encode(struct.pack( return CURRENT_VERSION + StringSession.encode(struct.pack(
'>B{}sH256s'.format(len(ip)), _STRUCT_PREFORMAT.format(len(ip)),
self._dc_id, self._dc_id,
ip, ip,
self._port, self._port,