Improve TakeoutClient proxy and takeout functionality (#1106)

This commit is contained in:
Dmitry D. Chernov 2019-02-10 20:10:41 +10:00 committed by Lonami
parent 274fa72a8c
commit 9a98d41a2c
6 changed files with 153 additions and 51 deletions

View File

@ -8,45 +8,65 @@ from ..tl import functions, TLRequest
class _TakeoutClient:
"""
Proxy object over the client. `c` is the client, `k` it's class,
`r` is the takeout request, and `t` is the takeout ID.
Proxy object over the client.
"""
def __init__(self, client, request):
# We're a proxy object with __getattribute__overrode so we
# need to set attributes through the super class `object`.
super().__setattr__('c', client)
super().__setattr__('k', client.__class__)
super().__setattr__('r', request)
super().__setattr__('t', None)
__PROXY_INTERFACE = ('__enter__', '__exit__', '__aenter__', '__aexit__')
def __init__(self, finalize, client, request):
# We use the name mangling for attributes to make them inaccessible
# from within the shadowed client object and to distinguish them from
# its own attributes where needed.
self.__finalize = finalize
self.__client = client
self.__request = request
self.__success = None
@property
def success(self):
return self.__success
@success.setter
def success(self, value):
self.__success = value
def __enter__(self):
# We also get self attributes through super()
if super().__getattribute__('c').loop.is_running():
if self.__client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return super().__getattribute__(
'c').loop.run_until_complete(self.__aenter__())
return self.__client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start
cl = super().__getattribute__('c')
super().__setattr__('t', (await cl(super().__getattribute__('r'))).id)
# Enter/Exit behaviour is "overrode", we don't want to call start.
client = self.__client
if client.session.takeout_id is None:
client.session.takeout_id = (await client(self.__request)).id
elif self.__request is not None:
raise ValueError("Can't send a takeout request while another "
"takeout for the current session still not been finished yet.")
return self
def __exit__(self, *args):
return super().__getattribute__(
'c').loop.run_until_complete(self.__aexit__(*args))
return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, *args):
super().__setattr__('t', None)
async def __aexit__(self, exc_type, exc_value, traceback):
if self.__success is None and self.__finalize:
self.__success = exc_type is None
if self.__success is not None:
result = await self(functions.account.FinishTakeoutSessionRequest(
self.__success))
if not result:
raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None
async def __call__(self, request, ordered=False):
takeout_id = super().__getattribute__('t')
takeout_id = self.__client.session.takeout_id
if takeout_id is None:
raise ValueError('Cannot call takeout methods outside of "with"')
raise ValueError('Takeout mode has not been initialized '
'(are you calling outside of "with"?)')
single = not utils.is_list_like(request)
requests = ((request,) if single else request)
@ -57,34 +77,43 @@ class _TakeoutClient:
await r.resolve(self, utils)
wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r))
return await super().__getattribute__('c')(
return await self.__client(
wrapped[0] if single else wrapped, ordered=ordered)
def __getattribute__(self, name):
if name.startswith('__'):
# We want to override special method names
if name == '__class__':
# See https://github.com/LonamiWebs/Telethon/issues/1103.
name = 'k'
return super().__getattribute__(name)
# We access class via type() because __class__ will recurse infinitely.
# Also note that since we've name-mangled our own class attributes,
# they'll be passed to __getattribute__() as already decorated. For
# example, 'self.__client' will be passed as '_TakeoutClient__client'.
# https://docs.python.org/3/tutorial/classes.html#private-variables
if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
raise AttributeError # force call of __getattr__
value = getattr(super().__getattribute__('c'), name)
# Try to access attribute in the proxy object and check for the same
# attribute in the shadowed object (through our __getattr__) if failed.
return super().__getattribute__(name)
def __getattr__(self, name):
value = getattr(self.__client, name)
if inspect.ismethod(value):
# Emulate bound methods behaviour by partially applying
# our proxy class as the self parameter instead of the client
# Emulate bound methods behavior by partially applying our proxy
# class as the self parameter instead of the client.
return functools.partial(
getattr(super().__getattribute__('k'), name), self)
else:
return value
getattr(self.__client.__class__, name), self)
return value
def __setattr__(self, name, value):
setattr(super().__getattribute__('c'), name, value)
if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))):
# This is our own name-mangled attribute, keep calm.
return super().__setattr__(name, value)
return setattr(self.__client, name, value)
class AccountMethods(UserMethods):
def takeout(
self, contacts=None, users=None, chats=None, megagroups=None,
channels=None, files=None, max_file_size=None):
self, finalize=True, *, contacts=None, users=None, chats=None,
megagroups=None, channels=None, files=None, max_file_size=None):
"""
Creates a proxy object over the current :ref:`TelegramClient` through
which making requests will use :tl:`InvokeWithTakeoutRequest` to wrap
@ -105,14 +134,24 @@ class AccountMethods(UserMethods):
to adjust the `wait_time` of methods like `client.iter_messages
<telethon.client.messages.MessageMethods.iter_messages>`.
By default, all parameters are ``False``, and you need to enable
those you plan to use by setting them to ``True``.
By default, all parameters are ``None``, and you need to enable those
you plan to use by setting them to either ``True`` or ``False``.
You should ``except errors.TakeoutInitDelayError as e``, since this
exception will raise depending on the condition of the session. You
can then access ``e.seconds`` to know how long you should wait for
before calling the method again.
There's also a `success` property available in the takeout proxy
object, so from the `with` body you can set the boolean result that
will be sent back to Telegram. But if it's left ``None`` as by
default, then the action is based on the `finalize` parameter. If
it's ``True`` then the takeout will be finished, and if no exception
occurred during it, then ``True`` will be considered as a result.
Otherwise, the takeout will not be finished and its ID will be
preserved for future usage as `client.session.takeout_id
<telethon.sessions.abstract.Session.takeout_id>`.
Args:
contacts (`bool`):
Set to ``True`` if you plan on downloading contacts.
@ -141,7 +180,7 @@ class AccountMethods(UserMethods):
The maximum file size, in bytes, that you plan
to download for each message with media.
"""
return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest(
request_kwargs = dict(
contacts=contacts,
message_users=users,
message_chats=chats,
@ -149,4 +188,27 @@ class AccountMethods(UserMethods):
message_channels=channels,
files=files,
file_max_size=max_file_size
))
)
arg_specified = (arg is not None for arg in request_kwargs.values())
if self.session.takeout_id is None or any(arg_specified):
request = functions.account.InitTakeoutSessionRequest(
**request_kwargs)
else:
request = None
return _TakeoutClient(finalize, self, request)
async def end_takeout(self, success):
"""
Finishes a takeout, with specified result sent back to Telegram.
Returns:
``True`` if the operation was successful, ``False`` otherwise.
"""
try:
async with _TakeoutClient(True, self, None) as takeout:
takeout.success = success
except ValueError:
return False
return True

