diff --git a/README.rst b/README.rst index a2e0d3de..73aafccb 100755 --- a/README.rst +++ b/README.rst @@ -30,6 +30,7 @@ Creating a client .. code:: python + import asyncio from telethon import TelegramClient # These example values won't work. You must get your own api_id and @@ -38,22 +39,28 @@ Creating a client api_hash = '0123456789abcdef0123456789abcdef' client = TelegramClient('session_name', api_id, api_hash) - client.start() + async def main(): + await client.start() + asyncio.get_event_loop().run_until_complete(main()) Doing stuff ----------- +Note that this assumes you're inside an "async def" method. Check out the +`Python documentation `_ +if you're new with ``asyncio``. + .. code:: python print(client.get_me().stringify()) - client.send_message('username', 'Hello! Talking to you from Telethon') - client.send_file('username', '/home/myself/Pictures/holidays.jpg') + await client.send_message('username', 'Hello! Talking to you from Telethon') + await client.send_file('username', '/home/myself/Pictures/holidays.jpg') - client.download_profile_photo('me') - messages = client.get_message_history('username') - client.download_media(messages[0]) + await client.download_profile_photo('me') + messages = await client.get_message_history('username') + await client.download_media(messages[0]) Next steps @@ -61,5 +68,7 @@ Next steps Do you like how Telethon looks? Check out `Read The Docs `_ -for a more in-depth explanation, with examples, -troubleshooting issues, and more useful information. +for a more in-depth explanation, with examples, troubleshooting issues, +and more useful information. Note that the examples there are written for +the threaded version, not the one using asyncio. However, you just need to +await every remote call. diff --git a/telethon/crypto/cdn_decrypter.py b/telethon/crypto/cdn_decrypter.py index 24a4bb49..a7f3efdf 100644 --- a/telethon/crypto/cdn_decrypter.py +++ b/telethon/crypto/cdn_decrypter.py @@ -30,7 +30,7 @@ class CdnDecrypter: self.cdn_file_hashes = cdn_file_hashes @staticmethod - def prepare_decrypter(client, cdn_client, cdn_redirect): + async def prepare_decrypter(client, cdn_client, cdn_redirect): """ Prepares a new CDN decrypter. @@ -52,14 +52,14 @@ class CdnDecrypter: cdn_aes, cdn_redirect.cdn_file_hashes ) - cdn_file = cdn_client(GetCdnFileRequest( + cdn_file = await cdn_client(GetCdnFileRequest( file_token=cdn_redirect.file_token, offset=cdn_redirect.cdn_file_hashes[0].offset, limit=cdn_redirect.cdn_file_hashes[0].limit )) if isinstance(cdn_file, CdnFileReuploadNeeded): # We need to use the original client here - client(ReuploadCdnFileRequest( + await client(ReuploadCdnFileRequest( file_token=cdn_redirect.file_token, request_token=cdn_file.request_token )) @@ -73,7 +73,7 @@ class CdnDecrypter: return decrypter, cdn_file - def get_file(self): + async def get_file(self): """ Calls GetCdnFileRequest and decrypts its bytes. Also ensures that the file hasn't been tampered. @@ -82,7 +82,7 @@ class CdnDecrypter: """ if self.cdn_file_hashes: cdn_hash = self.cdn_file_hashes.pop(0) - cdn_file = self.client(GetCdnFileRequest( + cdn_file = await self.client(GetCdnFileRequest( self.file_token, cdn_hash.offset, cdn_hash.limit )) cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes) diff --git a/telethon/events/__init__.py b/telethon/events/__init__.py index 5966a120..6c909180 100644 --- a/telethon/events/__init__.py +++ b/telethon/events/__init__.py @@ -9,7 +9,7 @@ from ..extensions import markdown from ..tl import types, functions -def _into_id_set(client, chats): +async def _into_id_set(client, chats): """Helper util to turn the input chat or chats into a set of IDs.""" if chats is None: return None @@ -19,9 +19,9 @@ def _into_id_set(client, chats): result = set() for chat in chats: - chat = client.get_input_entity(chat) + chat = await client.get_input_entity(chat) if isinstance(chat, types.InputPeerSelf): - chat = client.get_me(input_peer=True) + chat = await client.get_me(input_peer=True) result.add(utils.get_peer_id(chat)) return result @@ -48,10 +48,10 @@ class _EventBuilder(abc.ABC): def build(self, update): """Builds an event for the given update if possible, or returns None""" - def resolve(self, client): + async def resolve(self, client): """Helper method to allow event builders to be resolved before usage""" - self.chats = _into_id_set(client, self.chats) - self._self_id = client.get_me(input_peer=True).user_id + self.chats = await _into_id_set(client, self.chats) + self._self_id = (await client.get_me(input_peer=True)).user_id def _filter_event(self, event): """ @@ -86,7 +86,7 @@ class _EventCommon(abc.ABC): ) self.is_channel = isinstance(chat_peer, types.PeerChannel) - def _get_input_entity(self, msg_id, entity_id, chat=None): + async def _get_input_entity(self, msg_id, entity_id, chat=None): """ Helper function to call GetMessages on the give msg_id and return the input entity whose ID is the given entity ID. @@ -95,11 +95,11 @@ class _EventCommon(abc.ABC): """ try: if isinstance(chat, types.InputPeerChannel): - result = self._client( + result = await self._client( functions.channels.GetMessagesRequest(chat, [msg_id]) ) else: - result = self._client( + result = await self._client( functions.messages.GetMessagesRequest([msg_id]) ) except RPCError: @@ -113,7 +113,7 @@ class _EventCommon(abc.ABC): return utils.get_input_peer(entity) @property - def input_chat(self): + async def input_chat(self): """ The (:obj:`InputPeer`) (group, megagroup or channel) on which the event occurred. This doesn't have the title or anything, @@ -125,7 +125,7 @@ class _EventCommon(abc.ABC): if self._input_chat is None and self._chat_peer is not None: try: - self._input_chat = self._client.get_input_entity( + self._input_chat = await self._client.get_input_entity( self._chat_peer ) except (ValueError, TypeError): @@ -134,22 +134,22 @@ class _EventCommon(abc.ABC): # TODO For channels, getDifference? Maybe looking # in the dialogs (which is already done) is enough. if self._message_id is not None: - self._input_chat = self._get_input_entity( + self._input_chat = await self._get_input_entity( self._message_id, utils.get_peer_id(self._chat_peer) ) return self._input_chat @property - def chat(self): + async def chat(self): """ The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which the event occurred. This property will make an API call the first time to get the most up to date version of the chat, so use with care as there is no caching besides local caching yet. """ - if self._chat is None and self.input_chat: - self._chat = self._client.get_entity(self._input_chat) + if self._chat is None and await self.input_chat: + self._chat = await self._client.get_entity(await self._input_chat) return self._chat @@ -157,7 +157,7 @@ class Raw(_EventBuilder): """ Represents a raw event. The event is the update itself. """ - def resolve(self, client): + async def resolve(self, client): pass def build(self, update): @@ -304,23 +304,23 @@ class NewMessage(_EventBuilder): self.is_reply = bool(message.reply_to_msg_id) self._reply_message = None - def respond(self, *args, **kwargs): + async def respond(self, *args, **kwargs): """ Responds to the message (not as a reply). This is a shorthand for ``client.send_message(event.chat, ...)``. """ - return self._client.send_message(self.input_chat, *args, **kwargs) + return await self._client.send_message(await self.input_chat, *args, **kwargs) - def reply(self, *args, **kwargs): + async def reply(self, *args, **kwargs): """ Replies to the message (as a reply). This is a shorthand for ``client.send_message(event.chat, ..., reply_to=event.message.id)``. """ - return self._client.send_message(self.input_chat, - reply_to=self.message.id, - *args, **kwargs) + return await self._client.send_message(await self.input_chat, + reply_to=self.message.id, + *args, **kwargs) - def edit(self, *args, **kwargs): + async def edit(self, *args, **kwargs): """ Edits the message iff it's outgoing. This is a shorthand for ``client.edit_message(event.chat, event.message, ...)``. @@ -331,27 +331,27 @@ class NewMessage(_EventBuilder): if not self.message.out: if not isinstance(self.message.to_id, types.PeerUser): return None - me = self._client.get_me(input_peer=True) + me = await self._client.get_me(input_peer=True) if self.message.to_id.user_id != me.user_id: return None - return self._client.edit_message(self.input_chat, - self.message, - *args, **kwargs) + return await self._client.edit_message(await self.input_chat, + self.message, + *args, **kwargs) - def delete(self, *args, **kwargs): + async def delete(self, *args, **kwargs): """ Deletes the message. You're responsible for checking whether you have the permission to do so, or to except the error otherwise. This is a shorthand for ``client.delete_messages(event.chat, event.message, ...)``. """ - return self._client.delete_messages(self.input_chat, - [self.message], - *args, **kwargs) + return await self._client.delete_messages(await self.input_chat, + [self.message], + *args, **kwargs) @property - def input_sender(self): + async def input_sender(self): """ This (:obj:`InputPeer`) is the input version of the user who sent the message. Similarly to ``input_chat``, this doesn't have @@ -365,21 +365,21 @@ class NewMessage(_EventBuilder): return None try: - self._input_sender = self._client.get_input_entity( + self._input_sender = await self._client.get_input_entity( self.message.from_id ) except (ValueError, TypeError): # We can rely on self.input_chat for this - self._input_sender = self._get_input_entity( + self._input_sender = await self._get_input_entity( self.message.id, self.message.from_id, - chat=self.input_chat + chat=await self.input_chat ) return self._input_sender @property - def sender(self): + async def sender(self): """ This (:obj:`User`) will make an API call the first time to get the most up to date version of the sender, so use with care as @@ -387,8 +387,8 @@ class NewMessage(_EventBuilder): ``input_sender`` needs to be available (often the case). """ - if self._sender is None and self.input_sender: - self._sender = self._client.get_entity(self._input_sender) + if self._sender is None and await self.input_sender: + self._sender = await self._client.get_entity(self._input_sender) return self._sender @property @@ -411,7 +411,7 @@ class NewMessage(_EventBuilder): return self.message.message @property - def reply_message(self): + async def reply_message(self): """ This (:obj:`Message`, optional) will make an API call the first time to get the full ``Message`` object that one was replying to, @@ -421,12 +421,12 @@ class NewMessage(_EventBuilder): return None if self._reply_message is None: - if isinstance(self.input_chat, types.InputPeerChannel): - r = self._client(functions.channels.GetMessagesRequest( - self.input_chat, [self.message.reply_to_msg_id] + if isinstance(await self.input_chat, types.InputPeerChannel): + r = await self._client(functions.channels.GetMessagesRequest( + await self.input_chat, [self.message.reply_to_msg_id] )) else: - r = self._client(functions.messages.GetMessagesRequest( + r = await self._client(functions.messages.GetMessagesRequest( [self.message.reply_to_msg_id] )) if not isinstance(r, types.messages.MessagesNotModified): @@ -610,7 +610,7 @@ class ChatAction(_EventBuilder): self.new_title = new_title @property - def pinned_message(self): + async def pinned_message(self): """ If ``new_pin`` is ``True``, this returns the (:obj:`Message`) object that was pinned. @@ -618,8 +618,8 @@ class ChatAction(_EventBuilder): if self._pinned_message == 0: return None - if isinstance(self._pinned_message, int) and self.input_chat: - r = self._client(functions.channels.GetMessagesRequest( + if isinstance(self._pinned_message, int) and await self.input_chat: + r = await self._client(functions.channels.GetMessagesRequest( self._input_chat, [self._pinned_message] )) try: @@ -635,25 +635,25 @@ class ChatAction(_EventBuilder): return self._pinned_message @property - def added_by(self): + async def added_by(self): """ The user who added ``users``, if applicable (``None`` otherwise). """ if self._added_by and not isinstance(self._added_by, types.User): - self._added_by = self._client.get_entity(self._added_by) + self._added_by = await self._client.get_entity(self._added_by) return self._added_by @property - def kicked_by(self): + async def kicked_by(self): """ The user who kicked ``users``, if applicable (``None`` otherwise). """ if self._kicked_by and not isinstance(self._kicked_by, types.User): - self._kicked_by = self._client.get_entity(self._kicked_by) + self._kicked_by = await self._client.get_entity(self._kicked_by) return self._kicked_by @property - def user(self): + async def user(self): """ The single user that takes part in this action (e.g. joined). @@ -661,12 +661,12 @@ class ChatAction(_EventBuilder): there is no user taking part. """ try: - return next(self.users) + return next(await self.users) except (StopIteration, TypeError): return None @property - def users(self): + async def users(self): """ A list of users that take part in this action (e.g. joined). @@ -675,7 +675,7 @@ class ChatAction(_EventBuilder): """ if self._users is None and self._user_peers: try: - self._users = self._client.get_entity(self._user_peers) + self._users = await self._client.get_entity(self._user_peers) except (TypeError, ValueError): self._users = [] diff --git a/telethon/extensions/markdown.py b/telethon/extensions/markdown.py index a5dde5c6..15d3b66e 100644 --- a/telethon/extensions/markdown.py +++ b/telethon/extensions/markdown.py @@ -194,9 +194,9 @@ def get_inner_text(text, entity): """ if isinstance(entity, TLObject): entity = (entity,) - multiple = True - else: multiple = False + else: + multiple = True text = _add_surrogate(text) result = [] diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index d335e57a..3babfb3d 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -1,13 +1,20 @@ """ This module holds a rough implementation of the C# TCP client. """ +# Python rough implementation of a C# TCP client +import asyncio import errno import logging import socket import time from datetime import timedelta from io import BytesIO, BufferedWriter -from threading import Lock + +MAX_TIMEOUT = 15 # in seconds +CONN_RESET_ERRNOS = { + errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, + errno.EINVAL, errno.ENOTCONN +} try: import socks @@ -25,7 +32,7 @@ __log__ = logging.getLogger(__name__) class TcpClient: """A simple TCP client to ease the work with sockets and proxies.""" - def __init__(self, proxy=None, timeout=timedelta(seconds=5)): + def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): """ Initializes the TCP client. @@ -34,7 +41,7 @@ class TcpClient: """ self.proxy = proxy self._socket = None - self._closing_lock = Lock() + self._loop = loop if loop else asyncio.get_event_loop() if isinstance(timeout, timedelta): self.timeout = timeout.seconds @@ -54,9 +61,9 @@ class TcpClient: else: # tuple, list, etc. self._socket.set_proxy(*self.proxy) - self._socket.settimeout(self.timeout) + self._socket.setblocking(False) - def connect(self, ip, port): + async def connect(self, ip, port): """ Tries connecting forever to IP:port unless an OSError is raised. @@ -72,11 +79,15 @@ class TcpClient: timeout = 1 while True: try: - while not self._socket: + if not self._socket: self._recreate_socket(mode) - self._socket.connect(address) + await self._loop.sock_connect(self._socket, address) break # Successful connection, stop retrying to connect + except ConnectionError: + self._socket = None + await asyncio.sleep(timeout) + timeout = min(timeout * 2, MAX_TIMEOUT) except OSError as e: __log__.info('OSError "%s" raised while connecting', e) # Stop retrying to connect if proxy connection error occurred @@ -90,7 +101,7 @@ class TcpClient: # Bad file descriptor, i.e. socket was closed, set it # to none to recreate it on the next iteration self._socket = None - time.sleep(timeout) + await asyncio.sleep(timeout) timeout = min(timeout * 2, MAX_TIMEOUT) else: raise @@ -103,21 +114,16 @@ class TcpClient: def close(self): """Closes the connection.""" - if self._closing_lock.locked(): - # Already closing, no need to close again (avoid None.close()) - return + try: + if self._socket is not None: + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() + except OSError: + pass # Ignore ENOTCONN, EBADF, and any other error when closing + finally: + self._socket = None - with self._closing_lock: - try: - if self._socket is not None: - self._socket.shutdown(socket.SHUT_RDWR) - self._socket.close() - except OSError: - pass # Ignore ENOTCONN, EBADF, and any other error when closing - finally: - self._socket = None - - def write(self, data): + async def write(self, data): """ Writes (sends) the specified bytes to the connected peer. @@ -126,11 +132,13 @@ class TcpClient: if self._socket is None: self._raise_connection_reset(None) - # TODO Timeout may be an issue when sending the data, Changed in v3.5: - # The socket timeout is now the maximum total duration to send all data. try: - self._socket.sendall(data) - except socket.timeout as e: + await asyncio.wait_for( + self.sock_sendall(data), + timeout=self.timeout, + loop=self._loop + ) + except asyncio.TimeoutError as e: __log__.debug('socket.timeout "%s" while writing data', e) raise TimeoutError() from e except ConnectionError as e: @@ -143,7 +151,7 @@ class TcpClient: else: raise - def read(self, size): + async def read(self, size): """ Reads (receives) a whole block of size bytes from the connected peer. @@ -153,13 +161,18 @@ class TcpClient: if self._socket is None: self._raise_connection_reset(None) - # TODO Remove the timeout from this method, always use previous one with BufferedWriter(BytesIO(), buffer_size=size) as buffer: bytes_left = size while bytes_left != 0: try: - partial = self._socket.recv(bytes_left) - except socket.timeout as e: + if self._socket is None: + self._raise_connection_reset() + partial = await asyncio.wait_for( + self.sock_recv(bytes_left), + timeout=self.timeout, + loop=self._loop + ) + except asyncio.TimeoutError as e: # These are somewhat common if the server has nothing # to send to us, so use a lower logging priority. __log__.debug('socket.timeout "%s" while reading data', e) @@ -168,7 +181,7 @@ class TcpClient: __log__.info('ConnectionError "%s" while reading data', e) self._raise_connection_reset(e) except OSError as e: - if e.errno != errno.EBADF and self._closing_lock.locked(): + if e.errno != errno.EBADF: # Ignore bad file descriptor while closing __log__.info('OSError "%s" while reading data', e) @@ -190,5 +203,56 @@ class TcpClient: def _raise_connection_reset(self, original): """Disconnects the client and raises ConnectionResetError.""" self.close() # Connection reset -> flag as socket closed - raise ConnectionResetError('The server has closed the connection.')\ - from original + raise ConnectionResetError('The server has closed the connection.') from original + + # due to new https://github.com/python/cpython/pull/4386 + def sock_recv(self, n): + fut = self._loop.create_future() + self._sock_recv(fut, None, n) + return fut + + def _sock_recv(self, fut, registered_fd, n): + if registered_fd is not None: + self._loop.remove_reader(registered_fd) + if fut.cancelled(): + return + + try: + data = self._socket.recv(n) + except (BlockingIOError, InterruptedError): + fd = self._socket.fileno() + self._loop.add_reader(fd, self._sock_recv, fut, fd, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, data): + fut = self._loop.create_future() + if data: + self._sock_sendall(fut, None, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered_fd, data): + if registered_fd: + self._loop.remove_writer(registered_fd) + if fut.cancelled(): + return + + try: + n = self._socket.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + fd = self._socket.fileno() + self._loop.add_writer(fd, self._sock_sendall, fut, fd, data) diff --git a/telethon/network/authenticator.py b/telethon/network/authenticator.py index 32413551..f905f0fc 100644 --- a/telethon/network/authenticator.py +++ b/telethon/network/authenticator.py @@ -21,7 +21,7 @@ from ..tl.functions import ( ) -def do_authentication(connection, retries=5): +async def do_authentication(connection, retries=5): """ Performs the authentication steps on the given connection. Raises an error if all attempts fail. @@ -36,14 +36,14 @@ def do_authentication(connection, retries=5): last_error = None while retries: try: - return _do_authentication(connection) + return await _do_authentication(connection) except (SecurityError, AssertionError, NotImplementedError) as e: last_error = e retries -= 1 raise last_error -def _do_authentication(connection): +async def _do_authentication(connection): """ Executes the authentication process with the Telegram servers. @@ -56,8 +56,8 @@ def _do_authentication(connection): req_pq_request = ReqPqMultiRequest( nonce=int.from_bytes(os.urandom(16), 'big', signed=True) ) - sender.send(bytes(req_pq_request)) - with BinaryReader(sender.receive()) as reader: + await sender.send(bytes(req_pq_request)) + with BinaryReader(await sender.receive()) as reader: req_pq_request.on_response(reader) res_pq = req_pq_request.result @@ -104,10 +104,10 @@ def _do_authentication(connection): public_key_fingerprint=target_fingerprint, encrypted_data=cipher_text ) - sender.send(bytes(req_dh_params)) + await sender.send(bytes(req_dh_params)) # Step 2 response: DH Exchange - with BinaryReader(sender.receive()) as reader: + with BinaryReader(await sender.receive()) as reader: req_dh_params.on_response(reader) server_dh_params = req_dh_params.result @@ -174,10 +174,10 @@ def _do_authentication(connection): server_nonce=res_pq.server_nonce, encrypted_data=client_dh_encrypted, ) - sender.send(bytes(set_client_dh)) + await sender.send(bytes(set_client_dh)) # Step 3 response: Complete DH Exchange - with BinaryReader(sender.receive()) as reader: + with BinaryReader(await sender.receive()) as reader: set_client_dh.on_response(reader) dh_gen = set_client_dh.result diff --git a/telethon/network/connection.py b/telethon/network/connection.py index 0adaf98a..7efc6a9b 100644 --- a/telethon/network/connection.py +++ b/telethon/network/connection.py @@ -2,18 +2,17 @@ This module holds both the Connection class and the ConnectionMode enum, which specifies the protocol to be used by the Connection. """ +import errno import logging import os import struct from datetime import timedelta -from zlib import crc32 from enum import Enum - -import errno +from zlib import crc32 from ..crypto import AESModeCTR -from ..extensions import TcpClient from ..errors import InvalidChecksumError +from ..extensions import TcpClient __log__ = logging.getLogger(__name__) @@ -52,7 +51,7 @@ class Connection: """ def __init__(self, mode=ConnectionMode.TCP_FULL, - proxy=None, timeout=timedelta(seconds=5)): + proxy=None, timeout=timedelta(seconds=5), loop=None): """ Initializes a new connection. @@ -65,7 +64,7 @@ class Connection: self._aes_encrypt, self._aes_decrypt = None, None # TODO Rename "TcpClient" as some sort of generic socket? - self.conn = TcpClient(proxy=proxy, timeout=timeout) + self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop) # Sending messages if mode == ConnectionMode.TCP_FULL: @@ -89,7 +88,7 @@ class Connection: setattr(self, 'write', self._write_plain) setattr(self, 'read', self._read_plain) - def connect(self, ip, port): + async def connect(self, ip, port): """ Estabilishes a connection to IP:port. @@ -97,7 +96,7 @@ class Connection: :param port: the port to connect to. """ try: - self.conn.connect(ip, port) + await self.conn.connect(ip, port) except OSError as e: if e.errno == errno.EISCONN: return # Already connected, no need to re-set everything up @@ -106,17 +105,17 @@ class Connection: self._send_counter = 0 if self._mode == ConnectionMode.TCP_ABRIDGED: - self.conn.write(b'\xef') + await self.conn.write(b'\xef') elif self._mode == ConnectionMode.TCP_INTERMEDIATE: - self.conn.write(b'\xee\xee\xee\xee') + await self.conn.write(b'\xee\xee\xee\xee') elif self._mode == ConnectionMode.TCP_OBFUSCATED: - self._setup_obfuscation() + await self._setup_obfuscation() def get_timeout(self): """Returns the timeout used by the connection.""" return self.conn.timeout - def _setup_obfuscation(self): + async def _setup_obfuscation(self): """ Sets up the obfuscated protocol. """ @@ -144,7 +143,7 @@ class Connection: self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] - self.conn.write(bytes(random)) + await self.conn.write(bytes(random)) def is_connected(self): """ @@ -166,12 +165,12 @@ class Connection: # region Receive message implementations - def recv(self): + async def recv(self): """Receives and unpacks a message""" # Default implementation is just an error raise ValueError('Invalid connection mode specified: ' + str(self._mode)) - def _recv_tcp_full(self): + async def _recv_tcp_full(self): """ Receives a message from the network, internally encoded using the TCP full protocol. @@ -181,7 +180,10 @@ class Connection: :return: the read message payload. """ - packet_len_seq = self.read(8) # 4 and 4 + # TODO We don't want another call to this method that could + # potentially await on another self.read(n). Is this guaranteed + # by asyncio? + packet_len_seq = await self.read(8) # 4 and 4 packet_len, seq = struct.unpack('= 127: - length = struct.unpack(' SQL entities = self._check_migrate_json() @@ -189,42 +183,37 @@ class SQLiteSession(MemorySession): self._update_session_table() def _update_session_table(self): - with self._db_lock: - c = self._cursor() - # While we can save multiple rows into the sessions table - # currently we only want to keep ONE as the tables don't - # tell us which auth_key's are usable and will work. Needs - # some more work before being able to save auth_key's for - # multiple DCs. Probably done differently. - c.execute('delete from sessions') - c.execute('insert or replace into sessions values (?,?,?,?)', ( - self._dc_id, - self._server_address, - self._port, - self._auth_key.key if self._auth_key else b'' - )) - c.close() + c = self._cursor() + # While we can save multiple rows into the sessions table + # currently we only want to keep ONE as the tables don't + # tell us which auth_key's are usable and will work. Needs + # some more work before being able to save auth_key's for + # multiple DCs. Probably done differently. + c.execute('delete from sessions') + c.execute('insert or replace into sessions values (?,?,?,?)', ( + self._dc_id, + self._server_address, + self._port, + self._auth_key.key if self._auth_key else b'' + )) + c.close() def save(self): """Saves the current session object as session_user_id.session""" - with self._db_lock: - self._conn.commit() + self._conn.commit() def _cursor(self): """Asserts that the connection is open and returns a cursor""" - with self._db_lock: - if self._conn is None: - self._conn = sqlite3.connect(self.filename, - check_same_thread=False) - return self._conn.cursor() + if self._conn is None: + self._conn = sqlite3.connect(self.filename) + return self._conn.cursor() def close(self): """Closes the connection unless we're working in-memory""" if self.filename != ':memory:': - with self._db_lock: - if self._conn is not None: - self._conn.close() - self._conn = None + if self._conn is not None: + self._conn.close() + self._conn = None def delete(self): """Deletes the current session file""" @@ -259,11 +248,10 @@ class SQLiteSession(MemorySession): if not rows: return - with self._db_lock: - self._cursor().executemany( - 'insert or replace into entities values (?,?,?,?,?)', rows - ) - self.save() + self._cursor().executemany( + 'insert or replace into entities values (?,?,?,?,?)', rows + ) + self.save() def _fetchone_entity(self, query, args): c = self._cursor() @@ -302,11 +290,10 @@ class SQLiteSession(MemorySession): if not isinstance(instance, (InputDocument, InputPhoto)): raise TypeError('Cannot cache %s instance' % type(instance)) - with self._db_lock: - self._cursor().execute( - 'insert or replace into sent_files values (?,?,?,?,?)', ( - md5_digest, file_size, - _SentFileType.from_type(type(instance)).value, - instance.id, instance.access_hash - )) - self.save() + self._cursor().execute( + 'insert or replace into sent_files values (?,?,?,?,?)', ( + md5_digest, file_size, + _SentFileType.from_type(type(instance)).value, + instance.id, instance.access_hash + )) + self.save() diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index bf33a7dc..dd043c32 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -1,11 +1,9 @@ +import asyncio import logging import os +from asyncio import Lock +from datetime import timedelta import platform -import threading -from datetime import timedelta, datetime -from signal import signal, SIGINT, SIGTERM, SIGABRT -from threading import Lock -from time import sleep from . import version, utils from .crypto import rsa from .errors import ( @@ -70,8 +68,6 @@ class TelegramBareClient: connection_mode=ConnectionMode.TCP_FULL, use_ipv6=False, proxy=None, - update_workers=None, - spawn_read_thread=False, timeout=timedelta(seconds=5), loop=None, device_model=None, @@ -95,6 +91,8 @@ class TelegramBareClient: 'The given session must be a str or a Session instance.' ) + self._loop = loop if loop else asyncio.get_event_loop() + # ':' in session.server_address is True if it's an IPv6 address if (not session.server_address or (':' in session.server_address) != use_ipv6): @@ -112,13 +110,15 @@ class TelegramBareClient: # that calls .connect(). Every other thread will spawn a new # temporary connection. The connection on this one is always # kept open so Telegram can send us updates. - self._sender = MtProtoSender(self.session, Connection( - mode=connection_mode, proxy=proxy, timeout=timeout - )) + self._sender = MtProtoSender( + self.session, + Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop), + self._loop + ) - # Two threads may be calling reconnect() when the connection is lost, - # we only want one to actually perform the reconnection. - self._reconnect_lock = Lock() + # Two coroutines may be calling reconnect() when the connection + # is lost, we only want one to actually perform the reconnection. + self._reconnect_lock = Lock(loop=self._loop) # Cache "exported" sessions as 'dc_id: Session' not to recreate # them all the time since generating a new key is a relatively @@ -127,7 +127,7 @@ class TelegramBareClient: # This member will process updates if enabled. # One may change self.updates.enabled at any later point. - self.updates = UpdateState(workers=update_workers) + self.updates = UpdateState(self._loop) # Used on connection - the user may modify these and reconnect system = platform.uname() @@ -153,34 +153,25 @@ class TelegramBareClient: # See https://core.telegram.org/api/invoking#saving-client-info. self._first_request = True - # Constantly read for results and updates from within the main client, - # if the user has left enabled such option. - self._spawn_read_thread = spawn_read_thread - self._recv_thread = None - self._idling = threading.Event() + self._recv_loop = None + self._ping_loop = None + self._state_loop = None + self._idling = asyncio.Event() # Default PingRequest delay - self._last_ping = datetime.now() self._ping_delay = timedelta(minutes=1) - # Also have another delay for GetStateRequest. # # If the connection is kept alive for long without invoking any # high level request the server simply stops sending updates. # TODO maybe we can have ._last_request instead if any req works? - self._last_state = datetime.now() self._state_delay = timedelta(hours=1) - # Some errors are known but there's nothing we can do from the - # background thread. If any of these happens, call .disconnect(), - # and raise them next time .invoke() is tried to be called. - self._background_error = None - # endregion # region Connecting - def connect(self, _sync_updates=True): + async def connect(self, _sync_updates=True): """Connects to the Telegram servers, executing authentication if required. Note that authenticating to the Telegram servers is not the same as authenticating the desired user itself, which @@ -197,10 +188,8 @@ class TelegramBareClient: __log__.info('Connecting to %s:%d...', self.session.server_address, self.session.port) - self._background_error = None # Clear previous errors - try: - self._sender.connect() + await self._sender.connect() __log__.info('Connection success!') # Connection was successful! Try syncing the update state @@ -210,12 +199,12 @@ class TelegramBareClient: self._user_connected = True if self._authorized is None and _sync_updates: try: - self.sync_updates() - self._set_connected_and_authorized() + await self.sync_updates() + await self._set_connected_and_authorized() except UnauthorizedError: self._authorized = False elif self._authorized: - self._set_connected_and_authorized() + await self._set_connected_and_authorized() return True @@ -224,7 +213,7 @@ class TelegramBareClient: __log__.warning('Connection failed, got unexpected type with ID ' '%s. Migrating?', hex(e.invalid_constructor_id)) self.disconnect() - return self.connect(_sync_updates=_sync_updates) + return await self.connect(_sync_updates=_sync_updates) except (RPCError, ConnectionError) as e: # Probably errors from the previous session, ignore them @@ -249,24 +238,15 @@ class TelegramBareClient: )) def disconnect(self): - """Disconnects from the Telegram server - and stops all the spawned threads""" + """Disconnects from the Telegram server""" __log__.info('Disconnecting...') - self._user_connected = False # This will stop recv_thread's loop - - __log__.debug('Stopping all workers...') - self.updates.stop_workers() - - # This will trigger a "ConnectionResetError" on the recv_thread, - # which won't attempt reconnecting as ._user_connected is False. - __log__.debug('Disconnecting the socket...') + self._user_connected = False self._sender.disconnect() - # TODO Shall we clear the _exported_sessions, or may be reused? self._first_request = True # On reconnect it will be first again self.session.close() - def _reconnect(self, new_dc=None): + async def _reconnect(self, new_dc=None): """If 'new_dc' is not set, only a call to .connect() will be made since it's assumed that the connection has been lost and the library is reconnecting. @@ -276,13 +256,14 @@ class TelegramBareClient: connects to the new data center. """ if new_dc is None: - if self.is_connected(): - __log__.info('Reconnection aborted: already connected') - return True - + # Assume we are disconnected due to some error, so connect again try: + if self.is_connected(): + __log__.info('Reconnection aborted: already connected') + return True + __log__.info('Attempting reconnection...') - return self.connect() + return await self.connect() except ConnectionResetError as e: __log__.warning('Reconnection failed due to %s', e) return False @@ -290,7 +271,7 @@ class TelegramBareClient: # Since we're reconnecting possibly due to a UserMigrateError, # we need to first know the Data Centers we can connect to. Do # that before disconnecting. - dc = self._get_dc(new_dc) + dc = await self._get_dc(new_dc) __log__.info('Reconnecting to new data center %s', dc) self.session.set_dc(dc.id, dc.ip_address, dc.port) @@ -299,7 +280,7 @@ class TelegramBareClient: self.session.auth_key = None self.session.save() self.disconnect() - return self.connect() + return await self.connect() def set_proxy(self, proxy): """Change the proxy used by the connections. @@ -312,19 +293,15 @@ class TelegramBareClient: # region Working with different connections/Data Centers - def _on_read_thread(self): - return self._recv_thread is not None and \ - threading.get_ident() == self._recv_thread.ident - - def _get_dc(self, dc_id, cdn=False): + async def _get_dc(self, dc_id, cdn=False): """Gets the Data Center (DC) associated to 'dc_id'""" if not TelegramBareClient._config: - TelegramBareClient._config = self(GetConfigRequest()) + TelegramBareClient._config = await self(GetConfigRequest()) try: if cdn: # Ensure we have the latest keys for the CDNs - for pk in self(GetCdnConfigRequest()).public_keys: + for pk in await (self(GetCdnConfigRequest())).public_keys: rsa.add_key(pk.public_key) return next( @@ -336,10 +313,10 @@ class TelegramBareClient: raise # New configuration, perhaps a new CDN was added? - TelegramBareClient._config = self(GetConfigRequest()) - return self._get_dc(dc_id, cdn=cdn) + TelegramBareClient._config = await self(GetConfigRequest()) + return await self._get_dc(dc_id, cdn=cdn) - def _get_exported_client(self, dc_id): + async def _get_exported_client(self, dc_id): """Creates and connects a new TelegramBareClient for the desired DC. If it's the first time calling the method with a given dc_id, @@ -356,11 +333,11 @@ class TelegramBareClient: # TODO Add a lock, don't allow two threads to create an auth key # (when calling .connect() if there wasn't a previous session). # for the same data center. - dc = self._get_dc(dc_id) + dc = await self._get_dc(dc_id) # Export the current authorization to the new DC. __log__.info('Exporting authorization for data center %s', dc) - export_auth = self(ExportAuthorizationRequest(dc_id)) + export_auth = await self(ExportAuthorizationRequest(dc_id)) # Create a temporary session for this IP address, which needs # to be different because each auth_key is unique per DC. @@ -375,11 +352,12 @@ class TelegramBareClient: client = TelegramBareClient( session, self.api_id, self.api_hash, proxy=self._sender.connection.conn.proxy, - timeout=self._sender.connection.get_timeout() + timeout=self._sender.connection.get_timeout(), + loop=self._loop ) - client.connect(_sync_updates=False) + await client.connect(_sync_updates=False) if isinstance(export_auth, ExportedAuthorization): - client(ImportAuthorizationRequest( + await client(ImportAuthorizationRequest( id=export_auth.id, bytes=export_auth.bytes )) elif export_auth is not None: @@ -388,11 +366,11 @@ class TelegramBareClient: client._authorized = True # We exported the auth, so we got auth return client - def _get_cdn_client(self, cdn_redirect): + async def _get_cdn_client(self, cdn_redirect): """Similar to ._get_exported_client, but for CDNs""" session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: - dc = self._get_dc(cdn_redirect.dc_id, cdn=True) + dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) session = self.session.clone() session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session @@ -401,7 +379,8 @@ class TelegramBareClient: client = TelegramBareClient( session, self.api_id, self.api_hash, proxy=self._sender.connection.conn.proxy, - timeout=self._sender.connection.get_timeout() + timeout=self._sender.connection.get_timeout(), + loop=self._loop ) # This will make use of the new RSA keys for this specific CDN. @@ -409,7 +388,7 @@ class TelegramBareClient: # We won't be calling GetConfigRequest because it's only called # when needed by ._get_dc, and also it's static so it's likely # set already. Avoid invoking non-CDN methods by not syncing updates. - client.connect(_sync_updates=False) + await client.connect(_sync_updates=False) client._authorized = self._authorized return client @@ -417,7 +396,7 @@ class TelegramBareClient: # region Invoking Telegram requests - def __call__(self, *requests, retries=5): + async def __call__(self, *requests, retries=5): """Invokes (sends) a MTProtoRequest and returns (receives) its result. The invoke will be retried up to 'retries' times before raising @@ -427,11 +406,8 @@ class TelegramBareClient: x.content_related for x in requests): raise TypeError('You can only invoke requests, not types!') - if self._background_error: - raise self._background_error - for request in requests: - request.resolve(self, utils) + await request.resolve(self, utils) # For logging purposes if len(requests) == 1: @@ -440,26 +416,23 @@ class TelegramBareClient: which = '{} requests ({})'.format( len(requests), [type(x).__name__ for x in requests]) - # Determine the sender to be used (main or a new connection) __log__.debug('Invoking %s', which) call_receive = \ not self._idling.is_set() or self._reconnect_lock.locked() for retry in range(retries): - result = self._invoke(call_receive, *requests) + result = await self._invoke(call_receive, retry, *requests) if result is not None: return result __log__.warning('Invoking %s failed %d times, ' 'reconnecting and retrying', [str(x) for x in requests], retry + 1) - sleep(1) - # The ReadThread has priority when attempting reconnection, - # since this thread is constantly running while __call__ is - # only done sometimes. Here try connecting only once/retry. + + await asyncio.sleep(retry + 1, loop=self._loop) if not self._reconnect_lock.locked(): - with self._reconnect_lock: - self._reconnect() + with await self._reconnect_lock: + await self._reconnect() raise RuntimeError('Number of retries reached 0 for {}.'.format( [type(x).__name__ for x in requests] @@ -468,18 +441,17 @@ class TelegramBareClient: # Let people use client.invoke(SomeRequest()) instead client(...) invoke = __call__ - def _invoke(self, call_receive, *requests): + async def _invoke(self, call_receive, retry, *requests): try: # Ensure that we start with no previous errors (i.e. resending) for x in requests: - x.confirm_received.clear() x.rpc_error = None if not self.session.auth_key: __log__.info('Need to generate new auth key before invoking') self._first_request = True self.session.auth_key, self.session.time_offset = \ - authenticator.do_authentication(self._sender.connection) + await authenticator.do_authentication(self._sender.connection) if self._first_request: __log__.info('Initializing a new connection while invoking') @@ -489,24 +461,21 @@ class TelegramBareClient: # We need a SINGLE request (like GetConfig) to init conn. # Once that's done, the N original requests will be # invoked. - TelegramBareClient._config = self( + TelegramBareClient._config = await self( self._wrap_init_connection(GetConfigRequest()) ) - self._sender.send(*requests) + await self._sender.send(*requests) if not call_receive: - # TODO This will be slightly troublesome if we allow - # switching between constant read or not on the fly. - # Must also watch out for calling .read() from two places, - # in which case a Lock would be required for .receive(). - for x in requests: - x.confirm_received.wait( - self._sender.connection.get_timeout() - ) + await asyncio.wait( + list(map(lambda x: x.confirm_received.wait(), requests)), + timeout=self._sender.connection.get_timeout(), + loop=self._loop + ) else: while not all(x.confirm_received.is_set() for x in requests): - self._sender.receive(update_state=self.updates) + await self._sender.receive(update_state=self.updates) except BrokenAuthKeyError: __log__.error('Authorization key seems broken and was invalid!') @@ -552,12 +521,8 @@ class TelegramBareClient: except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: - # TODO What happens with the background thread here? - # For normal use cases, this won't happen, because this will only - # be on the very first connection (not authorized, not running), - # but may be an issue for people who actually travel? - self._reconnect(new_dc=e.new_dc) - return self._invoke(call_receive, *requests) + await self._reconnect(new_dc=e.new_dc) + return await self._invoke(call_receive, retry, *requests) except ServerError as e: # Telegram is having some issues, just retry @@ -568,7 +533,8 @@ class TelegramBareClient: if e.seconds > self.session.flood_sleep_threshold | 0: raise - sleep(e.seconds) + await asyncio.sleep(e.seconds, loop=self._loop) + return None # Some really basic functionality @@ -588,90 +554,69 @@ class TelegramBareClient: # region Updates handling - def sync_updates(self): + async def sync_updates(self): """Synchronizes self.updates to their initial state. Will be called automatically on connection if self.updates.enabled = True, otherwise it should be called manually after enabling updates. """ - self.updates.process(self(GetStateRequest())) - self._last_state = datetime.now() + self.updates.process(await self(GetStateRequest())) # endregion - # region Constant read + # Constant read - def _set_connected_and_authorized(self): + # This is async so that the overrided version in TelegramClient can be + # async without problems. + async def _set_connected_and_authorized(self): self._authorized = True - self.updates.setup_workers() - if self._spawn_read_thread and self._recv_thread is None: - self._recv_thread = threading.Thread( - name='ReadThread', daemon=True, - target=self._recv_thread_impl - ) - self._recv_thread.start() + if self._recv_loop is None: + self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop) + if self._ping_loop is None: + self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop) + if self._state_loop is None: + self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop) - def _signal_handler(self, signum, frame): - if self._user_connected: - self.disconnect() - else: - os._exit(1) + async def _ping_loop_impl(self): + while self._user_connected: + await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) + await asyncio.sleep(self._ping_delay.seconds, loop=self._loop) + self._ping_loop = None - def idle(self, stop_signals=(SIGINT, SIGTERM, SIGABRT)): - """ - Idles the program by looping forever and listening for updates - until one of the signals are received, which breaks the loop. - - :param stop_signals: - Iterable containing signals from the signal module that will - be subscribed to TelegramClient.disconnect() (effectively - stopping the idle loop), which will be called on receiving one - of those signals. - :return: - """ - if self._spawn_read_thread and not self._on_read_thread(): - raise RuntimeError('Can only idle if spawn_read_thread=False') + async def _state_loop_impl(self): + while self._user_connected: + await asyncio.sleep(self._state_delay.seconds, loop=self._loop) + await self._sender.send(GetStateRequest()) + async def _recv_loop_impl(self): + __log__.info('Starting to wait for items from the network') self._idling.set() - for sig in stop_signals: - signal(sig, self._signal_handler) - - if self._on_read_thread(): - __log__.info('Starting to wait for items from the network') - else: - __log__.info('Idling to receive items from the network') - + need_reconnect = False while self._user_connected: try: - if datetime.now() > self._last_ping + self._ping_delay: - self._sender.send(PingRequest( - int.from_bytes(os.urandom(8), 'big', signed=True) - )) - self._last_ping = datetime.now() + if need_reconnect: + __log__.info('Attempting reconnection from read loop') + need_reconnect = False + with await self._reconnect_lock: + while self._user_connected and not await self._reconnect(): + # Retry forever, this is instant messaging + await asyncio.sleep(0.1, loop=self._loop) - if datetime.now() > self._last_state + self._state_delay: - self._sender.send(GetStateRequest()) - self._last_state = datetime.now() - - __log__.debug('Receiving items from the network...') - self._sender.receive(update_state=self.updates) - except TimeoutError: - # No problem - __log__.debug('Receiving items from the network timed out') - except ConnectionResetError: - if self._user_connected: - __log__.error('Connection was reset while receiving ' - 'items. Reconnecting') - with self._reconnect_lock: - while self._user_connected and not self._reconnect(): - sleep(0.1) # Retry forever, this is instant messaging - - if self.is_connected(): # Telegram seems to kick us every 1024 items received # from the network not considering things like bad salt. # We must execute some *high level* request (that's not # a ping) if we want to receive updates again. # TODO Test if getDifference works too (better alternative) - self._sender.send(GetStateRequest()) + await self._sender.send(GetStateRequest()) + + __log__.debug('Receiving items from the network...') + await self._sender.receive(update_state=self.updates) + except TimeoutError: + # No problem. + __log__.debug('Receiving items from the network timed out') + except ConnectionError: + need_reconnect = True + __log__.error('Connection was reset while receiving items') + await asyncio.sleep(1, loop=self._loop) except: self._idling.clear() raise @@ -679,39 +624,4 @@ class TelegramBareClient: self._idling.clear() __log__.info('Connection closed by the user, not reading anymore') - # By using this approach, another thread will be - # created and started upon connection to constantly read - # from the other end. Otherwise, manual calls to .receive() - # must be performed. The MtProtoSender cannot be connected, - # or an error will be thrown. - # - # This way, sending and receiving will be completely independent. - def _recv_thread_impl(self): - # This thread is "idle" (only listening for updates), but also - # excepts everything unlike the manual idle because it should - # not crash. - while self._user_connected: - try: - self.idle(stop_signals=tuple()) - except Exception as error: - __log__.exception('Unknown exception in the read thread! ' - 'Disconnecting and leaving it to main thread') - # Unknown exception, pass it to the main thread - - try: - import socks - if isinstance(error, ( - socks.GeneralProxyError, socks.ProxyConnectionError - )): - # This is a known error, and it's not related to - # Telegram but rather to the proxy. Disconnect and - # hand it over to the main thread. - self._background_error = error - self.disconnect() - break - except ImportError: - "Not using PySocks, so it can't be a proxy error" - - self._recv_thread = None - # endregion diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index fbb57d63..c1f91dcd 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -1,3 +1,4 @@ +import asyncio import getpass import hashlib import io @@ -6,7 +7,6 @@ import logging import os import re import sys -import time import warnings from collections import OrderedDict, UserList from datetime import datetime, timedelta @@ -158,18 +158,16 @@ class TelegramClient(TelegramBareClient): connection_mode=ConnectionMode.TCP_FULL, use_ipv6=False, proxy=None, - update_workers=None, timeout=timedelta(seconds=5), - spawn_read_thread=True, + loop=None, **kwargs): super().__init__( session, api_id, api_hash, connection_mode=connection_mode, use_ipv6=use_ipv6, proxy=proxy, - update_workers=update_workers, - spawn_read_thread=spawn_read_thread, timeout=timeout, + loop=loop, **kwargs ) @@ -190,7 +188,7 @@ class TelegramClient(TelegramBareClient): # region Authorization requests - def send_code_request(self, phone, force_sms=False): + async def send_code_request(self, phone, force_sms=False): """ Sends a code request to the specified phone number. @@ -208,7 +206,7 @@ class TelegramClient(TelegramBareClient): phone_hash = self._phone_code_hash.get(phone) if not phone_hash: - result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) + result = await self(SendCodeRequest(phone, self.api_id, self.api_hash)) self._phone_code_hash[phone] = phone_hash = result.phone_code_hash else: force_sms = True @@ -216,22 +214,23 @@ class TelegramClient(TelegramBareClient): self._phone = phone if force_sms: - result = self(ResendCodeRequest(phone, phone_hash)) + result = await self(ResendCodeRequest(phone, phone_hash)) self._phone_code_hash[phone] = result.phone_code_hash return result - def start(self, - phone=lambda: input('Please enter your phone: '), - password=lambda: getpass.getpass('Please enter your password: '), - bot_token=None, force_sms=False, code_callback=None, - first_name='New User', last_name=''): + async def start(self, + phone=lambda: input('Please enter your phone: '), + password=lambda: getpass.getpass( + 'Please enter your password: '), + bot_token=None, force_sms=False, code_callback=None, + first_name='New User', last_name=''): """ Convenience method to interactively connect and sign in if required, also taking into consideration that 2FA may be enabled in the account. Example usage: - >>> client = TelegramClient(session, api_id, api_hash).start(phone) + >>> client = await TelegramClient(session, api_id, api_hash).start(phone) Please enter the code you received: 12345 Please enter your password: ******* (You are now logged in) @@ -286,14 +285,14 @@ class TelegramClient(TelegramBareClient): 'must only provide one of either') if not self.is_connected(): - self.connect() + await self.connect() if self.is_user_authorized(): self._check_events_pending_resolve() return self if bot_token: - self.sign_in(bot_token=bot_token) + await self.sign_in(bot_token=bot_token) return self # Turn the callable into a valid phone number @@ -305,15 +304,15 @@ class TelegramClient(TelegramBareClient): max_attempts = 3 two_step_detected = False - sent_code = self.send_code_request(phone, force_sms=force_sms) + sent_code = await self.send_code_request(phone, force_sms=force_sms) sign_up = not sent_code.phone_registered while attempts < max_attempts: try: if sign_up: - me = self.sign_up(code_callback(), first_name, last_name) + me = await self.sign_up(code_callback(), first_name, last_name) else: # Raises SessionPasswordNeededError if 2FA enabled - me = self.sign_in(phone, code_callback()) + me = await self.sign_in(phone, code_callback()) break except SessionPasswordNeededError: two_step_detected = True @@ -342,15 +341,15 @@ class TelegramClient(TelegramBareClient): # TODO If callable given make it retry on invalid if callable(password): password = password() - me = self.sign_in(phone=phone, password=password) + me = await self.sign_in(phone=phone, password=password) # We won't reach here if any step failed (exit by exception) print('Signed in successfully as', utils.get_display_name(me)) self._check_events_pending_resolve() return self - def sign_in(self, phone=None, code=None, - password=None, bot_token=None, phone_code_hash=None): + async def sign_in(self, phone=None, code=None, + password=None, bot_token=None, phone_code_hash=None): """ Starts or completes the sign in process with the given phone number or code that Telegram sent. @@ -385,7 +384,7 @@ class TelegramClient(TelegramBareClient): return self.get_me() if phone and not code and not password: - return self.send_code_request(phone) + return await self.send_code_request(phone) elif code: phone = utils.parse_phone(phone) or self._phone phone_code_hash = \ @@ -400,14 +399,14 @@ class TelegramClient(TelegramBareClient): # May raise PhoneCodeEmptyError, PhoneCodeExpiredError, # PhoneCodeHashEmptyError or PhoneCodeInvalidError. - result = self(SignInRequest(phone, phone_code_hash, str(code))) + result = await self(SignInRequest(phone, phone_code_hash, str(code))) elif password: - salt = self(GetPasswordRequest()).current_salt - result = self(CheckPasswordRequest( + salt = (await self(GetPasswordRequest())).current_salt + result = await self(CheckPasswordRequest( helpers.get_password_hash(password, salt) )) elif bot_token: - result = self(ImportBotAuthorizationRequest( + result = await self(ImportBotAuthorizationRequest( flags=0, bot_auth_token=bot_token, api_id=self.api_id, api_hash=self.api_hash )) @@ -420,10 +419,10 @@ class TelegramClient(TelegramBareClient): self._self_input_peer = utils.get_input_peer( result.user, allow_self=False ) - self._set_connected_and_authorized() + await self._set_connected_and_authorized() return result.user - def sign_up(self, code, first_name, last_name=''): + async def sign_up(self, code, first_name, last_name=''): """ Signs up to Telegram if you don't have an account yet. You must call .send_code_request(phone) first. @@ -442,10 +441,10 @@ class TelegramClient(TelegramBareClient): The new created user. """ if self.is_user_authorized(): - self._check_events_pending_resolve() - return self.get_me() + await self._check_events_pending_resolve() + return await self.get_me() - result = self(SignUpRequest( + result = await self(SignUpRequest( phone_number=self._phone, phone_code_hash=self._phone_code_hash.get(self._phone, ''), phone_code=str(code), @@ -456,10 +455,10 @@ class TelegramClient(TelegramBareClient): self._self_input_peer = utils.get_input_peer( result.user, allow_self=False ) - self._set_connected_and_authorized() + await self._set_connected_and_authorized() return result.user - def log_out(self): + async def log_out(self): """ Logs out Telegram and deletes the current ``*.session`` file. @@ -467,7 +466,7 @@ class TelegramClient(TelegramBareClient): True if the operation was successful. """ try: - self(LogOutRequest()) + await self(LogOutRequest()) except RPCError: return False @@ -475,7 +474,7 @@ class TelegramClient(TelegramBareClient): self.session.delete() return True - def get_me(self, input_peer=False): + async def get_me(self, input_peer=False): """ Gets "me" (the self user) which is currently authenticated, or None if the request fails (hence, not authenticated). @@ -491,9 +490,8 @@ class TelegramClient(TelegramBareClient): """ if input_peer and self._self_input_peer: return self._self_input_peer - try: - me = self(GetUsersRequest([InputUserSelf()]))[0] + me = (await self(GetUsersRequest([InputUserSelf()])))[0] if not self._self_input_peer: self._self_input_peer = utils.get_input_peer( me, allow_self=False @@ -507,8 +505,8 @@ class TelegramClient(TelegramBareClient): # region Dialogs ("chats") requests - def get_dialogs(self, limit=10, offset_date=None, offset_id=0, - offset_peer=InputPeerEmpty()): + async def get_dialogs(self, limit=10, offset_date=None, offset_id=0, + offset_peer=InputPeerEmpty()): """ Gets N "dialogs" (open "chats" or conversations with other people). @@ -535,7 +533,7 @@ class TelegramClient(TelegramBareClient): limit = float('inf') if limit is None else int(limit) if limit == 0: # Special case, get a single dialog and determine count - dialogs = self(GetDialogsRequest( + dialogs = await self(GetDialogsRequest( offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, @@ -549,7 +547,7 @@ class TelegramClient(TelegramBareClient): dialogs = OrderedDict() # Use peer id as identifier to avoid dupes while len(dialogs) < limit: real_limit = min(limit - len(dialogs), 100) - r = self(GetDialogsRequest( + r = await self(GetDialogsRequest( offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, @@ -580,7 +578,7 @@ class TelegramClient(TelegramBareClient): dialogs.total = total_count return dialogs - def get_drafts(self): # TODO: Ability to provide a `filter` + async def get_drafts(self): # TODO: Ability to provide a `filter` """ Gets all open draft messages. @@ -589,7 +587,7 @@ class TelegramClient(TelegramBareClient): You can call ``draft.set_message('text')`` to change the message, or delete it through :meth:`draft.delete()`. """ - response = self(GetAllDraftsRequest()) + response = await self(GetAllDraftsRequest()) self.session.process_entities(response) self.session.generate_sequence(response.seq) drafts = [Draft._from_update(self, u) for u in response.updates] @@ -636,7 +634,7 @@ class TelegramClient(TelegramBareClient): if request.id == update.message.id: return update.message - def _parse_message_text(self, message, parse_mode): + async def _parse_message_text(self, message, parse_mode): """ Returns a (parsed message, entities) tuple depending on parse_mode. """ @@ -657,7 +655,7 @@ class TelegramClient(TelegramBareClient): if m: try: msg_entities[i] = InputMessageEntityMentionName( - e.offset, e.length, self.get_input_entity( + e.offset, e.length, await self.get_input_entity( int(m.group(1)) if m.group(1) else e.url ) ) @@ -667,8 +665,8 @@ class TelegramClient(TelegramBareClient): return message, msg_entities - def send_message(self, entity, message, reply_to=None, parse_mode='md', - link_preview=True): + async def send_message(self, entity, message, reply_to=None, + parse_mode='md', link_preview=True): """ Sends the given message to the specified entity (user/chat/channel). @@ -695,11 +693,12 @@ class TelegramClient(TelegramBareClient): Returns: the sent message """ - entity = self.get_input_entity(entity) + + entity = await self.get_input_entity(entity) if isinstance(message, Message): if (message.media and not isinstance(message.media, MessageMediaWebPage)): - return self.send_file(entity, message.media) + return await self.send_file(entity, message.media) if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id): reply_id = message.reply_to_msg_id @@ -716,7 +715,7 @@ class TelegramClient(TelegramBareClient): ) message = message.message else: - message, msg_ent = self._parse_message_text(message, parse_mode) + message, msg_ent = await self._parse_message_text(message, parse_mode) request = SendMessageRequest( peer=entity, message=message, @@ -725,7 +724,8 @@ class TelegramClient(TelegramBareClient): reply_to_msg_id=self._get_message_id(reply_to) ) - result = self(request) + result = await self(request) + if isinstance(result, UpdateShortSentMessage): return Message( id=result.id, @@ -739,8 +739,8 @@ class TelegramClient(TelegramBareClient): return self._get_response_message(request, result) - def edit_message(self, entity, message_id, message=None, parse_mode='md', - link_preview=True): + async def edit_message(self, entity, message_id, message=None, + parse_mode='md', link_preview=True): """ Edits the given message ID (to change its contents or disable preview). @@ -773,18 +773,18 @@ class TelegramClient(TelegramBareClient): Returns: the edited message """ - message, msg_entities = self._parse_message_text(message, parse_mode) + message, msg_entities = await self._parse_message_text(message, parse_mode) request = EditMessageRequest( - peer=self.get_input_entity(entity), + peer=await self.get_input_entity(entity), id=self._get_message_id(message_id), message=message, no_webpage=not link_preview, entities=msg_entities ) - result = self(request) + result = await self(request) return self._get_response_message(request, result) - def delete_messages(self, entity, message_ids, revoke=True): + async def delete_messages(self, entity, message_ids, revoke=True): """ Deletes a message from a chat, optionally "for everyone". @@ -812,18 +812,18 @@ class TelegramClient(TelegramBareClient): message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids] if entity is None: - return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) + return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) if isinstance(entity, InputPeerChannel): - return self(channels.DeleteMessagesRequest(entity, message_ids)) + return await self(channels.DeleteMessagesRequest(entity, message_ids)) else: - return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) + return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) - def get_message_history(self, entity, limit=20, offset_date=None, - offset_id=0, max_id=0, min_id=0, add_offset=0, - batch_size=100, wait_time=None): + async def get_message_history(self, entity, limit=20, offset_date=None, + offset_id=0, max_id=0, min_id=0, add_offset=0, + batch_size=100, wait_time=None): """ Gets the message history for the specified entity @@ -884,13 +884,12 @@ class TelegramClient(TelegramBareClient): second is the default for this limit (or above). You may need an higher limit, so you're free to set the ``batch_size`` that you think may be good. - """ - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) limit = float('inf') if limit is None else int(limit) if limit == 0: # No messages, but we still need to know the total message count - result = self(GetHistoryRequest( + result = await self(GetHistoryRequest( peer=entity, limit=1, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0 )) @@ -906,7 +905,7 @@ class TelegramClient(TelegramBareClient): while len(messages) < limit: # Telegram has a hard limit of 100 real_limit = min(limit - len(messages), batch_size) - result = self(GetHistoryRequest( + result = await self(GetHistoryRequest( peer=entity, limit=real_limit, offset_date=offset_date, @@ -931,7 +930,7 @@ class TelegramClient(TelegramBareClient): offset_id = result.messages[-1].id offset_date = result.messages[-1].date - time.sleep(wait_time) + await asyncio.sleep(wait_time) # Add a few extra attributes to the Message to make it friendlier. messages.total = total_messages @@ -959,8 +958,8 @@ class TelegramClient(TelegramBareClient): return messages - def send_read_acknowledge(self, entity, message=None, max_id=None, - clear_mentions=False): + async def send_read_acknowledge(self, entity, message=None, max_id=None, + clear_mentions=False): """ Sends a "read acknowledge" (i.e., notifying the given peer that we've read their messages, also known as the "double check"). @@ -993,17 +992,17 @@ class TelegramClient(TelegramBareClient): raise ValueError( 'Either a message list or a max_id must be provided.') - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) if clear_mentions: - self(ReadMentionsRequest(entity)) + await self(ReadMentionsRequest(entity)) if max_id is None: return True if max_id is not None: if isinstance(entity, InputPeerChannel): - return self(channels.ReadHistoryRequest(entity, max_id=max_id)) + return await self(channels.ReadHistoryRequest(entity, max_id=max_id)) else: - return self(messages.ReadHistoryRequest(entity, max_id=max_id)) + return await self(messages.ReadHistoryRequest(entity, max_id=max_id)) return False @@ -1025,8 +1024,8 @@ class TelegramClient(TelegramBareClient): raise TypeError('Invalid message type: {}'.format(type(message))) - def get_participants(self, entity, limit=None, search='', - aggressive=False): + async def get_participants(self, entity, limit=None, search='', + aggressive=False): """ Gets the list of participants from the specified entity. @@ -1054,12 +1053,12 @@ class TelegramClient(TelegramBareClient): A list of participants with an additional .total variable on the list indicating the total amount of members in this group/channel. """ - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) limit = float('inf') if limit is None else int(limit) if isinstance(entity, InputPeerChannel): - total = self(GetFullChannelRequest( + total = (await self(GetFullChannelRequest( entity - )).full_chat.participants_count + ))).full_chat.participants_count all_participants = {} if total > 10000 and aggressive: @@ -1091,9 +1090,9 @@ class TelegramClient(TelegramBareClient): break if len(requests) == 1: - results = (self(requests[0]),) + results = (await self(requests[0]),) else: - results = self(*requests) + results = await self(*requests) for i in reversed(range(len(requests))): participants = results[i] if not participants.users: @@ -1111,7 +1110,7 @@ class TelegramClient(TelegramBareClient): users = UserList(values) users.total = total elif isinstance(entity, InputPeerChat): - users = self(GetFullChatRequest(entity.chat_id)).users + users = (await self(GetFullChatRequest(entity.chat_id))).users if len(users) > limit: users = users[:limit] users = UserList(users) @@ -1125,14 +1124,14 @@ class TelegramClient(TelegramBareClient): # region Uploading files - def send_file(self, entity, file, caption=None, - force_document=False, progress_callback=None, - reply_to=None, - attributes=None, - thumb=None, - allow_cache=True, - parse_mode='md', - **kwargs): + async def send_file(self, entity, file, caption=None, + force_document=False, progress_callback=None, + reply_to=None, + attributes=None, + thumb=None, + allow_cache=True, + parse_mode='md', + **kwargs): """ Sends a file to the specified entity. @@ -1201,14 +1200,14 @@ class TelegramClient(TelegramBareClient): # Convert to tuple so we can iterate several times file = tuple(x for x in file) if all(utils.is_image(x) for x in file): - return self._send_album( + return await self._send_album( entity, file, caption=caption, progress_callback=progress_callback, reply_to=reply_to, parse_mode=parse_mode ) # Not all are images, so send all the files one by one return [ - self.send_file( + await self.send_file( entity, x, allow_cache=False, caption=caption, force_document=force_document, progress_callback=progress_callback, reply_to=reply_to, @@ -1216,7 +1215,7 @@ class TelegramClient(TelegramBareClient): ) for x in file ] - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) reply_to = self._get_message_id(reply_to) caption, msg_entities = self._parse_message_text(caption, parse_mode) @@ -1233,11 +1232,11 @@ class TelegramClient(TelegramBareClient): reply_to_msg_id=reply_to, message=caption, entities=msg_entities) - return self._get_response_message(request, self(request)) + return self._get_response_message(request, await self(request)) as_image = utils.is_image(file) and not force_document use_cache = InputPhoto if as_image else InputDocument - file_handle = self.upload_file( + file_handle = await self.upload_file( file, progress_callback=progress_callback, use_cache=use_cache if allow_cache else None ) @@ -1314,7 +1313,7 @@ class TelegramClient(TelegramBareClient): input_kw = {} if thumb: - input_kw['thumb'] = self.upload_file(thumb) + input_kw['thumb'] = await self.upload_file(thumb) media = InputMediaUploadedDocument( file=file_handle, @@ -1327,7 +1326,7 @@ class TelegramClient(TelegramBareClient): # send the media message to the desired entity. request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to, message=caption, entities=msg_entities) - msg = self._get_response_message(request, self(request)) + msg = self._get_response_message(request, await self(request)) if msg and isinstance(file_handle, InputSizedFile): # There was a response message and we didn't use cached # version, so cache whatever we just sent to the database. @@ -1345,7 +1344,7 @@ class TelegramClient(TelegramBareClient): kwargs['is_voice_note'] = True return self.send_file(*args, **kwargs) - def _send_album(self, entity, files, caption=None, + async def _send_album(self, entity, files, caption=None, progress_callback=None, reply_to=None, parse_mode='md'): """Specialized version of .send_file for albums""" @@ -1354,7 +1353,7 @@ class TelegramClient(TelegramBareClient): # we need to produce right now to send albums (uploadMedia), and # cache only makes a difference for documents where the user may # want the attributes used on them to change. - entity = self.get_input_entity(entity) + entity = await self.get_input_entity(entity) if not utils.is_list_like(caption): caption = (caption,) captions = [ @@ -1367,11 +1366,11 @@ class TelegramClient(TelegramBareClient): media = [] for file in files: # fh will either be InputPhoto or a modified InputFile - fh = self.upload_file(file, use_cache=InputPhoto) + fh = await self.upload_file(file, use_cache=InputPhoto) if not isinstance(fh, InputPhoto): - input_photo = utils.get_input_photo(self(UploadMediaRequest( + input_photo = utils.get_input_photo((await self(UploadMediaRequest( entity, media=InputMediaUploadedPhoto(fh) - )).photo) + ))).photo) self.session.cache_file(fh.md5, fh.size, input_photo) fh = input_photo @@ -1383,7 +1382,7 @@ class TelegramClient(TelegramBareClient): entities=msg_entities)) # Now we can construct the multi-media request - result = self(SendMultiMediaRequest( + result = await self(SendMultiMediaRequest( entity, reply_to_msg_id=reply_to, multi_media=media )) return [ @@ -1392,12 +1391,8 @@ class TelegramClient(TelegramBareClient): if isinstance(update, UpdateMessageID) ] - def upload_file(self, - file, - part_size_kb=None, - file_name=None, - use_cache=None, - progress_callback=None): + async def upload_file(self, file, part_size_kb=None, file_name=None, + use_cache=None, progress_callback=None): """ Uploads the specified file and returns a handle (an instance of InputFile or InputFileBig, as required) which can be later used @@ -1510,7 +1505,7 @@ class TelegramClient(TelegramBareClient): else: request = SaveFilePartRequest(file_id, part_index, part) - result = self(request) + result = await self(request) if result: __log__.debug('Uploaded %d/%d', part_index + 1, part_count) @@ -1531,7 +1526,7 @@ class TelegramClient(TelegramBareClient): # region Downloading media requests - def download_profile_photo(self, entity, file=None, download_big=True): + async def download_profile_photo(self, entity, file=None, download_big=True): """ Downloads the profile photo of the given entity (user/chat/channel). @@ -1565,12 +1560,12 @@ class TelegramClient(TelegramBareClient): # The hexadecimal numbers above are simply: # hex(crc32(x.encode('ascii'))) for x in # ('User', 'Chat', 'UserFull', 'ChatFull') - entity = self.get_entity(entity) + entity = await self.get_entity(entity) if not hasattr(entity, 'photo'): # Special case: may be a ChatFull with photo:Photo # This is different from a normal UserProfilePhoto and Chat if hasattr(entity, 'chat_photo'): - return self._download_photo( + return await self._download_photo( entity.chat_photo, file, date=None, progress_callback=None ) @@ -1595,7 +1590,7 @@ class TelegramClient(TelegramBareClient): # Download the media with the largest size input file location try: - self.download_file( + await self.download_file( InputFileLocation( volume_id=photo_location.volume_id, local_id=photo_location.local_id, @@ -1606,10 +1601,10 @@ class TelegramClient(TelegramBareClient): except LocationInvalidError: # See issue #500, Android app fails as of v4.6.0 (1155). # The fix seems to be using the full channel chat photo. - ie = self.get_input_entity(entity) + ie = await self.get_input_entity(entity) if isinstance(ie, InputPeerChannel): - full = self(GetFullChannelRequest(ie)) - return self._download_photo( + full = await self(GetFullChannelRequest(ie)) + return await self._download_photo( full.full_chat.chat_photo, file, date=None, progress_callback=None ) @@ -1618,7 +1613,7 @@ class TelegramClient(TelegramBareClient): return None return file - def download_media(self, message, file=None, progress_callback=None): + async def download_media(self, message, file=None, progress_callback=None): """ Downloads the given media, or the media from a specified Message. @@ -1646,19 +1641,19 @@ class TelegramClient(TelegramBareClient): media = message if isinstance(media, (MessageMediaPhoto, Photo)): - return self._download_photo( + return await self._download_photo( media, file, date, progress_callback ) elif isinstance(media, (MessageMediaDocument, Document)): - return self._download_document( + return await self._download_document( media, file, date, progress_callback ) elif isinstance(media, MessageMediaContact): - return self._download_contact( + return await self._download_contact( media, file ) - def _download_photo(self, photo, file, date, progress_callback): + async def _download_photo(self, photo, file, date, progress_callback): """Specialized version of .download_media() for photos""" # Determine the photo and its largest size if isinstance(photo, MessageMediaPhoto): @@ -1673,7 +1668,7 @@ class TelegramClient(TelegramBareClient): file = self._get_proper_filename(file, 'photo', '.jpg', date=date) # Download the media with the largest size input file location - self.download_file( + await self.download_file( InputFileLocation( volume_id=largest_size.volume_id, local_id=largest_size.local_id, @@ -1685,7 +1680,7 @@ class TelegramClient(TelegramBareClient): ) return file - def _download_document(self, document, file, date, progress_callback): + async def _download_document(self, document, file, date, progress_callback): """Specialized version of .download_media() for documents""" if isinstance(document, MessageMediaDocument): document = document.document @@ -1718,7 +1713,7 @@ class TelegramClient(TelegramBareClient): date=date, possible_names=possible_names ) - self.download_file( + await self.download_file( InputDocumentFileLocation( id=document.id, access_hash=document.access_hash, @@ -1825,12 +1820,8 @@ class TelegramClient(TelegramBareClient): return result i += 1 - def download_file(self, - input_location, - file, - part_size_kb=None, - file_size=None, - progress_callback=None): + async def download_file(self, input_location, file, part_size_kb=None, + file_size=None, progress_callback=None): """ Downloads the given input location to a file. @@ -1889,23 +1880,24 @@ class TelegramClient(TelegramBareClient): while True: try: if cdn_decrypter: - result = cdn_decrypter.get_file() + result = await cdn_decrypter.get_file() else: - result = client(GetFileRequest( + result = await client(GetFileRequest( input_location, offset, part_size )) if isinstance(result, FileCdnRedirect): __log__.info('File lives in a CDN') cdn_decrypter, result = \ - CdnDecrypter.prepare_decrypter( - client, self._get_cdn_client(result), + await CdnDecrypter.prepare_decrypter( + client, + await self._get_cdn_client(result), result ) except FileMigrateError as e: __log__.info('File lives in another DC') - client = self._get_exported_client(e.new_dc) + client = await self._get_exported_client(e.new_dc) continue offset += part_size @@ -1947,25 +1939,25 @@ class TelegramClient(TelegramBareClient): The event builder class or instance to be used, for instance ``events.NewMessage``. """ - def decorator(f): - self.add_event_handler(f, event) + async def decorator(f): + await self.add_event_handler(f, event) return f return decorator - def _check_events_pending_resolve(self): + async def _check_events_pending_resolve(self): if self._events_pending_resolve: for event in self._events_pending_resolve: - event.resolve(self) + await event.resolve(self) self._events_pending_resolve.clear() - def _on_handler(self, update): + async def _on_handler(self, update): for builder, callback in self._event_builders: event = builder.build(update) if event: event._client = self try: - callback(event) + await callback(event) except events.StopPropagation: __log__.debug( "Event handler '{}' stopped chain of " @@ -1974,7 +1966,7 @@ class TelegramClient(TelegramBareClient): ) break - def add_event_handler(self, callback, event=None): + async def add_event_handler(self, callback, event=None): """ Registers the given callback to be called on the specified event. @@ -1989,12 +1981,6 @@ class TelegramClient(TelegramBareClient): If left unspecified, ``events.Raw`` (the ``Update`` objects with no further processing) will be passed instead. """ - if self.updates.workers is None: - warnings.warn( - "You have not setup any workers, so you won't receive updates." - " Pass update_workers=1 when creating the TelegramClient," - " or set client.self.updates.workers = 1" - ) self.updates.handler = self._on_handler if isinstance(event, type): @@ -2003,8 +1989,8 @@ class TelegramClient(TelegramBareClient): event = events.Raw() if self.is_user_authorized(): - event.resolve(self) - self._check_events_pending_resolve() + await event.resolve(self) + await self._check_events_pending_resolve() else: self._events_pending_resolve.append(event) @@ -2031,11 +2017,11 @@ class TelegramClient(TelegramBareClient): # region Small utilities to make users' life easier - def _set_connected_and_authorized(self): - super()._set_connected_and_authorized() - self._check_events_pending_resolve() + async def _set_connected_and_authorized(self): + await super()._set_connected_and_authorized() + await self._check_events_pending_resolve() - def get_entity(self, entity): + async def get_entity(self, entity): """ Turns the given entity into a valid Telegram user or chat. @@ -2069,7 +2055,7 @@ class TelegramClient(TelegramBareClient): # input channels (get channels) to get the most entities # in the less amount of calls possible. inputs = [ - x if isinstance(x, str) else self.get_input_entity(x) + x if isinstance(x, str) else await self.get_input_entity(x) for x in entity ] users = [x for x in inputs if isinstance(x, InputPeerUser)] @@ -2080,12 +2066,12 @@ class TelegramClient(TelegramBareClient): tmp = [] while users: curr, users = users[:200], users[200:] - tmp.extend(self(GetUsersRequest(curr))) + tmp.extend(await self(GetUsersRequest(curr))) users = tmp if chats: # TODO Handle chats slice? - chats = self(GetChatsRequest(chats)).chats + chats = (await self(GetChatsRequest(chats))).chats if channels: - channels = self(GetChannelsRequest(channels)).chats + channels = (await self(GetChannelsRequest(channels))).chats # Merge users, chats and channels into a single dictionary id_entity = { @@ -2098,33 +2084,31 @@ class TelegramClient(TelegramBareClient): # the amount of ResolveUsername calls, it would fail to catch # username changes. result = [ - self._get_entity_from_string(x) if isinstance(x, str) + await self._get_entity_from_string(x) if isinstance(x, str) else id_entity[utils.get_peer_id(x)] for x in inputs ] return result[0] if single else result - def _get_entity_from_string(self, string): + async def _get_entity_from_string(self, string): """ Gets a full entity from the given string, which may be a phone or an username, and processes all the found entities on the session. The string may also be a user link, or a channel/chat invite link. - This method has the side effect of adding the found users to the session database, so it can be queried later without API calls, if this option is enabled on the session. - Returns the found entity, or raises TypeError if not found. """ phone = utils.parse_phone(string) if phone: - for user in self(GetContactsRequest(0)).users: + for user in (await self(GetContactsRequest(0))).users: if user.phone == phone: return user else: username, is_join_chat = utils.parse_username(string) if is_join_chat: - invite = self(CheckChatInviteRequest(username)) + invite = await self(CheckChatInviteRequest(username)) if isinstance(invite, ChatInvite): raise ValueError( 'Cannot get entity from a channel ' @@ -2134,14 +2118,15 @@ class TelegramClient(TelegramBareClient): return invite.chat elif username: if username in ('me', 'self'): - return self.get_me() - result = self(ResolveUsernameRequest(username)) + return await self.get_me() + result = await self(ResolveUsernameRequest(username)) for entity in itertools.chain(result.users, result.chats): if entity.username.lower() == username: return entity try: # Nobody with this username, maybe it's an exact name/title - return self.get_entity(self.session.get_input_entity(string)) + return await self.get_entity( + self.session.get_input_entity(string)) except ValueError: pass @@ -2149,24 +2134,20 @@ class TelegramClient(TelegramBareClient): 'Cannot turn "{}" into any entity (user or chat)'.format(string) ) - def get_input_entity(self, peer): + async def get_input_entity(self, peer): """ Turns the given peer into its input entity version. Most requests use this kind of InputUser, InputChat and so on, so this is the most suitable call to make for those cases. - entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`): The integer ID of an user or otherwise either of a ``PeerUser``, ``PeerChat`` or ``PeerChannel``, for which to get its ``Input*`` version. - If this ``Peer`` hasn't been seen before by the library, the top dialogs will be loaded and their entities saved to the session file (unless this feature was disabled explicitly). - If in the end the access hash required for the peer was not found, a ValueError will be raised. - Returns: ``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``. """ @@ -2179,7 +2160,7 @@ class TelegramClient(TelegramBareClient): if isinstance(peer, str): if peer in ('me', 'self'): return InputPeerSelf() - return utils.get_input_peer(self._get_entity_from_string(peer)) + return utils.get_input_peer(await self._get_entity_from_string(peer)) if isinstance(peer, int): peer, kind = utils.resolve_id(peer) @@ -2206,7 +2187,7 @@ class TelegramClient(TelegramBareClient): limit=100 ) while True: - result = self(req) + result = await self(req) entities = {} for x in itertools.chain(result.users, result.chats): x_id = utils.get_peer_id(x) @@ -2222,7 +2203,7 @@ class TelegramClient(TelegramBareClient): req.offset_peer = entities[utils.get_peer_id( result.dialogs[-1].peer )] - time.sleep(1) + asyncio.sleep(1) raise TypeError( 'Could not find the input entity corresponding to "{}". ' diff --git a/telethon/tl/custom/dialog.py b/telethon/tl/custom/dialog.py index 366a19bf..61065bdf 100644 --- a/telethon/tl/custom/dialog.py +++ b/telethon/tl/custom/dialog.py @@ -26,9 +26,9 @@ class Dialog: self.draft = Draft(client, dialog.peer, dialog.draft) - def send_message(self, *args, **kwargs): + async def send_message(self, *args, **kwargs): """ Sends a message to this dialog. This is just a wrapper around client.send_message(dialog.input_entity, *args, **kwargs). """ - return self._client.send_message(self.input_entity, *args, **kwargs) + return await self._client.send_message(self.input_entity, *args, **kwargs) diff --git a/telethon/tl/custom/draft.py b/telethon/tl/custom/draft.py index 9b800d4c..e1ff7c91 100644 --- a/telethon/tl/custom/draft.py +++ b/telethon/tl/custom/draft.py @@ -31,14 +31,14 @@ class Draft: return cls(client=client, peer=update.peer, draft=update.draft) @property - def entity(self): - return self._client.get_entity(self._peer) + async def entity(self): + return await self._client.get_entity(self._peer) @property - def input_entity(self): - return self._client.get_input_entity(self._peer) + async def input_entity(self): + return await self._client.get_input_entity(self._peer) - def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None): + async def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None): """ Changes the draft message on the Telegram servers. The changes are reflected in this object. Changing only individual attributes like for @@ -58,7 +58,7 @@ class Draft: :param list entities: A list of formatting entities :return bool: ``True`` on success """ - result = self._client(SaveDraftRequest( + result = await self._client(SaveDraftRequest( peer=self._peer, message=text, no_webpage=no_webpage, @@ -74,9 +74,9 @@ class Draft: return result - def delete(self): + async def delete(self): """ Deletes this draft :return bool: ``True`` on success """ - return self.set_message(text='') + return await self.set_message(text='') diff --git a/telethon/tl/tlobject.py b/telethon/tl/tlobject.py index b048158c..1940580f 100644 --- a/telethon/tl/tlobject.py +++ b/telethon/tl/tlobject.py @@ -1,11 +1,10 @@ import struct from datetime import datetime, date -from threading import Event class TLObject: def __init__(self): - self.confirm_received = Event() + self.confirm_received = None self.rpc_error = None self.result = None @@ -157,7 +156,7 @@ class TLObject: return TLObject.pretty_format(self, indent=0) # These should be overrode - def resolve(self, client, utils): + async def resolve(self, client, utils): pass def to_dict(self): diff --git a/telethon/update_state.py b/telethon/update_state.py index 6a496603..f52b0d42 100644 --- a/telethon/update_state.py +++ b/telethon/update_state.py @@ -1,9 +1,8 @@ import logging import pickle +import asyncio from collections import deque -from queue import Queue, Empty from datetime import datetime -from threading import RLock, Thread from .tl import types as tl @@ -16,125 +15,40 @@ class UpdateState: """ WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers - def __init__(self, workers=None): - """ - :param workers: This integer parameter has three possible cases: - workers is None: Updates will *not* be stored on self. - workers = 0: Another thread is responsible for calling self.poll() - workers > 0: 'workers' background threads will be spawned, any - any of them will invoke the self.handler. - """ - self._workers = workers - self._worker_threads = [] - + def __init__(self, loop=None): self.handler = None - self._updates_lock = RLock() - self._updates = Queue() + self._loop = loop if loop else asyncio.get_event_loop() # https://core.telegram.org/api/updates self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) - def can_poll(self): - """Returns True if a call to .poll() won't lock""" - return not self._updates.empty() - - def poll(self, timeout=None): - """ - Polls an update or blocks until an update object is available. - If 'timeout is not None', it should be a floating point value, - and the method will 'return None' if waiting times out. - """ - try: - return self._updates.get(timeout=timeout) - except Empty: - return None - - def get_workers(self): - return self._workers - - def set_workers(self, n): - """Changes the number of workers running. - If 'n is None', clears all pending updates from memory. - """ - if n is None: - self.stop_workers() - else: - self._workers = n - self.setup_workers() - - workers = property(fget=get_workers, fset=set_workers) - - def stop_workers(self): - """ - Waits for all the worker threads to stop. - """ - # Put dummy ``None`` objects so that they don't need to timeout. - n = self._workers - self._workers = None - if n: - with self._updates_lock: - for _ in range(n): - self._updates.put(None) - - for t in self._worker_threads: - t.join() - - self._worker_threads.clear() - - def setup_workers(self): - if self._worker_threads or not self._workers: - # There already are workers, or workers is None or 0. Do nothing. - return - - for i in range(self._workers): - thread = Thread( - target=UpdateState._worker_loop, - name='UpdateWorker{}'.format(i), - daemon=True, - args=(self, i) - ) - self._worker_threads.append(thread) - thread.start() - - def _worker_loop(self, wid): - while self._workers is not None: - try: - update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT) - if update and self.handler: - self.handler(update) - except StopIteration: - break - except: - # We don't want to crash a worker thread due to any reason - __log__.exception('Unhandled exception on worker %d', wid) + def handle_update(self, update): + if self.handler: + asyncio.ensure_future(self.handler(update), loop=self._loop) def process(self, update): """Processes an update object. This method is normally called by the library itself. """ - if self._workers is None: - return # No processing needs to be done if nobody's working + if isinstance(update, tl.updates.State): + __log__.debug('Saved new updates state') + self._state = update + return # Nothing else to be done - with self._updates_lock: - if isinstance(update, tl.updates.State): - __log__.debug('Saved new updates state') - self._state = update - return # Nothing else to be done + if hasattr(update, 'pts'): + self._state.pts = update.pts - if hasattr(update, 'pts'): - self._state.pts = update.pts - - # After running the script for over an hour and receiving over - # 1000 updates, the only duplicates received were users going - # online or offline. We can trust the server until new reports. - if isinstance(update, tl.UpdateShort): - self._updates.put(update.update) - # Expand "Updates" into "Update", and pass these to callbacks. - # Since .users and .chats have already been processed, we - # don't need to care about those either. - elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): - for u in update.updates: - self._updates.put(u) - # TODO Handle "tl.UpdatesTooLong" - else: - self._updates.put(update) + # After running the script for over an hour and receiving over + # 1000 updates, the only duplicates received were users going + # online or offline. We can trust the server until new reports. + if isinstance(update, tl.UpdateShort): + self.handle_update(update.update) + # Expand "Updates" into "Update", and pass these to callbacks. + # Since .users and .chats have already been processed, we + # don't need to care about those either. + elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): + for u in update.updates: + self.handle_update(u) + # TODO Handle "tl.UpdatesTooLong" + else: + self.handle_update(update) diff --git a/telethon_generator/tl_generator.py b/telethon_generator/tl_generator.py index ff12acfe..b2706487 100644 --- a/telethon_generator/tl_generator.py +++ b/telethon_generator/tl_generator.py @@ -11,9 +11,9 @@ AUTO_GEN_NOTICE = \ AUTO_CASTS = { - 'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))', - 'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))', - 'InputUser': 'utils.get_input_user(client.get_input_entity({}))', + 'InputPeer': 'utils.get_input_peer(await client.get_input_entity({}))', + 'InputChannel': 'utils.get_input_channel(await client.get_input_entity({}))', + 'InputUser': 'utils.get_input_user(await client.get_input_entity({}))', 'InputMedia': 'utils.get_input_media({})', 'InputPhoto': 'utils.get_input_photo({})' } @@ -289,7 +289,7 @@ class TLGenerator: # Write the resolve(self, client, utils) method if any(arg.type in AUTO_CASTS for arg in args): - builder.writeln('def resolve(self, client, utils):') + builder.writeln('async def resolve(self, client, utils):') for arg in args: ac = AUTO_CASTS.get(arg.type, None) if ac: