Add support for storing the takeout ID in the session and revoking it

This commit is contained in:
Dmitry D. Chernov 2019-02-09 06:25:08 +10:00
parent 3ca24d15da
commit 78b6b02342
6 changed files with 112 additions and 25 deletions

View File

@ -14,14 +14,19 @@ class _TakeoutClient:
def __init__(self, client, request):
# We use the name mangling for attributes to make them inaccessible
# from within the shadowed client object.
# from within the shadowed client object and to distinguish them from
# its own attributes where needed.
self.__client = client
self.__request = request
self.__success = True
# After we initialize the proxy object variables, it's necessary to
# translate to the client any write-access to our attributes.
self.__setattr__ = lambda _, name, value: \
setattr(client, name, value)
@property
def success(self):
return self.__success
@success.setter
def success(self, value):
self.__success = value
def __enter__(self):
if self.__client.loop.is_running():
@ -34,19 +39,23 @@ class _TakeoutClient:
async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start.
# TODO: Request only if takeout ID isn't set.
self.__client.takeout_id = (await self.__client(self.__request)).id
client = self.__client
if client.session.takeout_id is None:
client.session.takeout_id = (await client(self.__request)).id
elif self.__request is not None:
raise ValueError("Can't send a takeout request while another "
"takeout for the current session still not been finished yet.")
return self
def __exit__(self, *args):
return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, *args):
# TODO: Reset only if takeout result is set.
self.__client.takeout_id = None
if self.__success is not None:
await self.__client.end_takeout(self.__success)
async def __call__(self, request, ordered=False):
takeout_id = self.__client.takeout_id
takeout_id = self.__client.session.takeout_id
if takeout_id is None:
raise ValueError('Takeout mode has not been initialized '
'(are you calling outside of "with"?)')
@ -65,6 +74,10 @@ class _TakeoutClient:
def __getattribute__(self, name):
# We access class via type() because __class__ will recurse infinitely.
# Also note that since we've name-mangled our own class attributes,
# they'll be passed to __getattribute__() as already decorated. For
# example, 'self.__client' will be passed as '_TakeoutClient__client'.
# https://docs.python.org/3/tutorial/classes.html#private-variables
if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
raise AttributeError # force call of __getattr__
@ -82,10 +95,16 @@ class _TakeoutClient:
return value
def __setattr__(self, name, value):
if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))):
# This is our own name-mangled attribute, keep calm.
return super().__setattr__(name, value)
return setattr(self.__client, name, value)
class AccountMethods(UserMethods):
def takeout(
self, contacts=None, users=None, chats=None, megagroups=None,
self, *, contacts=None, users=None, chats=None, megagroups=None,
channels=None, files=None, max_file_size=None):
"""
Creates a proxy object over the current :ref:`TelegramClient` through
@ -107,14 +126,21 @@ class AccountMethods(UserMethods):
to adjust the `wait_time` of methods like `client.iter_messages
<telethon.client.messages.MessageMethods.iter_messages>`.
By default, all parameters are ``False``, and you need to enable
those you plan to use by setting them to ``True``.
By default, all parameters are ``None``, and you need to enable those
you plan to use by setting them to either ``True`` or ``False``.
You should ``except errors.TakeoutInitDelayError as e``, since this
exception will raise depending on the condition of the session. You
can then access ``e.seconds`` to know how long you should wait for
before calling the method again.
There's also a `success` property available in the takeout proxy
object, so you can set the boolean takeout result that will be sent
back to Telegram, from within the `with` body. You're also able to
set it to ``None`` so takeout will not be finished and its ID will
be preserved for future usage as `client.session.takeout_id
<telethon.sessions.abstract.Session.takeout_id>`. Default is ``True``.
Args:
contacts (`bool`):
Set to ``True`` if you plan on downloading contacts.
@ -143,7 +169,7 @@ class AccountMethods(UserMethods):
The maximum file size, in bytes, that you plan
to download for each message with media.
"""
return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest(
request_kwargs = dict(
contacts=contacts,
message_users=users,
message_chats=chats,
@ -151,4 +177,26 @@ class AccountMethods(UserMethods):
message_channels=channels,
files=files,
file_max_size=max_file_size
))
)
arg_specified = (arg is not None for arg in request_kwargs.values())
if self.session.takeout_id is None or any(arg_specified):
request = functions.account.InitTakeoutSessionRequest(
**request_kwargs)
else:
request = None
return _TakeoutClient(self, request)
async def end_takeout(self, success):
"""
Finishes a takeout, with specified result sent back to Telegram.
Returns:
``True`` if the operation was successful, ``False`` otherwise.
"""
client = _TakeoutClient(self, None)
result = await client(
functions.account.FinishTakeoutSessionRequest(success))
self.session.takeout_id = None
return result

View File

@ -233,7 +233,6 @@ class TelegramBaseClient(abc.ABC):
# With asynchronous sessions, it would need await,
# and defeats the purpose of properties.
self.session = session
self.takeout_id = None # TODO: Move to session.
self.api_id = int(api_id)
self.api_hash = api_hash
@ -263,7 +262,6 @@ class TelegramBaseClient(abc.ABC):
)
)
self._connection = connection
self._sender = MTProtoSender(
self.session.auth_key, self._loop,
loggers=self._log,

View File

@ -53,6 +53,23 @@ class Session(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or ``None`` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
"""
raise NotImplementedError
@abstractmethod
def get_update_state(self, entity_id):
"""

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None
self._port = None
self._auth_key = None
self._takeout_id = None
self._files = {}
self._entities = set()
@ -62,6 +63,14 @@ class MemorySession(Session):
def auth_key(self, value):
self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)

View File

@ -18,7 +18,7 @@ except ImportError:
sqlite3 = None
EXTENSION = '.session'
CURRENT_VERSION = 4 # database version
CURRENT_VERSION = 5 # database version
class SQLiteSession(MemorySession):
@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_
self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key)
c.close()
@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
dc_id integer primary key,
server_address text,
port integer,
auth_key blob
auth_key blob,
takeout_id integer
)"""
,
"""entities (
@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
date integer,
seq integer
)""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
c.close()
@staticmethod
@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
self._auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self):
c = self._cursor()
# While we can save multiple rows into the sessions table
@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
self._auth_key.key if self._auth_key else b'',
self._takeout_id
))
c.close()

View File

@ -5,12 +5,16 @@ import struct
from .memory import MemorySession
from ..crypto import AuthKey
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1'
class StringSession(MemorySession):
"""
This minimal session file can be easily saved and loaded as a string.
This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions
@ -33,7 +37,7 @@ class StringSession(MemorySession):
string = string[1:]
ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack(
'>B{}sH256s'.format(ip_len), StringSession.decode(string))
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed
if any(key):
@ -45,7 +49,7 @@ class StringSession(MemorySession):
ip = ipaddress.ip_address(self._server_address).packed
return CURRENT_VERSION + StringSession.encode(struct.pack(
'>B{}sH256s'.format(len(ip)),
_STRUCT_PREFORMAT.format(len(ip)),
self._dc_id,
ip,
self._port,