View File

@ -262,7 +262,6 @@ class TelegramBaseClient(abc.ABC):
)
)
self._connection = connection
self._sender = MTProtoSender(
self.session.auth_key, self._loop,
loggers=self._log,

View File

@ -53,6 +53,23 @@ class Session(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or ``None`` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
"""
raise NotImplementedError
@abstractmethod
def get_update_state(self, entity_id):
"""

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None
self._port = None
self._auth_key = None
self._takeout_id = None
self._files = {}
self._entities = set()
@ -62,6 +63,14 @@ class MemorySession(Session):
def auth_key(self, value):
self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)

View File

@ -18,7 +18,7 @@ except ImportError:
sqlite3 = None
EXTENSION = '.session'
CURRENT_VERSION = 4 # database version
CURRENT_VERSION = 5 # database version
class SQLiteSession(MemorySession):
@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_
self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key)
c.close()
@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
dc_id integer primary key,
server_address text,
port integer,
auth_key blob
auth_key blob,
takeout_id integer
)"""
,
"""entities (
@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
date integer,
seq integer
)""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
c.close()
@staticmethod
@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
self._auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self):
c = self._cursor()
# While we can save multiple rows into the sessions table
@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
self._auth_key.key if self._auth_key else b'',
self._takeout_id
))
c.close()

View File

@ -5,12 +5,16 @@ import struct
from .memory import MemorySession
from ..crypto import AuthKey
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1'
class StringSession(MemorySession):
"""
This minimal session file can be easily saved and loaded as a string.
This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions
@ -33,7 +37,7 @@ class StringSession(MemorySession):
string = string[1:]
ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack(
'>B{}sH256s'.format(ip_len), StringSession.decode(string))
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed
if any(key):
@ -45,7 +49,7 @@ class StringSession(MemorySession):
ip = ipaddress.ip_address(self._server_address).packed
return CURRENT_VERSION + StringSession.encode(struct.pack(
'>B{}sH256s'.format(len(ip)),
_STRUCT_PREFORMAT.format(len(ip)),
self._dc_id,
ip,
self._port,