diff --git a/telethon/client/account.py b/telethon/client/account.py index cc4952a0..b21b043f 100644 --- a/telethon/client/account.py +++ b/telethon/client/account.py @@ -14,14 +14,19 @@ class _TakeoutClient: def __init__(self, client, request): # We use the name mangling for attributes to make them inaccessible - # from within the shadowed client object. + # from within the shadowed client object and to distinguish them from + # its own attributes where needed. self.__client = client self.__request = request + self.__success = True - # After we initialize the proxy object variables, it's necessary to - # translate to the client any write-access to our attributes. - self.__setattr__ = lambda _, name, value: \ - setattr(client, name, value) + @property + def success(self): + return self.__success + + @success.setter + def success(self, value): + self.__success = value def __enter__(self): if self.__client.loop.is_running(): @@ -34,19 +39,23 @@ class _TakeoutClient: async def __aenter__(self): # Enter/Exit behaviour is "overrode", we don't want to call start. - # TODO: Request only if takeout ID isn't set. - self.__client.takeout_id = (await self.__client(self.__request)).id + client = self.__client + if client.session.takeout_id is None: + client.session.takeout_id = (await client(self.__request)).id + elif self.__request is not None: + raise ValueError("Can't send a takeout request while another " + "takeout for the current session still not been finished yet.") return self def __exit__(self, *args): return self.__client.loop.run_until_complete(self.__aexit__(*args)) async def __aexit__(self, *args): - # TODO: Reset only if takeout result is set. - self.__client.takeout_id = None + if self.__success is not None: + await self.__client.end_takeout(self.__success) async def __call__(self, request, ordered=False): - takeout_id = self.__client.takeout_id + takeout_id = self.__client.session.takeout_id if takeout_id is None: raise ValueError('Takeout mode has not been initialized ' '(are you calling outside of "with"?)') @@ -65,6 +74,10 @@ class _TakeoutClient: def __getattribute__(self, name): # We access class via type() because __class__ will recurse infinitely. + # Also note that since we've name-mangled our own class attributes, + # they'll be passed to __getattribute__() as already decorated. For + # example, 'self.__client' will be passed as '_TakeoutClient__client'. + # https://docs.python.org/3/tutorial/classes.html#private-variables if name.startswith('__') and name not in type(self).__PROXY_INTERFACE: raise AttributeError # force call of __getattr__ @@ -82,10 +95,16 @@ class _TakeoutClient: return value + def __setattr__(self, name, value): + if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))): + # This is our own name-mangled attribute, keep calm. + return super().__setattr__(name, value) + return setattr(self.__client, name, value) + class AccountMethods(UserMethods): def takeout( - self, contacts=None, users=None, chats=None, megagroups=None, + self, *, contacts=None, users=None, chats=None, megagroups=None, channels=None, files=None, max_file_size=None): """ Creates a proxy object over the current :ref:`TelegramClient` through @@ -107,14 +126,21 @@ class AccountMethods(UserMethods): to adjust the `wait_time` of methods like `client.iter_messages `. - By default, all parameters are ``False``, and you need to enable - those you plan to use by setting them to ``True``. + By default, all parameters are ``None``, and you need to enable those + you plan to use by setting them to either ``True`` or ``False``. You should ``except errors.TakeoutInitDelayError as e``, since this exception will raise depending on the condition of the session. You can then access ``e.seconds`` to know how long you should wait for before calling the method again. + There's also a `success` property available in the takeout proxy + object, so you can set the boolean takeout result that will be sent + back to Telegram, from within the `with` body. You're also able to + set it to ``None`` so takeout will not be finished and its ID will + be preserved for future usage as `client.session.takeout_id + `. Default is ``True``. + Args: contacts (`bool`): Set to ``True`` if you plan on downloading contacts. @@ -143,7 +169,7 @@ class AccountMethods(UserMethods): The maximum file size, in bytes, that you plan to download for each message with media. """ - return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest( + request_kwargs = dict( contacts=contacts, message_users=users, message_chats=chats, @@ -151,4 +177,26 @@ class AccountMethods(UserMethods): message_channels=channels, files=files, file_max_size=max_file_size - )) + ) + arg_specified = (arg is not None for arg in request_kwargs.values()) + + if self.session.takeout_id is None or any(arg_specified): + request = functions.account.InitTakeoutSessionRequest( + **request_kwargs) + else: + request = None + + return _TakeoutClient(self, request) + + async def end_takeout(self, success): + """ + Finishes a takeout, with specified result sent back to Telegram. + + Returns: + ``True`` if the operation was successful, ``False`` otherwise. + """ + client = _TakeoutClient(self, None) + result = await client( + functions.account.FinishTakeoutSessionRequest(success)) + self.session.takeout_id = None + return result diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index f7763f1d..56a4ae0f 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -233,7 +233,6 @@ class TelegramBaseClient(abc.ABC): # With asynchronous sessions, it would need await, # and defeats the purpose of properties. self.session = session - self.takeout_id = None # TODO: Move to session. self.api_id = int(api_id) self.api_hash = api_hash @@ -263,7 +262,6 @@ class TelegramBaseClient(abc.ABC): ) ) - self._connection = connection self._sender = MTProtoSender( self.session.auth_key, self._loop, loggers=self._log, diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index b6f86a4d..d40d0e32 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -53,6 +53,23 @@ class Session(ABC): """ raise NotImplementedError + @property + @abstractmethod + def takeout_id(self): + """ + Returns an ID of the takeout process initialized for this session, + or ``None`` if there's no were any unfinished takeout requests. + """ + raise NotImplementedError + + @takeout_id.setter + @abstractmethod + def takeout_id(self, value): + """ + Sets the ID of the unfinished takeout process for this session. + """ + raise NotImplementedError + @abstractmethod def get_update_state(self, entity_id): """ diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 09f029ed..cdc144f2 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -32,6 +32,7 @@ class MemorySession(Session): self._server_address = None self._port = None self._auth_key = None + self._takeout_id = None self._files = {} self._entities = set() @@ -62,6 +63,14 @@ class MemorySession(Session): def auth_key(self, value): self._auth_key = value + @property + def takeout_id(self): + return self._takeout_id + + @takeout_id.setter + def takeout_id(self, value): + self._takeout_id = value + def get_update_state(self, entity_id): return self._update_states.get(entity_id, None) diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index ee68449a..ebb773e3 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -18,7 +18,7 @@ except ImportError: sqlite3 = None EXTENSION = '.session' -CURRENT_VERSION = 4 # database version +CURRENT_VERSION = 5 # database version class SQLiteSession(MemorySession): @@ -65,7 +65,8 @@ class SQLiteSession(MemorySession): c.execute('select * from sessions') tuple_ = c.fetchone() if tuple_: - self._dc_id, self._server_address, self._port, key, = tuple_ + self._dc_id, self._server_address, self._port, key, \ + self._takeout_id = tuple_ self._auth_key = AuthKey(data=key) c.close() @@ -79,7 +80,8 @@ class SQLiteSession(MemorySession): dc_id integer primary key, server_address text, port integer, - auth_key blob + auth_key blob, + takeout_id integer )""" , """entities ( @@ -172,6 +174,9 @@ class SQLiteSession(MemorySession): date integer, seq integer )""") + if old == 4: + old += 1 + c.execute("alter table sessions add column takeout_id integer") c.close() @staticmethod @@ -197,6 +202,11 @@ class SQLiteSession(MemorySession): self._auth_key = value self._update_session_table() + @MemorySession.takeout_id.setter + def takeout_id(self, value): + self._takeout_id = value + self._update_session_table() + def _update_session_table(self): c = self._cursor() # While we can save multiple rows into the sessions table @@ -205,11 +215,12 @@ class SQLiteSession(MemorySession): # some more work before being able to save auth_key's for # multiple DCs. Probably done differently. c.execute('delete from sessions') - c.execute('insert or replace into sessions values (?,?,?,?)', ( + c.execute('insert or replace into sessions values (?,?,?,?,?)', ( self._dc_id, self._server_address, self._port, - self._auth_key.key if self._auth_key else b'' + self._auth_key.key if self._auth_key else b'', + self._takeout_id )) c.close() diff --git a/telethon/sessions/string.py b/telethon/sessions/string.py index d1eedfea..666e177b 100644 --- a/telethon/sessions/string.py +++ b/telethon/sessions/string.py @@ -5,12 +5,16 @@ import struct from .memory import MemorySession from ..crypto import AuthKey +_STRUCT_PREFORMAT = '>B{}sH256s' + CURRENT_VERSION = '1' class StringSession(MemorySession): """ - This minimal session file can be easily saved and loaded as a string. + This session file can be easily saved and loaded as a string. According + to the initial design, it contains only the data that is necessary for + successful connection and authentication, so takeout ID is not stored. It is thought to be used where you don't want to create any on-disk files but would still like to be able to save and load existing sessions @@ -33,7 +37,7 @@ class StringSession(MemorySession): string = string[1:] ip_len = 4 if len(string) == 352 else 16 self._dc_id, ip, self._port, key = struct.unpack( - '>B{}sH256s'.format(ip_len), StringSession.decode(string)) + _STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string)) self._server_address = ipaddress.ip_address(ip).compressed if any(key): @@ -45,7 +49,7 @@ class StringSession(MemorySession): ip = ipaddress.ip_address(self._server_address).packed return CURRENT_VERSION + StringSession.encode(struct.pack( - '>B{}sH256s'.format(len(ip)), + _STRUCT_PREFORMAT.format(len(ip)), self._dc_id, ip, self._port,