Merge remote-tracking branch 'tulir/asyncio' into asyncio

This commit is contained in:
Lonami Exo 2018-03-03 17:03:27 +01:00
commit 563d731c95
17 changed files with 644 additions and 785 deletions

View File

@ -30,6 +30,7 @@ Creating a client
.. code:: python .. code:: python
import asyncio
from telethon import TelegramClient from telethon import TelegramClient
# These example values won't work. You must get your own api_id and # These example values won't work. You must get your own api_id and
@ -38,22 +39,28 @@ Creating a client
api_hash = '0123456789abcdef0123456789abcdef' api_hash = '0123456789abcdef0123456789abcdef'
client = TelegramClient('session_name', api_id, api_hash) client = TelegramClient('session_name', api_id, api_hash)
client.start() async def main():
await client.start()
asyncio.get_event_loop().run_until_complete(main())
Doing stuff Doing stuff
----------- -----------
Note that this assumes you're inside an "async def" method. Check out the
`Python documentation <https://docs.python.org/3/library/asyncio-dev.html>`_
if you're new with ``asyncio``.
.. code:: python .. code:: python
print(client.get_me().stringify()) print(client.get_me().stringify())
client.send_message('username', 'Hello! Talking to you from Telethon') await client.send_message('username', 'Hello! Talking to you from Telethon')
client.send_file('username', '/home/myself/Pictures/holidays.jpg') await client.send_file('username', '/home/myself/Pictures/holidays.jpg')
client.download_profile_photo('me') await client.download_profile_photo('me')
messages = client.get_message_history('username') messages = await client.get_message_history('username')
client.download_media(messages[0]) await client.download_media(messages[0])
Next steps Next steps
@ -61,5 +68,7 @@ Next steps
Do you like how Telethon looks? Check out Do you like how Telethon looks? Check out
`Read The Docs <http://telethon.rtfd.io/>`_ `Read The Docs <http://telethon.rtfd.io/>`_
for a more in-depth explanation, with examples, for a more in-depth explanation, with examples, troubleshooting issues,
troubleshooting issues, and more useful information. and more useful information. Note that the examples there are written for
the threaded version, not the one using asyncio. However, you just need to
await every remote call.

View File

@ -30,7 +30,7 @@ class CdnDecrypter:
self.cdn_file_hashes = cdn_file_hashes self.cdn_file_hashes = cdn_file_hashes
@staticmethod @staticmethod
def prepare_decrypter(client, cdn_client, cdn_redirect): async def prepare_decrypter(client, cdn_client, cdn_redirect):
""" """
Prepares a new CDN decrypter. Prepares a new CDN decrypter.
@ -52,14 +52,14 @@ class CdnDecrypter:
cdn_aes, cdn_redirect.cdn_file_hashes cdn_aes, cdn_redirect.cdn_file_hashes
) )
cdn_file = cdn_client(GetCdnFileRequest( cdn_file = await cdn_client(GetCdnFileRequest(
file_token=cdn_redirect.file_token, file_token=cdn_redirect.file_token,
offset=cdn_redirect.cdn_file_hashes[0].offset, offset=cdn_redirect.cdn_file_hashes[0].offset,
limit=cdn_redirect.cdn_file_hashes[0].limit limit=cdn_redirect.cdn_file_hashes[0].limit
)) ))
if isinstance(cdn_file, CdnFileReuploadNeeded): if isinstance(cdn_file, CdnFileReuploadNeeded):
# We need to use the original client here # We need to use the original client here
client(ReuploadCdnFileRequest( await client(ReuploadCdnFileRequest(
file_token=cdn_redirect.file_token, file_token=cdn_redirect.file_token,
request_token=cdn_file.request_token request_token=cdn_file.request_token
)) ))
@ -73,7 +73,7 @@ class CdnDecrypter:
return decrypter, cdn_file return decrypter, cdn_file
def get_file(self): async def get_file(self):
""" """
Calls GetCdnFileRequest and decrypts its bytes. Calls GetCdnFileRequest and decrypts its bytes.
Also ensures that the file hasn't been tampered. Also ensures that the file hasn't been tampered.
@ -82,7 +82,7 @@ class CdnDecrypter:
""" """
if self.cdn_file_hashes: if self.cdn_file_hashes:
cdn_hash = self.cdn_file_hashes.pop(0) cdn_hash = self.cdn_file_hashes.pop(0)
cdn_file = self.client(GetCdnFileRequest( cdn_file = await self.client(GetCdnFileRequest(
self.file_token, cdn_hash.offset, cdn_hash.limit self.file_token, cdn_hash.offset, cdn_hash.limit
)) ))
cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes) cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes)

View File

@ -9,7 +9,7 @@ from ..extensions import markdown
from ..tl import types, functions from ..tl import types, functions
def _into_id_set(client, chats): async def _into_id_set(client, chats):
"""Helper util to turn the input chat or chats into a set of IDs.""" """Helper util to turn the input chat or chats into a set of IDs."""
if chats is None: if chats is None:
return None return None
@ -19,9 +19,9 @@ def _into_id_set(client, chats):
result = set() result = set()
for chat in chats: for chat in chats:
chat = client.get_input_entity(chat) chat = await client.get_input_entity(chat)
if isinstance(chat, types.InputPeerSelf): if isinstance(chat, types.InputPeerSelf):
chat = client.get_me(input_peer=True) chat = await client.get_me(input_peer=True)
result.add(utils.get_peer_id(chat)) result.add(utils.get_peer_id(chat))
return result return result
@ -48,10 +48,10 @@ class _EventBuilder(abc.ABC):
def build(self, update): def build(self, update):
"""Builds an event for the given update if possible, or returns None""" """Builds an event for the given update if possible, or returns None"""
def resolve(self, client): async def resolve(self, client):
"""Helper method to allow event builders to be resolved before usage""" """Helper method to allow event builders to be resolved before usage"""
self.chats = _into_id_set(client, self.chats) self.chats = await _into_id_set(client, self.chats)
self._self_id = client.get_me(input_peer=True).user_id self._self_id = (await client.get_me(input_peer=True)).user_id
def _filter_event(self, event): def _filter_event(self, event):
""" """
@ -86,7 +86,7 @@ class _EventCommon(abc.ABC):
) )
self.is_channel = isinstance(chat_peer, types.PeerChannel) self.is_channel = isinstance(chat_peer, types.PeerChannel)
def _get_input_entity(self, msg_id, entity_id, chat=None): async def _get_input_entity(self, msg_id, entity_id, chat=None):
""" """
Helper function to call GetMessages on the give msg_id and Helper function to call GetMessages on the give msg_id and
return the input entity whose ID is the given entity ID. return the input entity whose ID is the given entity ID.
@ -95,11 +95,11 @@ class _EventCommon(abc.ABC):
""" """
try: try:
if isinstance(chat, types.InputPeerChannel): if isinstance(chat, types.InputPeerChannel):
result = self._client( result = await self._client(
functions.channels.GetMessagesRequest(chat, [msg_id]) functions.channels.GetMessagesRequest(chat, [msg_id])
) )
else: else:
result = self._client( result = await self._client(
functions.messages.GetMessagesRequest([msg_id]) functions.messages.GetMessagesRequest([msg_id])
) )
except RPCError: except RPCError:
@ -113,7 +113,7 @@ class _EventCommon(abc.ABC):
return utils.get_input_peer(entity) return utils.get_input_peer(entity)
@property @property
def input_chat(self): async def input_chat(self):
""" """
The (:obj:`InputPeer`) (group, megagroup or channel) on which The (:obj:`InputPeer`) (group, megagroup or channel) on which
the event occurred. This doesn't have the title or anything, the event occurred. This doesn't have the title or anything,
@ -125,7 +125,7 @@ class _EventCommon(abc.ABC):
if self._input_chat is None and self._chat_peer is not None: if self._input_chat is None and self._chat_peer is not None:
try: try:
self._input_chat = self._client.get_input_entity( self._input_chat = await self._client.get_input_entity(
self._chat_peer self._chat_peer
) )
except (ValueError, TypeError): except (ValueError, TypeError):
@ -134,22 +134,22 @@ class _EventCommon(abc.ABC):
# TODO For channels, getDifference? Maybe looking # TODO For channels, getDifference? Maybe looking
# in the dialogs (which is already done) is enough. # in the dialogs (which is already done) is enough.
if self._message_id is not None: if self._message_id is not None:
self._input_chat = self._get_input_entity( self._input_chat = await self._get_input_entity(
self._message_id, self._message_id,
utils.get_peer_id(self._chat_peer) utils.get_peer_id(self._chat_peer)
) )
return self._input_chat return self._input_chat
@property @property
def chat(self): async def chat(self):
""" """
The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which
the event occurred. This property will make an API call the first time the event occurred. This property will make an API call the first time
to get the most up to date version of the chat, so use with care as to get the most up to date version of the chat, so use with care as
there is no caching besides local caching yet. there is no caching besides local caching yet.
""" """
if self._chat is None and self.input_chat: if self._chat is None and await self.input_chat:
self._chat = self._client.get_entity(self._input_chat) self._chat = await self._client.get_entity(await self._input_chat)
return self._chat return self._chat
@ -157,7 +157,7 @@ class Raw(_EventBuilder):
""" """
Represents a raw event. The event is the update itself. Represents a raw event. The event is the update itself.
""" """
def resolve(self, client): async def resolve(self, client):
pass pass
def build(self, update): def build(self, update):
@ -304,23 +304,23 @@ class NewMessage(_EventBuilder):
self.is_reply = bool(message.reply_to_msg_id) self.is_reply = bool(message.reply_to_msg_id)
self._reply_message = None self._reply_message = None
def respond(self, *args, **kwargs): async def respond(self, *args, **kwargs):
""" """
Responds to the message (not as a reply). This is a shorthand for Responds to the message (not as a reply). This is a shorthand for
``client.send_message(event.chat, ...)``. ``client.send_message(event.chat, ...)``.
""" """
return self._client.send_message(self.input_chat, *args, **kwargs) return await self._client.send_message(await self.input_chat, *args, **kwargs)
def reply(self, *args, **kwargs): async def reply(self, *args, **kwargs):
""" """
Replies to the message (as a reply). This is a shorthand for Replies to the message (as a reply). This is a shorthand for
``client.send_message(event.chat, ..., reply_to=event.message.id)``. ``client.send_message(event.chat, ..., reply_to=event.message.id)``.
""" """
return self._client.send_message(self.input_chat, return await self._client.send_message(await self.input_chat,
reply_to=self.message.id, reply_to=self.message.id,
*args, **kwargs) *args, **kwargs)
def edit(self, *args, **kwargs): async def edit(self, *args, **kwargs):
""" """
Edits the message iff it's outgoing. This is a shorthand for Edits the message iff it's outgoing. This is a shorthand for
``client.edit_message(event.chat, event.message, ...)``. ``client.edit_message(event.chat, event.message, ...)``.
@ -331,27 +331,27 @@ class NewMessage(_EventBuilder):
if not self.message.out: if not self.message.out:
if not isinstance(self.message.to_id, types.PeerUser): if not isinstance(self.message.to_id, types.PeerUser):
return None return None
me = self._client.get_me(input_peer=True) me = await self._client.get_me(input_peer=True)
if self.message.to_id.user_id != me.user_id: if self.message.to_id.user_id != me.user_id:
return None return None
return self._client.edit_message(self.input_chat, return await self._client.edit_message(await self.input_chat,
self.message, self.message,
*args, **kwargs) *args, **kwargs)
def delete(self, *args, **kwargs): async def delete(self, *args, **kwargs):
""" """
Deletes the message. You're responsible for checking whether you Deletes the message. You're responsible for checking whether you
have the permission to do so, or to except the error otherwise. have the permission to do so, or to except the error otherwise.
This is a shorthand for This is a shorthand for
``client.delete_messages(event.chat, event.message, ...)``. ``client.delete_messages(event.chat, event.message, ...)``.
""" """
return self._client.delete_messages(self.input_chat, return await self._client.delete_messages(await self.input_chat,
[self.message], [self.message],
*args, **kwargs) *args, **kwargs)
@property @property
def input_sender(self): async def input_sender(self):
""" """
This (:obj:`InputPeer`) is the input version of the user who This (:obj:`InputPeer`) is the input version of the user who
sent the message. Similarly to ``input_chat``, this doesn't have sent the message. Similarly to ``input_chat``, this doesn't have
@ -365,21 +365,21 @@ class NewMessage(_EventBuilder):
return None return None
try: try:
self._input_sender = self._client.get_input_entity( self._input_sender = await self._client.get_input_entity(
self.message.from_id self.message.from_id
) )
except (ValueError, TypeError): except (ValueError, TypeError):
# We can rely on self.input_chat for this # We can rely on self.input_chat for this
self._input_sender = self._get_input_entity( self._input_sender = await self._get_input_entity(
self.message.id, self.message.id,
self.message.from_id, self.message.from_id,
chat=self.input_chat chat=await self.input_chat
) )
return self._input_sender return self._input_sender
@property @property
def sender(self): async def sender(self):
""" """
This (:obj:`User`) will make an API call the first time to get This (:obj:`User`) will make an API call the first time to get
the most up to date version of the sender, so use with care as the most up to date version of the sender, so use with care as
@ -387,8 +387,8 @@ class NewMessage(_EventBuilder):
``input_sender`` needs to be available (often the case). ``input_sender`` needs to be available (often the case).
""" """
if self._sender is None and self.input_sender: if self._sender is None and await self.input_sender:
self._sender = self._client.get_entity(self._input_sender) self._sender = await self._client.get_entity(self._input_sender)
return self._sender return self._sender
@property @property
@ -411,7 +411,7 @@ class NewMessage(_EventBuilder):
return self.message.message return self.message.message
@property @property
def reply_message(self): async def reply_message(self):
""" """
This (:obj:`Message`, optional) will make an API call the first This (:obj:`Message`, optional) will make an API call the first
time to get the full ``Message`` object that one was replying to, time to get the full ``Message`` object that one was replying to,
@ -421,12 +421,12 @@ class NewMessage(_EventBuilder):
return None return None
if self._reply_message is None: if self._reply_message is None:
if isinstance(self.input_chat, types.InputPeerChannel): if isinstance(await self.input_chat, types.InputPeerChannel):
r = self._client(functions.channels.GetMessagesRequest( r = await self._client(functions.channels.GetMessagesRequest(
self.input_chat, [self.message.reply_to_msg_id] await self.input_chat, [self.message.reply_to_msg_id]
)) ))
else: else:
r = self._client(functions.messages.GetMessagesRequest( r = await self._client(functions.messages.GetMessagesRequest(
[self.message.reply_to_msg_id] [self.message.reply_to_msg_id]
)) ))
if not isinstance(r, types.messages.MessagesNotModified): if not isinstance(r, types.messages.MessagesNotModified):
@ -610,7 +610,7 @@ class ChatAction(_EventBuilder):
self.new_title = new_title self.new_title = new_title
@property @property
def pinned_message(self): async def pinned_message(self):
""" """
If ``new_pin`` is ``True``, this returns the (:obj:`Message`) If ``new_pin`` is ``True``, this returns the (:obj:`Message`)
object that was pinned. object that was pinned.
@ -618,8 +618,8 @@ class ChatAction(_EventBuilder):
if self._pinned_message == 0: if self._pinned_message == 0:
return None return None
if isinstance(self._pinned_message, int) and self.input_chat: if isinstance(self._pinned_message, int) and await self.input_chat:
r = self._client(functions.channels.GetMessagesRequest( r = await self._client(functions.channels.GetMessagesRequest(
self._input_chat, [self._pinned_message] self._input_chat, [self._pinned_message]
)) ))
try: try:
@ -635,25 +635,25 @@ class ChatAction(_EventBuilder):
return self._pinned_message return self._pinned_message
@property @property
def added_by(self): async def added_by(self):
""" """
The user who added ``users``, if applicable (``None`` otherwise). The user who added ``users``, if applicable (``None`` otherwise).
""" """
if self._added_by and not isinstance(self._added_by, types.User): if self._added_by and not isinstance(self._added_by, types.User):
self._added_by = self._client.get_entity(self._added_by) self._added_by = await self._client.get_entity(self._added_by)
return self._added_by return self._added_by
@property @property
def kicked_by(self): async def kicked_by(self):
""" """
The user who kicked ``users``, if applicable (``None`` otherwise). The user who kicked ``users``, if applicable (``None`` otherwise).
""" """
if self._kicked_by and not isinstance(self._kicked_by, types.User): if self._kicked_by and not isinstance(self._kicked_by, types.User):
self._kicked_by = self._client.get_entity(self._kicked_by) self._kicked_by = await self._client.get_entity(self._kicked_by)
return self._kicked_by return self._kicked_by
@property @property
def user(self): async def user(self):
""" """
The single user that takes part in this action (e.g. joined). The single user that takes part in this action (e.g. joined).
@ -661,12 +661,12 @@ class ChatAction(_EventBuilder):
there is no user taking part. there is no user taking part.
""" """
try: try:
return next(self.users) return next(await self.users)
except (StopIteration, TypeError): except (StopIteration, TypeError):
return None return None
@property @property
def users(self): async def users(self):
""" """
A list of users that take part in this action (e.g. joined). A list of users that take part in this action (e.g. joined).
@ -675,7 +675,7 @@ class ChatAction(_EventBuilder):
""" """
if self._users is None and self._user_peers: if self._users is None and self._user_peers:
try: try:
self._users = self._client.get_entity(self._user_peers) self._users = await self._client.get_entity(self._user_peers)
except (TypeError, ValueError): except (TypeError, ValueError):
self._users = [] self._users = []

View File

@ -194,9 +194,9 @@ def get_inner_text(text, entity):
""" """
if isinstance(entity, TLObject): if isinstance(entity, TLObject):
entity = (entity,) entity = (entity,)
multiple = True
else:
multiple = False multiple = False
else:
multiple = True
text = _add_surrogate(text) text = _add_surrogate(text)
result = [] result = []

View File

@ -1,13 +1,20 @@
""" """
This module holds a rough implementation of the C# TCP client. This module holds a rough implementation of the C# TCP client.
""" """
# Python rough implementation of a C# TCP client
import asyncio
import errno import errno
import logging import logging
import socket import socket
import time import time
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
from threading import Lock
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
try: try:
import socks import socks
@ -25,7 +32,7 @@ __log__ = logging.getLogger(__name__)
class TcpClient: class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies.""" """A simple TCP client to ease the work with sockets and proxies."""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
""" """
Initializes the TCP client. Initializes the TCP client.
@ -34,7 +41,7 @@ class TcpClient:
""" """
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._closing_lock = Lock() self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
@ -54,9 +61,9 @@ class TcpClient:
else: # tuple, list, etc. else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy) self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout) self._socket.setblocking(False)
def connect(self, ip, port): async def connect(self, ip, port):
""" """
Tries connecting forever to IP:port unless an OSError is raised. Tries connecting forever to IP:port unless an OSError is raised.
@ -72,11 +79,15 @@ class TcpClient:
timeout = 1 timeout = 1
while True: while True:
try: try:
while not self._socket: if not self._socket:
self._recreate_socket(mode) self._recreate_socket(mode)
self._socket.connect(address) await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e: except OSError as e:
__log__.info('OSError "%s" raised while connecting', e) __log__.info('OSError "%s" raised while connecting', e)
# Stop retrying to connect if proxy connection error occurred # Stop retrying to connect if proxy connection error occurred
@ -90,7 +101,7 @@ class TcpClient:
# Bad file descriptor, i.e. socket was closed, set it # Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration # to none to recreate it on the next iteration
self._socket = None self._socket = None
time.sleep(timeout) await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT) timeout = min(timeout * 2, MAX_TIMEOUT)
else: else:
raise raise
@ -103,11 +114,6 @@ class TcpClient:
def close(self): def close(self):
"""Closes the connection.""" """Closes the connection."""
if self._closing_lock.locked():
# Already closing, no need to close again (avoid None.close())
return
with self._closing_lock:
try: try:
if self._socket is not None: if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR) self._socket.shutdown(socket.SHUT_RDWR)
@ -117,7 +123,7 @@ class TcpClient:
finally: finally:
self._socket = None self._socket = None
def write(self, data): async def write(self, data):
""" """
Writes (sends) the specified bytes to the connected peer. Writes (sends) the specified bytes to the connected peer.
@ -126,11 +132,13 @@ class TcpClient:
if self._socket is None: if self._socket is None:
self._raise_connection_reset(None) self._raise_connection_reset(None)
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try: try:
self._socket.sendall(data) await asyncio.wait_for(
except socket.timeout as e: self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
__log__.debug('socket.timeout "%s" while writing data', e) __log__.debug('socket.timeout "%s" while writing data', e)
raise TimeoutError() from e raise TimeoutError() from e
except ConnectionError as e: except ConnectionError as e:
@ -143,7 +151,7 @@ class TcpClient:
else: else:
raise raise
def read(self, size): async def read(self, size):
""" """
Reads (receives) a whole block of size bytes from the connected peer. Reads (receives) a whole block of size bytes from the connected peer.
@ -153,13 +161,18 @@ class TcpClient:
if self._socket is None: if self._socket is None:
self._raise_connection_reset(None) self._raise_connection_reset(None)
# TODO Remove the timeout from this method, always use previous one
with BufferedWriter(BytesIO(), buffer_size=size) as buffer: with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
try: try:
partial = self._socket.recv(bytes_left) if self._socket is None:
except socket.timeout as e: self._raise_connection_reset()
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
# These are somewhat common if the server has nothing # These are somewhat common if the server has nothing
# to send to us, so use a lower logging priority. # to send to us, so use a lower logging priority.
__log__.debug('socket.timeout "%s" while reading data', e) __log__.debug('socket.timeout "%s" while reading data', e)
@ -168,7 +181,7 @@ class TcpClient:
__log__.info('ConnectionError "%s" while reading data', e) __log__.info('ConnectionError "%s" while reading data', e)
self._raise_connection_reset(e) self._raise_connection_reset(e)
except OSError as e: except OSError as e:
if e.errno != errno.EBADF and self._closing_lock.locked(): if e.errno != errno.EBADF:
# Ignore bad file descriptor while closing # Ignore bad file descriptor while closing
__log__.info('OSError "%s" while reading data', e) __log__.info('OSError "%s" while reading data', e)
@ -190,5 +203,56 @@ class TcpClient:
def _raise_connection_reset(self, original): def _raise_connection_reset(self, original):
"""Disconnects the client and raises ConnectionResetError.""" """Disconnects the client and raises ConnectionResetError."""
self.close() # Connection reset -> flag as socket closed self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.')\ raise ConnectionResetError('The server has closed the connection.') from original
from original
# due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n):
fut = self._loop.create_future()
self._sock_recv(fut, None, n)
return fut
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled():
return
try:
data = self._socket.recv(n)
except (BlockingIOError, InterruptedError):
fd = self._socket.fileno()
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(data)
def sock_sendall(self, data):
fut = self._loop.create_future()
if data:
self._sock_sendall(fut, None, data)
else:
fut.set_result(None)
return fut
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled():
return
try:
n = self._socket.send(data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = self._socket.fileno()
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)

View File

@ -21,7 +21,7 @@ from ..tl.functions import (
) )
def do_authentication(connection, retries=5): async def do_authentication(connection, retries=5):
""" """
Performs the authentication steps on the given connection. Performs the authentication steps on the given connection.
Raises an error if all attempts fail. Raises an error if all attempts fail.
@ -36,14 +36,14 @@ def do_authentication(connection, retries=5):
last_error = None last_error = None
while retries: while retries:
try: try:
return _do_authentication(connection) return await _do_authentication(connection)
except (SecurityError, AssertionError, NotImplementedError) as e: except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e last_error = e
retries -= 1 retries -= 1
raise last_error raise last_error
def _do_authentication(connection): async def _do_authentication(connection):
""" """
Executes the authentication process with the Telegram servers. Executes the authentication process with the Telegram servers.
@ -56,8 +56,8 @@ def _do_authentication(connection):
req_pq_request = ReqPqMultiRequest( req_pq_request = ReqPqMultiRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True) nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
) )
sender.send(bytes(req_pq_request)) await sender.send(bytes(req_pq_request))
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
req_pq_request.on_response(reader) req_pq_request.on_response(reader)
res_pq = req_pq_request.result res_pq = req_pq_request.result
@ -104,10 +104,10 @@ def _do_authentication(connection):
public_key_fingerprint=target_fingerprint, public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text encrypted_data=cipher_text
) )
sender.send(bytes(req_dh_params)) await sender.send(bytes(req_dh_params))
# Step 2 response: DH Exchange # Step 2 response: DH Exchange
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
req_dh_params.on_response(reader) req_dh_params.on_response(reader)
server_dh_params = req_dh_params.result server_dh_params = req_dh_params.result
@ -174,10 +174,10 @@ def _do_authentication(connection):
server_nonce=res_pq.server_nonce, server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted, encrypted_data=client_dh_encrypted,
) )
sender.send(bytes(set_client_dh)) await sender.send(bytes(set_client_dh))
# Step 3 response: Complete DH Exchange # Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
set_client_dh.on_response(reader) set_client_dh.on_response(reader)
dh_gen = set_client_dh.result dh_gen = set_client_dh.result

View File

@ -2,18 +2,17 @@
This module holds both the Connection class and the ConnectionMode enum, This module holds both the Connection class and the ConnectionMode enum,
which specifies the protocol to be used by the Connection. which specifies the protocol to be used by the Connection.
""" """
import errno
import logging import logging
import os import os
import struct import struct
from datetime import timedelta from datetime import timedelta
from zlib import crc32
from enum import Enum from enum import Enum
from zlib import crc32
import errno
from ..crypto import AESModeCTR from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError from ..errors import InvalidChecksumError
from ..extensions import TcpClient
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -52,7 +51,7 @@ class Connection:
""" """
def __init__(self, mode=ConnectionMode.TCP_FULL, def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)): proxy=None, timeout=timedelta(seconds=5), loop=None):
""" """
Initializes a new connection. Initializes a new connection.
@ -65,7 +64,7 @@ class Connection:
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket? # TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout) self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages # Sending messages
if mode == ConnectionMode.TCP_FULL: if mode == ConnectionMode.TCP_FULL:
@ -89,7 +88,7 @@ class Connection:
setattr(self, 'write', self._write_plain) setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain) setattr(self, 'read', self._read_plain)
def connect(self, ip, port): async def connect(self, ip, port):
""" """
Estabilishes a connection to IP:port. Estabilishes a connection to IP:port.
@ -97,7 +96,7 @@ class Connection:
:param port: the port to connect to. :param port: the port to connect to.
""" """
try: try:
self.conn.connect(ip, port) await self.conn.connect(ip, port)
except OSError as e: except OSError as e:
if e.errno == errno.EISCONN: if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up return # Already connected, no need to re-set everything up
@ -106,17 +105,17 @@ class Connection:
self._send_counter = 0 self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED: if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef') await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE: elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee') await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED: elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation() await self._setup_obfuscation()
def get_timeout(self): def get_timeout(self):
"""Returns the timeout used by the connection.""" """Returns the timeout used by the connection."""
return self.conn.timeout return self.conn.timeout
def _setup_obfuscation(self): async def _setup_obfuscation(self):
""" """
Sets up the obfuscated protocol. Sets up the obfuscated protocol.
""" """
@ -144,7 +143,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random)) await self.conn.write(bytes(random))
def is_connected(self): def is_connected(self):
""" """
@ -166,12 +165,12 @@ class Connection:
# region Receive message implementations # region Receive message implementations
def recv(self): async def recv(self):
"""Receives and unpacks a message""" """Receives and unpacks a message"""
# Default implementation is just an error # Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self): async def _recv_tcp_full(self):
""" """
Receives a message from the network, Receives a message from the network,
internally encoded using the TCP full protocol. internally encoded using the TCP full protocol.
@ -181,7 +180,10 @@ class Connection:
:return: the read message payload. :return: the read message payload.
""" """
packet_len_seq = self.read(8) # 4 and 4 # TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq) packet_len, seq = struct.unpack('<ii', packet_len_seq)
# Sometimes Telegram seems to send a packet length of 0 (12) # Sometimes Telegram seems to send a packet length of 0 (12)
@ -192,15 +194,15 @@ class Connection:
'reading data left:', packet_len) 'reading data left:', packet_len)
while True: while True:
try: try:
__log__.error(repr(self.read(1))) __log__.error(repr(await self.read(1)))
except TimeoutError: except TimeoutError:
break break
# Connection reset and hope it's fixed after # Connection reset and hope it's fixed after
self.conn.close() self.conn.close()
raise ConnectionResetError() raise ConnectionResetError()
body = self.read(packet_len - 12) body = await self.read(packet_len - 12)
checksum = struct.unpack('<I', self.read(4))[0] checksum = struct.unpack('<I', await self.read(4))[0]
valid_checksum = crc32(packet_len_seq + body) valid_checksum = crc32(packet_len_seq + body)
if checksum != valid_checksum: if checksum != valid_checksum:
@ -208,38 +210,38 @@ class Connection:
return body return body
def _recv_intermediate(self): async def _recv_intermediate(self):
""" """
Receives a message from the network, Receives a message from the network,
internally encoded using the TCP intermediate protocol. internally encoded using the TCP intermediate protocol.
:return: the read message payload. :return: the read message payload.
""" """
return self.read(struct.unpack('<i', self.read(4))[0]) return await self.read(struct.unpack('<i', await self.read(4))[0])
def _recv_abridged(self): async def _recv_abridged(self):
""" """
Receives a message from the network, Receives a message from the network,
internally encoded using the TCP abridged protocol. internally encoded using the TCP abridged protocol.
:return: the read message payload. :return: the read message payload.
""" """
length = struct.unpack('<B', self.read(1))[0] length = struct.unpack('<B', await self.read(1))[0]
if length >= 127: if length >= 127:
length = struct.unpack('<i', self.read(3) + b'\0')[0] length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return self.read(length << 2) return await self.read(length << 2)
# endregion # endregion
# region Send message implementations # region Send message implementations
def send(self, message): async def send(self, message):
"""Encapsulates and sends the given message""" """Encapsulates and sends the given message"""
# Default implementation is just an error # Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message): async def _send_tcp_full(self, message):
""" """
Encapsulates and sends the given message payload Encapsulates and sends the given message payload
using the TCP full mode (length, sequence, message, crc32). using the TCP full mode (length, sequence, message, crc32).
@ -252,18 +254,18 @@ class Connection:
data = struct.pack('<ii', length, self._send_counter) + message data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data)) crc = struct.pack('<I', crc32(data))
self._send_counter += 1 self._send_counter += 1
self.write(data + crc) await self.write(data + crc)
def _send_intermediate(self, message): async def _send_intermediate(self, message):
""" """
Encapsulates and sends the given message payload Encapsulates and sends the given message payload
using the TCP intermediate mode (length, message). using the TCP intermediate mode (length, message).
:param message: the message to be sent. :param message: the message to be sent.
""" """
self.write(struct.pack('<i', len(message)) + message) await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message): async def _send_abridged(self, message):
""" """
Encapsulates and sends the given message payload Encapsulates and sends the given message payload
using the TCP abridged mode (short length, message). using the TCP abridged mode (short length, message).
@ -276,57 +278,55 @@ class Connection:
else: else:
length = b'\x7f' + int.to_bytes(length, 3, 'little') length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message) await self.write(length + message)
# endregion # endregion
# region Read implementations # region Read implementations
def read(self, length): async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length): async def _read_plain(self, length):
""" """
Reads data from the socket connection. Reads data from the socket connection.
:param length: how many bytes should be read. :param length: how many bytes should be read.
:return: a byte sequence with len(data) == length :return: a byte sequence with len(data) == length
""" """
return self.conn.read(length) return await self.conn.read(length)
def _read_obfuscated(self, length): async def _read_obfuscated(self, length):
""" """
Reads data and decrypts from the socket connection. Reads data and decrypts from the socket connection.
:param length: how many bytes should be read. :param length: how many bytes should be read.
:return: the decrypted byte sequence with len(data) == length :return: the decrypted byte sequence with len(data) == length
""" """
return self._aes_decrypt.encrypt( return self._aes_decrypt.encrypt(await self.conn.read(length))
self.conn.read(length)
)
# endregion # endregion
# region Write implementations # region Write implementations
def write(self, data): async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data): async def _write_plain(self, data):
""" """
Writes the given data through the socket connection. Writes the given data through the socket connection.
:param data: the data in bytes to be written. :param data: the data in bytes to be written.
""" """
self.conn.write(data) await self.conn.write(data)
def _write_obfuscated(self, data): async def _write_obfuscated(self, data):
""" """
Writes the given data through the socket connection, Writes the given data through the socket connection,
using the obfuscated mode (AES encryption is applied on top). using the obfuscated mode (AES encryption is applied on top).
:param data: the data in bytes to be written. :param data: the data in bytes to be written.
""" """
self.conn.write(self._aes_encrypt.encrypt(data)) await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion # endregion

View File

@ -26,32 +26,32 @@ class MtProtoPlainSender:
self._last_msg_id = 0 self._last_msg_id = 0
self._connection = connection self._connection = connection
def connect(self): async def connect(self):
"""Connects to Telegram's servers.""" """Connects to Telegram's servers."""
self._connection.connect() await self._connection.connect()
def disconnect(self): def disconnect(self):
"""Disconnects from Telegram's servers.""" """Disconnects from Telegram's servers."""
self._connection.close() self._connection.close()
def send(self, data): async def send(self, data):
""" """
Sends a plain packet (auth_key_id = 0) containing the Sends a plain packet (auth_key_id = 0) containing the
given message body (data). given message body (data).
:param data: the data to be sent. :param data: the data to be sent.
""" """
self._connection.send( await self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
) )
def receive(self): async def receive(self):
""" """
Receives a plain packet from the network. Receives a plain packet from the network.
:return: the response body. :return: the response body.
""" """
body = self._connection.recv() body = await self._connection.recv()
if body == b'l\xfe\xff\xff': # -404 little endian signed if body == b'l\xfe\xff\xff': # -404 little endian signed
# Broken authorization, must reset the auth key # Broken authorization, must reset the auth key
raise BrokenAuthKeyError() raise BrokenAuthKeyError()

View File

@ -2,9 +2,10 @@
This module contains the class used to communicate with Telegram's servers This module contains the class used to communicate with Telegram's servers
encrypting every packet, and relies on a valid AuthKey in the used Session. encrypting every packet, and relies on a valid AuthKey in the used Session.
""" """
import asyncio
import gzip import gzip
import logging import logging
from threading import Lock from asyncio import Event
from .. import helpers as utils from .. import helpers as utils
from ..errors import ( from ..errors import (
@ -33,7 +34,7 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed. in parallel, so thread-safety (hence locking) isn't needed.
""" """
def __init__(self, session, connection): def __init__(self, session, connection, loop=None):
""" """
Initializes a new MTProto sender. Initializes a new MTProto sender.
@ -42,22 +43,20 @@ class MtProtoSender:
port of the server, salt, ID, and AuthKey, port of the server, salt, ID, and AuthKey,
:param connection: :param connection:
the Connection to be used. the Connection to be used.
:param loop:
the asyncio loop to be used, or the default one.
""" """
self.session = session self.session = session
self.connection = connection self.connection = connection
self._loop = loop if loop else asyncio.get_event_loop()
# Message IDs that need confirmation self._recv_lock = asyncio.Lock()
self._need_confirmation = set()
# Requests (as msg_id: Message) sent waiting to be received # Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {} self._pending_receive = {}
# Multithreading async def connect(self):
self._send_lock = Lock()
def connect(self):
"""Connects to the server.""" """Connects to the server."""
self.connection.connect(self.session.server_address, self.session.port) await self.connection.connect(self.session.server_address, self.session.port)
def is_connected(self): def is_connected(self):
""" """
@ -70,18 +69,25 @@ class MtProtoSender:
def disconnect(self): def disconnect(self):
"""Disconnects from the server.""" """Disconnects from the server."""
self.connection.close() self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending() self._clear_all_pending()
# region Send and receive # region Send and receive
def send(self, *requests): async def send(self, *requests):
""" """
Sends the specified TLObject(s) (which must be requests), Sends the specified TLObject(s) (which must be requests),
and acknowledging any message which needed confirmation. and acknowledging any message which needed confirmation.
:param requests: the requests to be sent. :param requests: the requests to be sent.
""" """
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s) # Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests] messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages}) self._pending_receive.update({m.msg_id: m for m in messages})
@ -91,13 +97,6 @@ class MtProtoSender:
for m in messages for m in messages
)) ))
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
)
self._need_confirmation.clear()
if len(messages) == 1: if len(messages) == 1:
message = messages[0] message = messages[0]
else: else:
@ -108,13 +107,13 @@ class MtProtoSender:
for m in messages: for m in messages:
m.container_msg_id = message.msg_id m.container_msg_id = message.msg_id
self._send_message(message) await self._send_message(message)
def _send_acknowledge(self, msg_id): async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id.""" """Sends a message acknowledge for the given msg_id."""
self._send_message(TLMessage(self.session, MsgsAck([msg_id]))) await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
def receive(self, update_state): async def receive(self, update_state):
""" """
Receives a single message from the connected endpoint. Receives a single message from the connected endpoint.
@ -130,7 +129,10 @@ class MtProtoSender:
Update and Updates objects. Update and Updates objects.
""" """
try: try:
body = self.connection.recv() with await self._recv_lock:
# Receiving items is not an "atomic" operation since we
# need to read the length and then upcoming parts separated.
body = await self.connection.recv()
except (BufferError, InvalidChecksumError): except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause... # TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear # "No more bytes left"; something wrong happened, clear
@ -147,20 +149,20 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_seq, reader, update_state) await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion # endregion
# region Low level processing # region Low level processing
def _send_message(self, message): async def _send_message(self, message):
""" """
Sends the given encrypted through the network. Sends the given encrypted through the network.
:param message: the TLMessage to be sent. :param message: the TLMessage to be sent.
""" """
with self._send_lock: await self.connection.send(utils.pack_message(self.session, message))
self.connection.send(utils.pack_message(self.session, message))
def _decode_msg(self, body): def _decode_msg(self, body):
""" """
@ -178,7 +180,7 @@ class MtProtoSender:
with BinaryReader(body) as reader: with BinaryReader(body) as reader:
return utils.unpack_message(self.session, reader) return utils.unpack_message(self.session, reader)
def _process_msg(self, msg_id, sequence, reader, state): async def _process_msg(self, msg_id, sequence, reader, state):
""" """
Processes the message read from the network inside reader. Processes the message read from the network inside reader.
@ -189,7 +191,6 @@ class MtProtoSender:
:return: true if the message was handled correctly, false otherwise. :return: true if the message was handled correctly, false otherwise.
""" """
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
self._need_confirmation.add(msg_id)
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
reader.seek(-4) reader.seek(-4)
@ -197,15 +198,15 @@ class MtProtoSender:
# These are a bit of special case, not yet generated by the code gen # These are a bit of special case, not yet generated by the code gen
if code == 0xf35c6d01: # rpc_result, (response of an RPC call) if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
__log__.debug('Processing Remote Procedure Call result') __log__.debug('Processing Remote Procedure Call result')
return self._handle_rpc_result(msg_id, sequence, reader) return await self._handle_rpc_result(msg_id, sequence, reader)
if code == MessageContainer.CONSTRUCTOR_ID: if code == MessageContainer.CONSTRUCTOR_ID:
__log__.debug('Processing container result') __log__.debug('Processing container result')
return self._handle_container(msg_id, sequence, reader, state) return await self._handle_container(msg_id, sequence, reader, state)
if code == GzipPacked.CONSTRUCTOR_ID: if code == GzipPacked.CONSTRUCTOR_ID:
__log__.debug('Processing gzipped result') __log__.debug('Processing gzipped result')
return self._handle_gzip_packed(msg_id, sequence, reader, state) return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code not in tlobjects: if code not in tlobjects:
__log__.warning( __log__.warning(
@ -218,22 +219,22 @@ class MtProtoSender:
__log__.debug('Processing %s result', type(obj).__name__) __log__.debug('Processing %s result', type(obj).__name__)
if isinstance(obj, Pong): if isinstance(obj, Pong):
return self._handle_pong(msg_id, sequence, obj) return await self._handle_pong(msg_id, sequence, obj)
if isinstance(obj, BadServerSalt): if isinstance(obj, BadServerSalt):
return self._handle_bad_server_salt(msg_id, sequence, obj) return await self._handle_bad_server_salt(msg_id, sequence, obj)
if isinstance(obj, BadMsgNotification): if isinstance(obj, BadMsgNotification):
return self._handle_bad_msg_notification(msg_id, sequence, obj) return await self._handle_bad_msg_notification(msg_id, sequence, obj)
if isinstance(obj, MsgDetailedInfo): if isinstance(obj, MsgDetailedInfo):
return self._handle_msg_detailed_info(msg_id, sequence, obj) return await self._handle_msg_detailed_info(msg_id, sequence, obj)
if isinstance(obj, MsgNewDetailedInfo): if isinstance(obj, MsgNewDetailedInfo):
return self._handle_msg_new_detailed_info(msg_id, sequence, obj) return await self._handle_msg_new_detailed_info(msg_id, sequence, obj)
if isinstance(obj, NewSessionCreated): if isinstance(obj, NewSessionCreated):
return self._handle_new_session_created(msg_id, sequence, obj) return await self._handle_new_session_created(msg_id, sequence, obj)
if isinstance(obj, MsgsAck): # may handle the request we wanted if isinstance(obj, MsgsAck): # may handle the request we wanted
# Ignore every ack request *unless* when logging out, when it's # Ignore every ack request *unless* when logging out, when it's
@ -310,7 +311,7 @@ class MtProtoSender:
r.request.confirm_received.set() r.request.confirm_received.set()
self._pending_receive.clear() self._pending_receive.clear()
def _resend_request(self, msg_id): async def _resend_request(self, msg_id):
""" """
Re-sends the request that belongs to a certain msg_id. This may Re-sends the request that belongs to a certain msg_id. This may
also be the msg_id of a container if they were sent in one. also be the msg_id of a container if they were sent in one.
@ -319,12 +320,13 @@ class MtProtoSender:
""" """
request = self._pop_request(msg_id) request = self._pop_request(msg_id)
if request: if request:
return self.send(request) await self.send(request)
return
requests = self._pop_requests_of_container(msg_id) requests = self._pop_requests_of_container(msg_id)
if requests: if requests:
return self.send(*requests) await self.send(*requests)
def _handle_pong(self, msg_id, sequence, pong): async def _handle_pong(self, msg_id, sequence, pong):
""" """
Handles a Pong response. Handles a Pong response.
@ -340,7 +342,7 @@ class MtProtoSender:
return True return True
def _handle_container(self, msg_id, sequence, reader, state): async def _handle_container(self, msg_id, sequence, reader, state):
""" """
Handles a MessageContainer response. Handles a MessageContainer response.
@ -355,7 +357,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not self._process_msg(inner_msg_id, sequence, reader, state): if not await self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_len) reader.set_position(begin_position + inner_len)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
@ -364,7 +366,7 @@ class MtProtoSender:
return True return True
def _handle_bad_server_salt(self, msg_id, sequence, bad_salt): async def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
""" """
Handles a BadServerSalt response. Handles a BadServerSalt response.
@ -378,10 +380,11 @@ class MtProtoSender:
# "the bad_server_salt response is received with the # "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it" # correct salt, and the message is to be re-sent with it"
self._resend_request(bad_salt.bad_msg_id) await self._resend_request(bad_salt.bad_msg_id)
return True return True
def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg): async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
""" """
Handles a BadMessageError response. Handles a BadMessageError response.
@ -397,25 +400,25 @@ class MtProtoSender:
# Use the current msg_id to determine the right time offset. # Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id) self.session.update_time_offset(correct_msg_id=msg_id)
__log__.info('Attempting to use the correct time offset') __log__.info('Attempting to use the correct time offset')
self._resend_request(bad_msg.bad_msg_id) await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.session.sequence += 64 self.session.sequence += 64
__log__.info('Attempting to set the right higher sequence') __log__.info('Attempting to set the right higher sequence')
self._resend_request(bad_msg.bad_msg_id) await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.session.sequence -= 16 self.session.sequence -= 16
__log__.info('Attempting to set the right lower sequence') __log__.info('Attempting to set the right lower sequence')
self._resend_request(bad_msg.bad_msg_id) await self._resend_request(bad_msg.bad_msg_id)
return True return True
else: else:
raise error raise error
def _handle_msg_detailed_info(self, msg_id, sequence, msg_new): async def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
""" """
Handles a MsgDetailedInfo response. Handles a MsgDetailedInfo response.
@ -426,10 +429,10 @@ class MtProtoSender:
""" """
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6 # Relevant tdesktop source code: https://goo.gl/VvpCC6
self._send_acknowledge(msg_new.answer_msg_id) await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new): async def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
""" """
Handles a MsgNewDetailedInfo response. Handles a MsgNewDetailedInfo response.
@ -440,10 +443,10 @@ class MtProtoSender:
""" """
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR # Relevant tdesktop source code: https://goo.gl/G7DPsR
self._send_acknowledge(msg_new.answer_msg_id) await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
def _handle_new_session_created(self, msg_id, sequence, new_session): async def _handle_new_session_created(self, msg_id, sequence, new_session):
""" """
Handles a NewSessionCreated response. Handles a NewSessionCreated response.
@ -456,7 +459,7 @@ class MtProtoSender:
# TODO https://goo.gl/LMyN7A # TODO https://goo.gl/LMyN7A
return True return True
def _handle_rpc_result(self, msg_id, sequence, reader): async def _handle_rpc_result(self, msg_id, sequence, reader):
""" """
Handles a RPCResult response. Handles a RPCResult response.
@ -484,9 +487,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string() reader.read_int(), reader.tgread_string()
) )
# Acknowledge that we received the error
self._send_acknowledge(request_id)
if request: if request:
request.rpc_error = error request.rpc_error = error
request.confirm_received.set() request.confirm_received.set()
@ -522,7 +522,7 @@ class MtProtoSender:
) )
return False return False
def _handle_gzip_packed(self, msg_id, sequence, reader, state): async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
""" """
Handles a GzipPacked response. Handles a GzipPacked response.
@ -532,11 +532,6 @@ class MtProtoSender:
:return: the result of processing the packed message. :return: the result of processing the packed message.
""" """
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
# We are reentering process_msg, which seemingly the same msg_id return await self._process_msg(msg_id, sequence, compressed_reader, state)
# to the self._need_confirmation set. Remove it from there first
# to avoid any future conflicts (i.e. if we "ignore" messages
# that we are already aware of, see 1a91c02 and old 63dfb1e)
self._need_confirmation -= {msg_id}
return self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion # endregion

View File

@ -3,7 +3,6 @@ import os
import sqlite3 import sqlite3
from base64 import b64decode from base64 import b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock, RLock
from .memory import MemorySession, _SentFileType from .memory import MemorySession, _SentFileType
from ..crypto import AuthKey from ..crypto import AuthKey
@ -39,11 +38,6 @@ class SQLiteSession(MemorySession):
if not self.filename.endswith(EXTENSION): if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION self.filename += EXTENSION
# Cross-thread safety
self._seq_no_lock = Lock()
self._msg_id_lock = Lock()
self._db_lock = RLock()
# Migrating from .json -> SQL # Migrating from .json -> SQL
entities = self._check_migrate_json() entities = self._check_migrate_json()
@ -189,7 +183,6 @@ class SQLiteSession(MemorySession):
self._update_session_table() self._update_session_table()
def _update_session_table(self): def _update_session_table(self):
with self._db_lock:
c = self._cursor() c = self._cursor()
# While we can save multiple rows into the sessions table # While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't # currently we only want to keep ONE as the tables don't
@ -207,21 +200,17 @@ class SQLiteSession(MemorySession):
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
with self._db_lock:
self._conn.commit() self._conn.commit()
def _cursor(self): def _cursor(self):
"""Asserts that the connection is open and returns a cursor""" """Asserts that the connection is open and returns a cursor"""
with self._db_lock:
if self._conn is None: if self._conn is None:
self._conn = sqlite3.connect(self.filename, self._conn = sqlite3.connect(self.filename)
check_same_thread=False)
return self._conn.cursor() return self._conn.cursor()
def close(self): def close(self):
"""Closes the connection unless we're working in-memory""" """Closes the connection unless we're working in-memory"""
if self.filename != ':memory:': if self.filename != ':memory:':
with self._db_lock:
if self._conn is not None: if self._conn is not None:
self._conn.close() self._conn.close()
self._conn = None self._conn = None
@ -259,7 +248,6 @@ class SQLiteSession(MemorySession):
if not rows: if not rows:
return return
with self._db_lock:
self._cursor().executemany( self._cursor().executemany(
'insert or replace into entities values (?,?,?,?,?)', rows 'insert or replace into entities values (?,?,?,?,?)', rows
) )
@ -302,7 +290,6 @@ class SQLiteSession(MemorySession):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
with self._db_lock:
self._cursor().execute( self._cursor().execute(
'insert or replace into sent_files values (?,?,?,?,?)', ( 'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size, md5_digest, file_size,

View File

@ -1,11 +1,9 @@
import asyncio
import logging import logging
import os import os
from asyncio import Lock
from datetime import timedelta
import platform import platform
import threading
from datetime import timedelta, datetime
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Lock
from time import sleep
from . import version, utils from . import version, utils
from .crypto import rsa from .crypto import rsa
from .errors import ( from .errors import (
@ -70,8 +68,6 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
use_ipv6=False, use_ipv6=False,
proxy=None, proxy=None,
update_workers=None,
spawn_read_thread=False,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
loop=None, loop=None,
device_model=None, device_model=None,
@ -95,6 +91,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
self._loop = loop if loop else asyncio.get_event_loop()
# ':' in session.server_address is True if it's an IPv6 address # ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or if (not session.server_address or
(':' in session.server_address) != use_ipv6): (':' in session.server_address) != use_ipv6):
@ -112,13 +110,15 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new # that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always # temporary connection. The connection on this one is always
# kept open so Telegram can send us updates. # kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection( self._sender = MtProtoSender(
mode=connection_mode, proxy=proxy, timeout=timeout self.session,
)) Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
# Two threads may be calling reconnect() when the connection is lost, # Two coroutines may be calling reconnect() when the connection
# we only want one to actually perform the reconnection. # is lost, we only want one to actually perform the reconnection.
self._reconnect_lock = Lock() self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate # Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively # them all the time since generating a new key is a relatively
@ -127,7 +127,7 @@ class TelegramBareClient:
# This member will process updates if enabled. # This member will process updates if enabled.
# One may change self.updates.enabled at any later point. # One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=update_workers) self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
system = platform.uname() system = platform.uname()
@ -153,34 +153,25 @@ class TelegramBareClient:
# See https://core.telegram.org/api/invoking#saving-client-info. # See https://core.telegram.org/api/invoking#saving-client-info.
self._first_request = True self._first_request = True
# Constantly read for results and updates from within the main client, self._recv_loop = None
# if the user has left enabled such option. self._ping_loop = None
self._spawn_read_thread = spawn_read_thread self._state_loop = None
self._recv_thread = None self._idling = asyncio.Event()
self._idling = threading.Event()
# Default PingRequest delay # Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1) self._ping_delay = timedelta(minutes=1)
# Also have another delay for GetStateRequest. # Also have another delay for GetStateRequest.
# #
# If the connection is kept alive for long without invoking any # If the connection is kept alive for long without invoking any
# high level request the server simply stops sending updates. # high level request the server simply stops sending updates.
# TODO maybe we can have ._last_request instead if any req works? # TODO maybe we can have ._last_request instead if any req works?
self._last_state = datetime.now()
self._state_delay = timedelta(hours=1) self._state_delay = timedelta(hours=1)
# Some errors are known but there's nothing we can do from the
# background thread. If any of these happens, call .disconnect(),
# and raise them next time .invoke() is tried to be called.
self._background_error = None
# endregion # endregion
# region Connecting # region Connecting
def connect(self, _sync_updates=True): async def connect(self, _sync_updates=True):
"""Connects to the Telegram servers, executing authentication if """Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which not the same as authenticating the desired user itself, which
@ -197,10 +188,8 @@ class TelegramBareClient:
__log__.info('Connecting to %s:%d...', __log__.info('Connecting to %s:%d...',
self.session.server_address, self.session.port) self.session.server_address, self.session.port)
self._background_error = None # Clear previous errors
try: try:
self._sender.connect() await self._sender.connect()
__log__.info('Connection success!') __log__.info('Connection success!')
# Connection was successful! Try syncing the update state # Connection was successful! Try syncing the update state
@ -210,12 +199,12 @@ class TelegramBareClient:
self._user_connected = True self._user_connected = True
if self._authorized is None and _sync_updates: if self._authorized is None and _sync_updates:
try: try:
self.sync_updates() await self.sync_updates()
self._set_connected_and_authorized() await self._set_connected_and_authorized()
except UnauthorizedError: except UnauthorizedError:
self._authorized = False self._authorized = False
elif self._authorized: elif self._authorized:
self._set_connected_and_authorized() await self._set_connected_and_authorized()
return True return True
@ -224,7 +213,7 @@ class TelegramBareClient:
__log__.warning('Connection failed, got unexpected type with ID ' __log__.warning('Connection failed, got unexpected type with ID '
'%s. Migrating?', hex(e.invalid_constructor_id)) '%s. Migrating?', hex(e.invalid_constructor_id))
self.disconnect() self.disconnect()
return self.connect(_sync_updates=_sync_updates) return await self.connect(_sync_updates=_sync_updates)
except (RPCError, ConnectionError) as e: except (RPCError, ConnectionError) as e:
# Probably errors from the previous session, ignore them # Probably errors from the previous session, ignore them
@ -249,24 +238,15 @@ class TelegramBareClient:
)) ))
def disconnect(self): def disconnect(self):
"""Disconnects from the Telegram server """Disconnects from the Telegram server"""
and stops all the spawned threads"""
__log__.info('Disconnecting...') __log__.info('Disconnecting...')
self._user_connected = False # This will stop recv_thread's loop self._user_connected = False
__log__.debug('Stopping all workers...')
self.updates.stop_workers()
# This will trigger a "ConnectionResetError" on the recv_thread,
# which won't attempt reconnecting as ._user_connected is False.
__log__.debug('Disconnecting the socket...')
self._sender.disconnect() self._sender.disconnect()
# TODO Shall we clear the _exported_sessions, or may be reused? # TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again self._first_request = True # On reconnect it will be first again
self.session.close() self.session.close()
def _reconnect(self, new_dc=None): async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made """If 'new_dc' is not set, only a call to .connect() will be made
since it's assumed that the connection has been lost and the since it's assumed that the connection has been lost and the
library is reconnecting. library is reconnecting.
@ -276,13 +256,14 @@ class TelegramBareClient:
connects to the new data center. connects to the new data center.
""" """
if new_dc is None: if new_dc is None:
# Assume we are disconnected due to some error, so connect again
try:
if self.is_connected(): if self.is_connected():
__log__.info('Reconnection aborted: already connected') __log__.info('Reconnection aborted: already connected')
return True return True
try:
__log__.info('Attempting reconnection...') __log__.info('Attempting reconnection...')
return self.connect() return await self.connect()
except ConnectionResetError as e: except ConnectionResetError as e:
__log__.warning('Reconnection failed due to %s', e) __log__.warning('Reconnection failed due to %s', e)
return False return False
@ -290,7 +271,7 @@ class TelegramBareClient:
# Since we're reconnecting possibly due to a UserMigrateError, # Since we're reconnecting possibly due to a UserMigrateError,
# we need to first know the Data Centers we can connect to. Do # we need to first know the Data Centers we can connect to. Do
# that before disconnecting. # that before disconnecting.
dc = self._get_dc(new_dc) dc = await self._get_dc(new_dc)
__log__.info('Reconnecting to new data center %s', dc) __log__.info('Reconnecting to new data center %s', dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) self.session.set_dc(dc.id, dc.ip_address, dc.port)
@ -299,7 +280,7 @@ class TelegramBareClient:
self.session.auth_key = None self.session.auth_key = None
self.session.save() self.session.save()
self.disconnect() self.disconnect()
return self.connect() return await self.connect()
def set_proxy(self, proxy): def set_proxy(self, proxy):
"""Change the proxy used by the connections. """Change the proxy used by the connections.
@ -312,19 +293,15 @@ class TelegramBareClient:
# region Working with different connections/Data Centers # region Working with different connections/Data Centers
def _on_read_thread(self): async def _get_dc(self, dc_id, cdn=False):
return self._recv_thread is not None and \
threading.get_ident() == self._recv_thread.ident
def _get_dc(self, dc_id, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'""" """Gets the Data Center (DC) associated to 'dc_id'"""
if not TelegramBareClient._config: if not TelegramBareClient._config:
TelegramBareClient._config = self(GetConfigRequest()) TelegramBareClient._config = await self(GetConfigRequest())
try: try:
if cdn: if cdn:
# Ensure we have the latest keys for the CDNs # Ensure we have the latest keys for the CDNs
for pk in self(GetCdnConfigRequest()).public_keys: for pk in await (self(GetCdnConfigRequest())).public_keys:
rsa.add_key(pk.public_key) rsa.add_key(pk.public_key)
return next( return next(
@ -336,10 +313,10 @@ class TelegramBareClient:
raise raise
# New configuration, perhaps a new CDN was added? # New configuration, perhaps a new CDN was added?
TelegramBareClient._config = self(GetConfigRequest()) TelegramBareClient._config = await self(GetConfigRequest())
return self._get_dc(dc_id, cdn=cdn) return await self._get_dc(dc_id, cdn=cdn)
def _get_exported_client(self, dc_id): async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC. """Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time calling the method with a given dc_id, If it's the first time calling the method with a given dc_id,
@ -356,11 +333,11 @@ class TelegramBareClient:
# TODO Add a lock, don't allow two threads to create an auth key # TODO Add a lock, don't allow two threads to create an auth key
# (when calling .connect() if there wasn't a previous session). # (when calling .connect() if there wasn't a previous session).
# for the same data center. # for the same data center.
dc = self._get_dc(dc_id) dc = await self._get_dc(dc_id)
# Export the current authorization to the new DC. # Export the current authorization to the new DC.
__log__.info('Exporting authorization for data center %s', dc) __log__.info('Exporting authorization for data center %s', dc)
export_auth = self(ExportAuthorizationRequest(dc_id)) export_auth = await self(ExportAuthorizationRequest(dc_id))
# Create a temporary session for this IP address, which needs # Create a temporary session for this IP address, which needs
# to be different because each auth_key is unique per DC. # to be different because each auth_key is unique per DC.
@ -375,11 +352,12 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
client.connect(_sync_updates=False) await client.connect(_sync_updates=False)
if isinstance(export_auth, ExportedAuthorization): if isinstance(export_auth, ExportedAuthorization):
client(ImportAuthorizationRequest( await client(ImportAuthorizationRequest(
id=export_auth.id, bytes=export_auth.bytes id=export_auth.id, bytes=export_auth.bytes
)) ))
elif export_auth is not None: elif export_auth is not None:
@ -388,11 +366,11 @@ class TelegramBareClient:
client._authorized = True # We exported the auth, so we got auth client._authorized = True # We exported the auth, so we got auth
return client return client
def _get_cdn_client(self, cdn_redirect): async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs""" """Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port) session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
@ -401,7 +379,8 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
# This will make use of the new RSA keys for this specific CDN. # This will make use of the new RSA keys for this specific CDN.
@ -409,7 +388,7 @@ class TelegramBareClient:
# We won't be calling GetConfigRequest because it's only called # We won't be calling GetConfigRequest because it's only called
# when needed by ._get_dc, and also it's static so it's likely # when needed by ._get_dc, and also it's static so it's likely
# set already. Avoid invoking non-CDN methods by not syncing updates. # set already. Avoid invoking non-CDN methods by not syncing updates.
client.connect(_sync_updates=False) await client.connect(_sync_updates=False)
client._authorized = self._authorized client._authorized = self._authorized
return client return client
@ -417,7 +396,7 @@ class TelegramBareClient:
# region Invoking Telegram requests # region Invoking Telegram requests
def __call__(self, *requests, retries=5): async def __call__(self, *requests, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) a MTProtoRequest and returns (receives) its result.
The invoke will be retried up to 'retries' times before raising The invoke will be retried up to 'retries' times before raising
@ -427,11 +406,8 @@ class TelegramBareClient:
x.content_related for x in requests): x.content_related for x in requests):
raise TypeError('You can only invoke requests, not types!') raise TypeError('You can only invoke requests, not types!')
if self._background_error:
raise self._background_error
for request in requests: for request in requests:
request.resolve(self, utils) await request.resolve(self, utils)
# For logging purposes # For logging purposes
if len(requests) == 1: if len(requests) == 1:
@ -440,26 +416,23 @@ class TelegramBareClient:
which = '{} requests ({})'.format( which = '{} requests ({})'.format(
len(requests), [type(x).__name__ for x in requests]) len(requests), [type(x).__name__ for x in requests])
# Determine the sender to be used (main or a new connection)
__log__.debug('Invoking %s', which) __log__.debug('Invoking %s', which)
call_receive = \ call_receive = \
not self._idling.is_set() or self._reconnect_lock.locked() not self._idling.is_set() or self._reconnect_lock.locked()
for retry in range(retries): for retry in range(retries):
result = self._invoke(call_receive, *requests) result = await self._invoke(call_receive, retry, *requests)
if result is not None: if result is not None:
return result return result
__log__.warning('Invoking %s failed %d times, ' __log__.warning('Invoking %s failed %d times, '
'reconnecting and retrying', 'reconnecting and retrying',
[str(x) for x in requests], retry + 1) [str(x) for x in requests], retry + 1)
sleep(1)
# The ReadThread has priority when attempting reconnection, await asyncio.sleep(retry + 1, loop=self._loop)
# since this thread is constantly running while __call__ is
# only done sometimes. Here try connecting only once/retry.
if not self._reconnect_lock.locked(): if not self._reconnect_lock.locked():
with self._reconnect_lock: with await self._reconnect_lock:
self._reconnect() await self._reconnect()
raise RuntimeError('Number of retries reached 0 for {}.'.format( raise RuntimeError('Number of retries reached 0 for {}.'.format(
[type(x).__name__ for x in requests] [type(x).__name__ for x in requests]
@ -468,18 +441,17 @@ class TelegramBareClient:
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
def _invoke(self, call_receive, *requests): async def _invoke(self, call_receive, retry, *requests):
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
for x in requests: for x in requests:
x.confirm_received.clear()
x.rpc_error = None x.rpc_error = None
if not self.session.auth_key: if not self.session.auth_key:
__log__.info('Need to generate new auth key before invoking') __log__.info('Need to generate new auth key before invoking')
self._first_request = True self._first_request = True
self.session.auth_key, self.session.time_offset = \ self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(self._sender.connection) await authenticator.do_authentication(self._sender.connection)
if self._first_request: if self._first_request:
__log__.info('Initializing a new connection while invoking') __log__.info('Initializing a new connection while invoking')
@ -489,24 +461,21 @@ class TelegramBareClient:
# We need a SINGLE request (like GetConfig) to init conn. # We need a SINGLE request (like GetConfig) to init conn.
# Once that's done, the N original requests will be # Once that's done, the N original requests will be
# invoked. # invoked.
TelegramBareClient._config = self( TelegramBareClient._config = await self(
self._wrap_init_connection(GetConfigRequest()) self._wrap_init_connection(GetConfigRequest())
) )
self._sender.send(*requests) await self._sender.send(*requests)
if not call_receive: if not call_receive:
# TODO This will be slightly troublesome if we allow await asyncio.wait(
# switching between constant read or not on the fly. list(map(lambda x: x.confirm_received.wait(), requests)),
# Must also watch out for calling .read() from two places, timeout=self._sender.connection.get_timeout(),
# in which case a Lock would be required for .receive(). loop=self._loop
for x in requests:
x.confirm_received.wait(
self._sender.connection.get_timeout()
) )
else: else:
while not all(x.confirm_received.is_set() for x in requests): while not all(x.confirm_received.is_set() for x in requests):
self._sender.receive(update_state=self.updates) await self._sender.receive(update_state=self.updates)
except BrokenAuthKeyError: except BrokenAuthKeyError:
__log__.error('Authorization key seems broken and was invalid!') __log__.error('Authorization key seems broken and was invalid!')
@ -552,12 +521,8 @@ class TelegramBareClient:
except (PhoneMigrateError, NetworkMigrateError, except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e: UserMigrateError) as e:
# TODO What happens with the background thread here? await self._reconnect(new_dc=e.new_dc)
# For normal use cases, this won't happen, because this will only return await self._invoke(call_receive, retry, *requests)
# be on the very first connection (not authorized, not running),
# but may be an issue for people who actually travel?
self._reconnect(new_dc=e.new_dc)
return self._invoke(call_receive, *requests)
except ServerError as e: except ServerError as e:
# Telegram is having some issues, just retry # Telegram is having some issues, just retry
@ -568,7 +533,8 @@ class TelegramBareClient:
if e.seconds > self.session.flood_sleep_threshold | 0: if e.seconds > self.session.flood_sleep_threshold | 0:
raise raise
sleep(e.seconds) await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality # Some really basic functionality
@ -588,90 +554,69 @@ class TelegramBareClient:
# region Updates handling # region Updates handling
def sync_updates(self): async def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be """Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True, called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates. otherwise it should be called manually after enabling updates.
""" """
self.updates.process(self(GetStateRequest())) self.updates.process(await self(GetStateRequest()))
self._last_state = datetime.now()
# endregion # endregion
# region Constant read # Constant read
def _set_connected_and_authorized(self): # This is async so that the overrided version in TelegramClient can be
# async without problems.
async def _set_connected_and_authorized(self):
self._authorized = True self._authorized = True
self.updates.setup_workers() if self._recv_loop is None:
if self._spawn_read_thread and self._recv_thread is None: self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
self._recv_thread = threading.Thread( if self._ping_loop is None:
name='ReadThread', daemon=True, self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
target=self._recv_thread_impl if self._state_loop is None:
) self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop)
self._recv_thread.start()
def _signal_handler(self, signum, frame): async def _ping_loop_impl(self):
if self._user_connected: while self._user_connected:
self.disconnect() await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
else: await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
os._exit(1) self._ping_loop = None
def idle(self, stop_signals=(SIGINT, SIGTERM, SIGABRT)): async def _state_loop_impl(self):
""" while self._user_connected:
Idles the program by looping forever and listening for updates await asyncio.sleep(self._state_delay.seconds, loop=self._loop)
until one of the signals are received, which breaks the loop. await self._sender.send(GetStateRequest())
:param stop_signals: async def _recv_loop_impl(self):
Iterable containing signals from the signal module that will
be subscribed to TelegramClient.disconnect() (effectively
stopping the idle loop), which will be called on receiving one
of those signals.
:return:
"""
if self._spawn_read_thread and not self._on_read_thread():
raise RuntimeError('Can only idle if spawn_read_thread=False')
self._idling.set()
for sig in stop_signals:
signal(sig, self._signal_handler)
if self._on_read_thread():
__log__.info('Starting to wait for items from the network') __log__.info('Starting to wait for items from the network')
else: self._idling.set()
__log__.info('Idling to receive items from the network') need_reconnect = False
while self._user_connected: while self._user_connected:
try: try:
if datetime.now() > self._last_ping + self._ping_delay: if need_reconnect:
self._sender.send(PingRequest( __log__.info('Attempting reconnection from read loop')
int.from_bytes(os.urandom(8), 'big', signed=True) need_reconnect = False
)) with await self._reconnect_lock:
self._last_ping = datetime.now() while self._user_connected and not await self._reconnect():
# Retry forever, this is instant messaging
await asyncio.sleep(0.1, loop=self._loop)
if datetime.now() > self._last_state + self._state_delay:
self._sender.send(GetStateRequest())
self._last_state = datetime.now()
__log__.debug('Receiving items from the network...')
self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem
__log__.debug('Receiving items from the network timed out')
except ConnectionResetError:
if self._user_connected:
__log__.error('Connection was reset while receiving '
'items. Reconnecting')
with self._reconnect_lock:
while self._user_connected and not self._reconnect():
sleep(0.1) # Retry forever, this is instant messaging
if self.is_connected():
# Telegram seems to kick us every 1024 items received # Telegram seems to kick us every 1024 items received
# from the network not considering things like bad salt. # from the network not considering things like bad salt.
# We must execute some *high level* request (that's not # We must execute some *high level* request (that's not
# a ping) if we want to receive updates again. # a ping) if we want to receive updates again.
# TODO Test if getDifference works too (better alternative) # TODO Test if getDifference works too (better alternative)
self._sender.send(GetStateRequest()) await self._sender.send(GetStateRequest())
__log__.debug('Receiving items from the network...')
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
__log__.debug('Receiving items from the network timed out')
except ConnectionError:
need_reconnect = True
__log__.error('Connection was reset while receiving items')
await asyncio.sleep(1, loop=self._loop)
except: except:
self._idling.clear() self._idling.clear()
raise raise
@ -679,39 +624,4 @@ class TelegramBareClient:
self._idling.clear() self._idling.clear()
__log__.info('Connection closed by the user, not reading anymore') __log__.info('Connection closed by the user, not reading anymore')
# By using this approach, another thread will be
# created and started upon connection to constantly read
# from the other end. Otherwise, manual calls to .receive()
# must be performed. The MtProtoSender cannot be connected,
# or an error will be thrown.
#
# This way, sending and receiving will be completely independent.
def _recv_thread_impl(self):
# This thread is "idle" (only listening for updates), but also
# excepts everything unlike the manual idle because it should
# not crash.
while self._user_connected:
try:
self.idle(stop_signals=tuple())
except Exception as error:
__log__.exception('Unknown exception in the read thread! '
'Disconnecting and leaving it to main thread')
# Unknown exception, pass it to the main thread
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a proxy error"
self._recv_thread = None
# endregion # endregion

View File

@ -1,3 +1,4 @@
import asyncio
import getpass import getpass
import hashlib import hashlib
import io import io
@ -6,7 +7,6 @@ import logging
import os import os
import re import re
import sys import sys
import time
import warnings import warnings
from collections import OrderedDict, UserList from collections import OrderedDict, UserList
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -158,18 +158,16 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
use_ipv6=False, use_ipv6=False,
proxy=None, proxy=None,
update_workers=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
spawn_read_thread=True, loop=None,
**kwargs): **kwargs):
super().__init__( super().__init__(
session, api_id, api_hash, session, api_id, api_hash,
connection_mode=connection_mode, connection_mode=connection_mode,
use_ipv6=use_ipv6, use_ipv6=use_ipv6,
proxy=proxy, proxy=proxy,
update_workers=update_workers,
spawn_read_thread=spawn_read_thread,
timeout=timeout, timeout=timeout,
loop=loop,
**kwargs **kwargs
) )
@ -190,7 +188,7 @@ class TelegramClient(TelegramBareClient):
# region Authorization requests # region Authorization requests
def send_code_request(self, phone, force_sms=False): async def send_code_request(self, phone, force_sms=False):
""" """
Sends a code request to the specified phone number. Sends a code request to the specified phone number.
@ -208,7 +206,7 @@ class TelegramClient(TelegramBareClient):
phone_hash = self._phone_code_hash.get(phone) phone_hash = self._phone_code_hash.get(phone)
if not phone_hash: if not phone_hash:
result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone_code_hash[phone] = phone_hash = result.phone_code_hash self._phone_code_hash[phone] = phone_hash = result.phone_code_hash
else: else:
force_sms = True force_sms = True
@ -216,14 +214,15 @@ class TelegramClient(TelegramBareClient):
self._phone = phone self._phone = phone
if force_sms: if force_sms:
result = self(ResendCodeRequest(phone, phone_hash)) result = await self(ResendCodeRequest(phone, phone_hash))
self._phone_code_hash[phone] = result.phone_code_hash self._phone_code_hash[phone] = result.phone_code_hash
return result return result
def start(self, async def start(self,
phone=lambda: input('Please enter your phone: '), phone=lambda: input('Please enter your phone: '),
password=lambda: getpass.getpass('Please enter your password: '), password=lambda: getpass.getpass(
'Please enter your password: '),
bot_token=None, force_sms=False, code_callback=None, bot_token=None, force_sms=False, code_callback=None,
first_name='New User', last_name=''): first_name='New User', last_name=''):
""" """
@ -231,7 +230,7 @@ class TelegramClient(TelegramBareClient):
also taking into consideration that 2FA may be enabled in the account. also taking into consideration that 2FA may be enabled in the account.
Example usage: Example usage:
>>> client = TelegramClient(session, api_id, api_hash).start(phone) >>> client = await TelegramClient(session, api_id, api_hash).start(phone)
Please enter the code you received: 12345 Please enter the code you received: 12345
Please enter your password: ******* Please enter your password: *******
(You are now logged in) (You are now logged in)
@ -286,14 +285,14 @@ class TelegramClient(TelegramBareClient):
'must only provide one of either') 'must only provide one of either')
if not self.is_connected(): if not self.is_connected():
self.connect() await self.connect()
if self.is_user_authorized(): if self.is_user_authorized():
self._check_events_pending_resolve() self._check_events_pending_resolve()
return self return self
if bot_token: if bot_token:
self.sign_in(bot_token=bot_token) await self.sign_in(bot_token=bot_token)
return self return self
# Turn the callable into a valid phone number # Turn the callable into a valid phone number
@ -305,15 +304,15 @@ class TelegramClient(TelegramBareClient):
max_attempts = 3 max_attempts = 3
two_step_detected = False two_step_detected = False
sent_code = self.send_code_request(phone, force_sms=force_sms) sent_code = await self.send_code_request(phone, force_sms=force_sms)
sign_up = not sent_code.phone_registered sign_up = not sent_code.phone_registered
while attempts < max_attempts: while attempts < max_attempts:
try: try:
if sign_up: if sign_up:
me = self.sign_up(code_callback(), first_name, last_name) me = await self.sign_up(code_callback(), first_name, last_name)
else: else:
# Raises SessionPasswordNeededError if 2FA enabled # Raises SessionPasswordNeededError if 2FA enabled
me = self.sign_in(phone, code_callback()) me = await self.sign_in(phone, code_callback())
break break
except SessionPasswordNeededError: except SessionPasswordNeededError:
two_step_detected = True two_step_detected = True
@ -342,14 +341,14 @@ class TelegramClient(TelegramBareClient):
# TODO If callable given make it retry on invalid # TODO If callable given make it retry on invalid
if callable(password): if callable(password):
password = password() password = password()
me = self.sign_in(phone=phone, password=password) me = await self.sign_in(phone=phone, password=password)
# We won't reach here if any step failed (exit by exception) # We won't reach here if any step failed (exit by exception)
print('Signed in successfully as', utils.get_display_name(me)) print('Signed in successfully as', utils.get_display_name(me))
self._check_events_pending_resolve() self._check_events_pending_resolve()
return self return self
def sign_in(self, phone=None, code=None, async def sign_in(self, phone=None, code=None,
password=None, bot_token=None, phone_code_hash=None): password=None, bot_token=None, phone_code_hash=None):
""" """
Starts or completes the sign in process with the given phone number Starts or completes the sign in process with the given phone number
@ -385,7 +384,7 @@ class TelegramClient(TelegramBareClient):
return self.get_me() return self.get_me()
if phone and not code and not password: if phone and not code and not password:
return self.send_code_request(phone) return await self.send_code_request(phone)
elif code: elif code:
phone = utils.parse_phone(phone) or self._phone phone = utils.parse_phone(phone) or self._phone
phone_code_hash = \ phone_code_hash = \
@ -400,14 +399,14 @@ class TelegramClient(TelegramBareClient):
# May raise PhoneCodeEmptyError, PhoneCodeExpiredError, # May raise PhoneCodeEmptyError, PhoneCodeExpiredError,
# PhoneCodeHashEmptyError or PhoneCodeInvalidError. # PhoneCodeHashEmptyError or PhoneCodeInvalidError.
result = self(SignInRequest(phone, phone_code_hash, str(code))) result = await self(SignInRequest(phone, phone_code_hash, str(code)))
elif password: elif password:
salt = self(GetPasswordRequest()).current_salt salt = (await self(GetPasswordRequest())).current_salt
result = self(CheckPasswordRequest( result = await self(CheckPasswordRequest(
helpers.get_password_hash(password, salt) helpers.get_password_hash(password, salt)
)) ))
elif bot_token: elif bot_token:
result = self(ImportBotAuthorizationRequest( result = await self(ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token, flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash api_id=self.api_id, api_hash=self.api_hash
)) ))
@ -420,10 +419,10 @@ class TelegramClient(TelegramBareClient):
self._self_input_peer = utils.get_input_peer( self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False result.user, allow_self=False
) )
self._set_connected_and_authorized() await self._set_connected_and_authorized()
return result.user return result.user
def sign_up(self, code, first_name, last_name=''): async def sign_up(self, code, first_name, last_name=''):
""" """
Signs up to Telegram if you don't have an account yet. Signs up to Telegram if you don't have an account yet.
You must call .send_code_request(phone) first. You must call .send_code_request(phone) first.
@ -442,10 +441,10 @@ class TelegramClient(TelegramBareClient):
The new created user. The new created user.
""" """
if self.is_user_authorized(): if self.is_user_authorized():
self._check_events_pending_resolve() await self._check_events_pending_resolve()
return self.get_me() return await self.get_me()
result = self(SignUpRequest( result = await self(SignUpRequest(
phone_number=self._phone, phone_number=self._phone,
phone_code_hash=self._phone_code_hash.get(self._phone, ''), phone_code_hash=self._phone_code_hash.get(self._phone, ''),
phone_code=str(code), phone_code=str(code),
@ -456,10 +455,10 @@ class TelegramClient(TelegramBareClient):
self._self_input_peer = utils.get_input_peer( self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False result.user, allow_self=False
) )
self._set_connected_and_authorized() await self._set_connected_and_authorized()
return result.user return result.user
def log_out(self): async def log_out(self):
""" """
Logs out Telegram and deletes the current ``*.session`` file. Logs out Telegram and deletes the current ``*.session`` file.
@ -467,7 +466,7 @@ class TelegramClient(TelegramBareClient):
True if the operation was successful. True if the operation was successful.
""" """
try: try:
self(LogOutRequest()) await self(LogOutRequest())
except RPCError: except RPCError:
return False return False
@ -475,7 +474,7 @@ class TelegramClient(TelegramBareClient):
self.session.delete() self.session.delete()
return True return True
def get_me(self, input_peer=False): async def get_me(self, input_peer=False):
""" """
Gets "me" (the self user) which is currently authenticated, Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated). or None if the request fails (hence, not authenticated).
@ -491,9 +490,8 @@ class TelegramClient(TelegramBareClient):
""" """
if input_peer and self._self_input_peer: if input_peer and self._self_input_peer:
return self._self_input_peer return self._self_input_peer
try: try:
me = self(GetUsersRequest([InputUserSelf()]))[0] me = (await self(GetUsersRequest([InputUserSelf()])))[0]
if not self._self_input_peer: if not self._self_input_peer:
self._self_input_peer = utils.get_input_peer( self._self_input_peer = utils.get_input_peer(
me, allow_self=False me, allow_self=False
@ -507,7 +505,7 @@ class TelegramClient(TelegramBareClient):
# region Dialogs ("chats") requests # region Dialogs ("chats") requests
def get_dialogs(self, limit=10, offset_date=None, offset_id=0, async def get_dialogs(self, limit=10, offset_date=None, offset_id=0,
offset_peer=InputPeerEmpty()): offset_peer=InputPeerEmpty()):
""" """
Gets N "dialogs" (open "chats" or conversations with other people). Gets N "dialogs" (open "chats" or conversations with other people).
@ -535,7 +533,7 @@ class TelegramClient(TelegramBareClient):
limit = float('inf') if limit is None else int(limit) limit = float('inf') if limit is None else int(limit)
if limit == 0: if limit == 0:
# Special case, get a single dialog and determine count # Special case, get a single dialog and determine count
dialogs = self(GetDialogsRequest( dialogs = await self(GetDialogsRequest(
offset_date=offset_date, offset_date=offset_date,
offset_id=offset_id, offset_id=offset_id,
offset_peer=offset_peer, offset_peer=offset_peer,
@ -549,7 +547,7 @@ class TelegramClient(TelegramBareClient):
dialogs = OrderedDict() # Use peer id as identifier to avoid dupes dialogs = OrderedDict() # Use peer id as identifier to avoid dupes
while len(dialogs) < limit: while len(dialogs) < limit:
real_limit = min(limit - len(dialogs), 100) real_limit = min(limit - len(dialogs), 100)
r = self(GetDialogsRequest( r = await self(GetDialogsRequest(
offset_date=offset_date, offset_date=offset_date,
offset_id=offset_id, offset_id=offset_id,
offset_peer=offset_peer, offset_peer=offset_peer,
@ -580,7 +578,7 @@ class TelegramClient(TelegramBareClient):
dialogs.total = total_count dialogs.total = total_count
return dialogs return dialogs
def get_drafts(self): # TODO: Ability to provide a `filter` async def get_drafts(self): # TODO: Ability to provide a `filter`
""" """
Gets all open draft messages. Gets all open draft messages.
@ -589,7 +587,7 @@ class TelegramClient(TelegramBareClient):
You can call ``draft.set_message('text')`` to change the message, You can call ``draft.set_message('text')`` to change the message,
or delete it through :meth:`draft.delete()`. or delete it through :meth:`draft.delete()`.
""" """
response = self(GetAllDraftsRequest()) response = await self(GetAllDraftsRequest())
self.session.process_entities(response) self.session.process_entities(response)
self.session.generate_sequence(response.seq) self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates] drafts = [Draft._from_update(self, u) for u in response.updates]
@ -636,7 +634,7 @@ class TelegramClient(TelegramBareClient):
if request.id == update.message.id: if request.id == update.message.id:
return update.message return update.message
def _parse_message_text(self, message, parse_mode): async def _parse_message_text(self, message, parse_mode):
""" """
Returns a (parsed message, entities) tuple depending on parse_mode. Returns a (parsed message, entities) tuple depending on parse_mode.
""" """
@ -657,7 +655,7 @@ class TelegramClient(TelegramBareClient):
if m: if m:
try: try:
msg_entities[i] = InputMessageEntityMentionName( msg_entities[i] = InputMessageEntityMentionName(
e.offset, e.length, self.get_input_entity( e.offset, e.length, await self.get_input_entity(
int(m.group(1)) if m.group(1) else e.url int(m.group(1)) if m.group(1) else e.url
) )
) )
@ -667,8 +665,8 @@ class TelegramClient(TelegramBareClient):
return message, msg_entities return message, msg_entities
def send_message(self, entity, message, reply_to=None, parse_mode='md', async def send_message(self, entity, message, reply_to=None,
link_preview=True): parse_mode='md', link_preview=True):
""" """
Sends the given message to the specified entity (user/chat/channel). Sends the given message to the specified entity (user/chat/channel).
@ -695,11 +693,12 @@ class TelegramClient(TelegramBareClient):
Returns: Returns:
the sent message the sent message
""" """
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if isinstance(message, Message): if isinstance(message, Message):
if (message.media if (message.media
and not isinstance(message.media, MessageMediaWebPage)): and not isinstance(message.media, MessageMediaWebPage)):
return self.send_file(entity, message.media) return await self.send_file(entity, message.media)
if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id): if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id):
reply_id = message.reply_to_msg_id reply_id = message.reply_to_msg_id
@ -716,7 +715,7 @@ class TelegramClient(TelegramBareClient):
) )
message = message.message message = message.message
else: else:
message, msg_ent = self._parse_message_text(message, parse_mode) message, msg_ent = await self._parse_message_text(message, parse_mode)
request = SendMessageRequest( request = SendMessageRequest(
peer=entity, peer=entity,
message=message, message=message,
@ -725,7 +724,8 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_message_id(reply_to) reply_to_msg_id=self._get_message_id(reply_to)
) )
result = self(request) result = await self(request)
if isinstance(result, UpdateShortSentMessage): if isinstance(result, UpdateShortSentMessage):
return Message( return Message(
id=result.id, id=result.id,
@ -739,8 +739,8 @@ class TelegramClient(TelegramBareClient):
return self._get_response_message(request, result) return self._get_response_message(request, result)
def edit_message(self, entity, message_id, message=None, parse_mode='md', async def edit_message(self, entity, message_id, message=None,
link_preview=True): parse_mode='md', link_preview=True):
""" """
Edits the given message ID (to change its contents or disable preview). Edits the given message ID (to change its contents or disable preview).
@ -773,18 +773,18 @@ class TelegramClient(TelegramBareClient):
Returns: Returns:
the edited message the edited message
""" """
message, msg_entities = self._parse_message_text(message, parse_mode) message, msg_entities = await self._parse_message_text(message, parse_mode)
request = EditMessageRequest( request = EditMessageRequest(
peer=self.get_input_entity(entity), peer=await self.get_input_entity(entity),
id=self._get_message_id(message_id), id=self._get_message_id(message_id),
message=message, message=message,
no_webpage=not link_preview, no_webpage=not link_preview,
entities=msg_entities entities=msg_entities
) )
result = self(request) result = await self(request)
return self._get_response_message(request, result) return self._get_response_message(request, result)
def delete_messages(self, entity, message_ids, revoke=True): async def delete_messages(self, entity, message_ids, revoke=True):
""" """
Deletes a message from a chat, optionally "for everyone". Deletes a message from a chat, optionally "for everyone".
@ -812,16 +812,16 @@ class TelegramClient(TelegramBareClient):
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids] message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
if entity is None: if entity is None:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
if isinstance(entity, InputPeerChannel): if isinstance(entity, InputPeerChannel):
return self(channels.DeleteMessagesRequest(entity, message_ids)) return await self(channels.DeleteMessagesRequest(entity, message_ids))
else: else:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
def get_message_history(self, entity, limit=20, offset_date=None, async def get_message_history(self, entity, limit=20, offset_date=None,
offset_id=0, max_id=0, min_id=0, add_offset=0, offset_id=0, max_id=0, min_id=0, add_offset=0,
batch_size=100, wait_time=None): batch_size=100, wait_time=None):
""" """
@ -884,13 +884,12 @@ class TelegramClient(TelegramBareClient):
second is the default for this limit (or above). You may need second is the default for this limit (or above). You may need
an higher limit, so you're free to set the ``batch_size`` that an higher limit, so you're free to set the ``batch_size`` that
you think may be good. you think may be good.
""" """
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
limit = float('inf') if limit is None else int(limit) limit = float('inf') if limit is None else int(limit)
if limit == 0: if limit == 0:
# No messages, but we still need to know the total message count # No messages, but we still need to know the total message count
result = self(GetHistoryRequest( result = await self(GetHistoryRequest(
peer=entity, limit=1, peer=entity, limit=1,
offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0 offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0
)) ))
@ -906,7 +905,7 @@ class TelegramClient(TelegramBareClient):
while len(messages) < limit: while len(messages) < limit:
# Telegram has a hard limit of 100 # Telegram has a hard limit of 100
real_limit = min(limit - len(messages), batch_size) real_limit = min(limit - len(messages), batch_size)
result = self(GetHistoryRequest( result = await self(GetHistoryRequest(
peer=entity, peer=entity,
limit=real_limit, limit=real_limit,
offset_date=offset_date, offset_date=offset_date,
@ -931,7 +930,7 @@ class TelegramClient(TelegramBareClient):
offset_id = result.messages[-1].id offset_id = result.messages[-1].id
offset_date = result.messages[-1].date offset_date = result.messages[-1].date
time.sleep(wait_time) await asyncio.sleep(wait_time)
# Add a few extra attributes to the Message to make it friendlier. # Add a few extra attributes to the Message to make it friendlier.
messages.total = total_messages messages.total = total_messages
@ -959,7 +958,7 @@ class TelegramClient(TelegramBareClient):
return messages return messages
def send_read_acknowledge(self, entity, message=None, max_id=None, async def send_read_acknowledge(self, entity, message=None, max_id=None,
clear_mentions=False): clear_mentions=False):
""" """
Sends a "read acknowledge" (i.e., notifying the given peer that we've Sends a "read acknowledge" (i.e., notifying the given peer that we've
@ -993,17 +992,17 @@ class TelegramClient(TelegramBareClient):
raise ValueError( raise ValueError(
'Either a message list or a max_id must be provided.') 'Either a message list or a max_id must be provided.')
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
if clear_mentions: if clear_mentions:
self(ReadMentionsRequest(entity)) await self(ReadMentionsRequest(entity))
if max_id is None: if max_id is None:
return True return True
if max_id is not None: if max_id is not None:
if isinstance(entity, InputPeerChannel): if isinstance(entity, InputPeerChannel):
return self(channels.ReadHistoryRequest(entity, max_id=max_id)) return await self(channels.ReadHistoryRequest(entity, max_id=max_id))
else: else:
return self(messages.ReadHistoryRequest(entity, max_id=max_id)) return await self(messages.ReadHistoryRequest(entity, max_id=max_id))
return False return False
@ -1025,7 +1024,7 @@ class TelegramClient(TelegramBareClient):
raise TypeError('Invalid message type: {}'.format(type(message))) raise TypeError('Invalid message type: {}'.format(type(message)))
def get_participants(self, entity, limit=None, search='', async def get_participants(self, entity, limit=None, search='',
aggressive=False): aggressive=False):
""" """
Gets the list of participants from the specified entity. Gets the list of participants from the specified entity.
@ -1054,12 +1053,12 @@ class TelegramClient(TelegramBareClient):
A list of participants with an additional .total variable on the A list of participants with an additional .total variable on the
list indicating the total amount of members in this group/channel. list indicating the total amount of members in this group/channel.
""" """
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
limit = float('inf') if limit is None else int(limit) limit = float('inf') if limit is None else int(limit)
if isinstance(entity, InputPeerChannel): if isinstance(entity, InputPeerChannel):
total = self(GetFullChannelRequest( total = (await self(GetFullChannelRequest(
entity entity
)).full_chat.participants_count ))).full_chat.participants_count
all_participants = {} all_participants = {}
if total > 10000 and aggressive: if total > 10000 and aggressive:
@ -1091,9 +1090,9 @@ class TelegramClient(TelegramBareClient):
break break
if len(requests) == 1: if len(requests) == 1:
results = (self(requests[0]),) results = (await self(requests[0]),)
else: else:
results = self(*requests) results = await self(*requests)
for i in reversed(range(len(requests))): for i in reversed(range(len(requests))):
participants = results[i] participants = results[i]
if not participants.users: if not participants.users:
@ -1111,7 +1110,7 @@ class TelegramClient(TelegramBareClient):
users = UserList(values) users = UserList(values)
users.total = total users.total = total
elif isinstance(entity, InputPeerChat): elif isinstance(entity, InputPeerChat):
users = self(GetFullChatRequest(entity.chat_id)).users users = (await self(GetFullChatRequest(entity.chat_id))).users
if len(users) > limit: if len(users) > limit:
users = users[:limit] users = users[:limit]
users = UserList(users) users = UserList(users)
@ -1125,7 +1124,7 @@ class TelegramClient(TelegramBareClient):
# region Uploading files # region Uploading files
def send_file(self, entity, file, caption=None, async def send_file(self, entity, file, caption=None,
force_document=False, progress_callback=None, force_document=False, progress_callback=None,
reply_to=None, reply_to=None,
attributes=None, attributes=None,
@ -1201,14 +1200,14 @@ class TelegramClient(TelegramBareClient):
# Convert to tuple so we can iterate several times # Convert to tuple so we can iterate several times
file = tuple(x for x in file) file = tuple(x for x in file)
if all(utils.is_image(x) for x in file): if all(utils.is_image(x) for x in file):
return self._send_album( return await self._send_album(
entity, file, caption=caption, entity, file, caption=caption,
progress_callback=progress_callback, reply_to=reply_to, progress_callback=progress_callback, reply_to=reply_to,
parse_mode=parse_mode parse_mode=parse_mode
) )
# Not all are images, so send all the files one by one # Not all are images, so send all the files one by one
return [ return [
self.send_file( await self.send_file(
entity, x, allow_cache=False, entity, x, allow_cache=False,
caption=caption, force_document=force_document, caption=caption, force_document=force_document,
progress_callback=progress_callback, reply_to=reply_to, progress_callback=progress_callback, reply_to=reply_to,
@ -1216,7 +1215,7 @@ class TelegramClient(TelegramBareClient):
) for x in file ) for x in file
] ]
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
reply_to = self._get_message_id(reply_to) reply_to = self._get_message_id(reply_to)
caption, msg_entities = self._parse_message_text(caption, parse_mode) caption, msg_entities = self._parse_message_text(caption, parse_mode)
@ -1233,11 +1232,11 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=reply_to, reply_to_msg_id=reply_to,
message=caption, message=caption,
entities=msg_entities) entities=msg_entities)
return self._get_response_message(request, self(request)) return self._get_response_message(request, await self(request))
as_image = utils.is_image(file) and not force_document as_image = utils.is_image(file) and not force_document
use_cache = InputPhoto if as_image else InputDocument use_cache = InputPhoto if as_image else InputDocument
file_handle = self.upload_file( file_handle = await self.upload_file(
file, progress_callback=progress_callback, file, progress_callback=progress_callback,
use_cache=use_cache if allow_cache else None use_cache=use_cache if allow_cache else None
) )
@ -1314,7 +1313,7 @@ class TelegramClient(TelegramBareClient):
input_kw = {} input_kw = {}
if thumb: if thumb:
input_kw['thumb'] = self.upload_file(thumb) input_kw['thumb'] = await self.upload_file(thumb)
media = InputMediaUploadedDocument( media = InputMediaUploadedDocument(
file=file_handle, file=file_handle,
@ -1327,7 +1326,7 @@ class TelegramClient(TelegramBareClient):
# send the media message to the desired entity. # send the media message to the desired entity.
request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to, request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to,
message=caption, entities=msg_entities) message=caption, entities=msg_entities)
msg = self._get_response_message(request, self(request)) msg = self._get_response_message(request, await self(request))
if msg and isinstance(file_handle, InputSizedFile): if msg and isinstance(file_handle, InputSizedFile):
# There was a response message and we didn't use cached # There was a response message and we didn't use cached
# version, so cache whatever we just sent to the database. # version, so cache whatever we just sent to the database.
@ -1345,7 +1344,7 @@ class TelegramClient(TelegramBareClient):
kwargs['is_voice_note'] = True kwargs['is_voice_note'] = True
return self.send_file(*args, **kwargs) return self.send_file(*args, **kwargs)
def _send_album(self, entity, files, caption=None, async def _send_album(self, entity, files, caption=None,
progress_callback=None, reply_to=None, progress_callback=None, reply_to=None,
parse_mode='md'): parse_mode='md'):
"""Specialized version of .send_file for albums""" """Specialized version of .send_file for albums"""
@ -1354,7 +1353,7 @@ class TelegramClient(TelegramBareClient):
# we need to produce right now to send albums (uploadMedia), and # we need to produce right now to send albums (uploadMedia), and
# cache only makes a difference for documents where the user may # cache only makes a difference for documents where the user may
# want the attributes used on them to change. # want the attributes used on them to change.
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
if not utils.is_list_like(caption): if not utils.is_list_like(caption):
caption = (caption,) caption = (caption,)
captions = [ captions = [
@ -1367,11 +1366,11 @@ class TelegramClient(TelegramBareClient):
media = [] media = []
for file in files: for file in files:
# fh will either be InputPhoto or a modified InputFile # fh will either be InputPhoto or a modified InputFile
fh = self.upload_file(file, use_cache=InputPhoto) fh = await self.upload_file(file, use_cache=InputPhoto)
if not isinstance(fh, InputPhoto): if not isinstance(fh, InputPhoto):
input_photo = utils.get_input_photo(self(UploadMediaRequest( input_photo = utils.get_input_photo((await self(UploadMediaRequest(
entity, media=InputMediaUploadedPhoto(fh) entity, media=InputMediaUploadedPhoto(fh)
)).photo) ))).photo)
self.session.cache_file(fh.md5, fh.size, input_photo) self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo fh = input_photo
@ -1383,7 +1382,7 @@ class TelegramClient(TelegramBareClient):
entities=msg_entities)) entities=msg_entities))
# Now we can construct the multi-media request # Now we can construct the multi-media request
result = self(SendMultiMediaRequest( result = await self(SendMultiMediaRequest(
entity, reply_to_msg_id=reply_to, multi_media=media entity, reply_to_msg_id=reply_to, multi_media=media
)) ))
return [ return [
@ -1392,12 +1391,8 @@ class TelegramClient(TelegramBareClient):
if isinstance(update, UpdateMessageID) if isinstance(update, UpdateMessageID)
] ]
def upload_file(self, async def upload_file(self, file, part_size_kb=None, file_name=None,
file, use_cache=None, progress_callback=None):
part_size_kb=None,
file_name=None,
use_cache=None,
progress_callback=None):
""" """
Uploads the specified file and returns a handle (an instance of Uploads the specified file and returns a handle (an instance of
InputFile or InputFileBig, as required) which can be later used InputFile or InputFileBig, as required) which can be later used
@ -1510,7 +1505,7 @@ class TelegramClient(TelegramBareClient):
else: else:
request = SaveFilePartRequest(file_id, part_index, part) request = SaveFilePartRequest(file_id, part_index, part)
result = self(request) result = await self(request)
if result: if result:
__log__.debug('Uploaded %d/%d', part_index + 1, __log__.debug('Uploaded %d/%d', part_index + 1,
part_count) part_count)
@ -1531,7 +1526,7 @@ class TelegramClient(TelegramBareClient):
# region Downloading media requests # region Downloading media requests
def download_profile_photo(self, entity, file=None, download_big=True): async def download_profile_photo(self, entity, file=None, download_big=True):
""" """
Downloads the profile photo of the given entity (user/chat/channel). Downloads the profile photo of the given entity (user/chat/channel).
@ -1565,12 +1560,12 @@ class TelegramClient(TelegramBareClient):
# The hexadecimal numbers above are simply: # The hexadecimal numbers above are simply:
# hex(crc32(x.encode('ascii'))) for x in # hex(crc32(x.encode('ascii'))) for x in
# ('User', 'Chat', 'UserFull', 'ChatFull') # ('User', 'Chat', 'UserFull', 'ChatFull')
entity = self.get_entity(entity) entity = await self.get_entity(entity)
if not hasattr(entity, 'photo'): if not hasattr(entity, 'photo'):
# Special case: may be a ChatFull with photo:Photo # Special case: may be a ChatFull with photo:Photo
# This is different from a normal UserProfilePhoto and Chat # This is different from a normal UserProfilePhoto and Chat
if hasattr(entity, 'chat_photo'): if hasattr(entity, 'chat_photo'):
return self._download_photo( return await self._download_photo(
entity.chat_photo, file, entity.chat_photo, file,
date=None, progress_callback=None date=None, progress_callback=None
) )
@ -1595,7 +1590,7 @@ class TelegramClient(TelegramBareClient):
# Download the media with the largest size input file location # Download the media with the largest size input file location
try: try:
self.download_file( await self.download_file(
InputFileLocation( InputFileLocation(
volume_id=photo_location.volume_id, volume_id=photo_location.volume_id,
local_id=photo_location.local_id, local_id=photo_location.local_id,
@ -1606,10 +1601,10 @@ class TelegramClient(TelegramBareClient):
except LocationInvalidError: except LocationInvalidError:
# See issue #500, Android app fails as of v4.6.0 (1155). # See issue #500, Android app fails as of v4.6.0 (1155).
# The fix seems to be using the full channel chat photo. # The fix seems to be using the full channel chat photo.
ie = self.get_input_entity(entity) ie = await self.get_input_entity(entity)
if isinstance(ie, InputPeerChannel): if isinstance(ie, InputPeerChannel):
full = self(GetFullChannelRequest(ie)) full = await self(GetFullChannelRequest(ie))
return self._download_photo( return await self._download_photo(
full.full_chat.chat_photo, file, full.full_chat.chat_photo, file,
date=None, progress_callback=None date=None, progress_callback=None
) )
@ -1618,7 +1613,7 @@ class TelegramClient(TelegramBareClient):
return None return None
return file return file
def download_media(self, message, file=None, progress_callback=None): async def download_media(self, message, file=None, progress_callback=None):
""" """
Downloads the given media, or the media from a specified Message. Downloads the given media, or the media from a specified Message.
@ -1646,19 +1641,19 @@ class TelegramClient(TelegramBareClient):
media = message media = message
if isinstance(media, (MessageMediaPhoto, Photo)): if isinstance(media, (MessageMediaPhoto, Photo)):
return self._download_photo( return await self._download_photo(
media, file, date, progress_callback media, file, date, progress_callback
) )
elif isinstance(media, (MessageMediaDocument, Document)): elif isinstance(media, (MessageMediaDocument, Document)):
return self._download_document( return await self._download_document(
media, file, date, progress_callback media, file, date, progress_callback
) )
elif isinstance(media, MessageMediaContact): elif isinstance(media, MessageMediaContact):
return self._download_contact( return await self._download_contact(
media, file media, file
) )
def _download_photo(self, photo, file, date, progress_callback): async def _download_photo(self, photo, file, date, progress_callback):
"""Specialized version of .download_media() for photos""" """Specialized version of .download_media() for photos"""
# Determine the photo and its largest size # Determine the photo and its largest size
if isinstance(photo, MessageMediaPhoto): if isinstance(photo, MessageMediaPhoto):
@ -1673,7 +1668,7 @@ class TelegramClient(TelegramBareClient):
file = self._get_proper_filename(file, 'photo', '.jpg', date=date) file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
# Download the media with the largest size input file location # Download the media with the largest size input file location
self.download_file( await self.download_file(
InputFileLocation( InputFileLocation(
volume_id=largest_size.volume_id, volume_id=largest_size.volume_id,
local_id=largest_size.local_id, local_id=largest_size.local_id,
@ -1685,7 +1680,7 @@ class TelegramClient(TelegramBareClient):
) )
return file return file
def _download_document(self, document, file, date, progress_callback): async def _download_document(self, document, file, date, progress_callback):
"""Specialized version of .download_media() for documents""" """Specialized version of .download_media() for documents"""
if isinstance(document, MessageMediaDocument): if isinstance(document, MessageMediaDocument):
document = document.document document = document.document
@ -1718,7 +1713,7 @@ class TelegramClient(TelegramBareClient):
date=date, possible_names=possible_names date=date, possible_names=possible_names
) )
self.download_file( await self.download_file(
InputDocumentFileLocation( InputDocumentFileLocation(
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
@ -1825,12 +1820,8 @@ class TelegramClient(TelegramBareClient):
return result return result
i += 1 i += 1
def download_file(self, async def download_file(self, input_location, file, part_size_kb=None,
input_location, file_size=None, progress_callback=None):
file,
part_size_kb=None,
file_size=None,
progress_callback=None):
""" """
Downloads the given input location to a file. Downloads the given input location to a file.
@ -1889,23 +1880,24 @@ class TelegramClient(TelegramBareClient):
while True: while True:
try: try:
if cdn_decrypter: if cdn_decrypter:
result = cdn_decrypter.get_file() result = await cdn_decrypter.get_file()
else: else:
result = client(GetFileRequest( result = await client(GetFileRequest(
input_location, offset, part_size input_location, offset, part_size
)) ))
if isinstance(result, FileCdnRedirect): if isinstance(result, FileCdnRedirect):
__log__.info('File lives in a CDN') __log__.info('File lives in a CDN')
cdn_decrypter, result = \ cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter( await CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result), client,
await self._get_cdn_client(result),
result result
) )
except FileMigrateError as e: except FileMigrateError as e:
__log__.info('File lives in another DC') __log__.info('File lives in another DC')
client = self._get_exported_client(e.new_dc) client = await self._get_exported_client(e.new_dc)
continue continue
offset += part_size offset += part_size
@ -1947,25 +1939,25 @@ class TelegramClient(TelegramBareClient):
The event builder class or instance to be used, The event builder class or instance to be used,
for instance ``events.NewMessage``. for instance ``events.NewMessage``.
""" """
def decorator(f): async def decorator(f):
self.add_event_handler(f, event) await self.add_event_handler(f, event)
return f return f
return decorator return decorator
def _check_events_pending_resolve(self): async def _check_events_pending_resolve(self):
if self._events_pending_resolve: if self._events_pending_resolve:
for event in self._events_pending_resolve: for event in self._events_pending_resolve:
event.resolve(self) await event.resolve(self)
self._events_pending_resolve.clear() self._events_pending_resolve.clear()
def _on_handler(self, update): async def _on_handler(self, update):
for builder, callback in self._event_builders: for builder, callback in self._event_builders:
event = builder.build(update) event = builder.build(update)
if event: if event:
event._client = self event._client = self
try: try:
callback(event) await callback(event)
except events.StopPropagation: except events.StopPropagation:
__log__.debug( __log__.debug(
"Event handler '{}' stopped chain of " "Event handler '{}' stopped chain of "
@ -1974,7 +1966,7 @@ class TelegramClient(TelegramBareClient):
) )
break break
def add_event_handler(self, callback, event=None): async def add_event_handler(self, callback, event=None):
""" """
Registers the given callback to be called on the specified event. Registers the given callback to be called on the specified event.
@ -1989,12 +1981,6 @@ class TelegramClient(TelegramBareClient):
If left unspecified, ``events.Raw`` (the ``Update`` objects If left unspecified, ``events.Raw`` (the ``Update`` objects
with no further processing) will be passed instead. with no further processing) will be passed instead.
""" """
if self.updates.workers is None:
warnings.warn(
"You have not setup any workers, so you won't receive updates."
" Pass update_workers=1 when creating the TelegramClient,"
" or set client.self.updates.workers = 1"
)
self.updates.handler = self._on_handler self.updates.handler = self._on_handler
if isinstance(event, type): if isinstance(event, type):
@ -2003,8 +1989,8 @@ class TelegramClient(TelegramBareClient):
event = events.Raw() event = events.Raw()
if self.is_user_authorized(): if self.is_user_authorized():
event.resolve(self) await event.resolve(self)
self._check_events_pending_resolve() await self._check_events_pending_resolve()
else: else:
self._events_pending_resolve.append(event) self._events_pending_resolve.append(event)
@ -2031,11 +2017,11 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier # region Small utilities to make users' life easier
def _set_connected_and_authorized(self): async def _set_connected_and_authorized(self):
super()._set_connected_and_authorized() await super()._set_connected_and_authorized()
self._check_events_pending_resolve() await self._check_events_pending_resolve()
def get_entity(self, entity): async def get_entity(self, entity):
""" """
Turns the given entity into a valid Telegram user or chat. Turns the given entity into a valid Telegram user or chat.
@ -2069,7 +2055,7 @@ class TelegramClient(TelegramBareClient):
# input channels (get channels) to get the most entities # input channels (get channels) to get the most entities
# in the less amount of calls possible. # in the less amount of calls possible.
inputs = [ inputs = [
x if isinstance(x, str) else self.get_input_entity(x) x if isinstance(x, str) else await self.get_input_entity(x)
for x in entity for x in entity
] ]
users = [x for x in inputs if isinstance(x, InputPeerUser)] users = [x for x in inputs if isinstance(x, InputPeerUser)]
@ -2080,12 +2066,12 @@ class TelegramClient(TelegramBareClient):
tmp = [] tmp = []
while users: while users:
curr, users = users[:200], users[200:] curr, users = users[:200], users[200:]
tmp.extend(self(GetUsersRequest(curr))) tmp.extend(await self(GetUsersRequest(curr)))
users = tmp users = tmp
if chats: # TODO Handle chats slice? if chats: # TODO Handle chats slice?
chats = self(GetChatsRequest(chats)).chats chats = (await self(GetChatsRequest(chats))).chats
if channels: if channels:
channels = self(GetChannelsRequest(channels)).chats channels = (await self(GetChannelsRequest(channels))).chats
# Merge users, chats and channels into a single dictionary # Merge users, chats and channels into a single dictionary
id_entity = { id_entity = {
@ -2098,33 +2084,31 @@ class TelegramClient(TelegramBareClient):
# the amount of ResolveUsername calls, it would fail to catch # the amount of ResolveUsername calls, it would fail to catch
# username changes. # username changes.
result = [ result = [
self._get_entity_from_string(x) if isinstance(x, str) await self._get_entity_from_string(x) if isinstance(x, str)
else id_entity[utils.get_peer_id(x)] else id_entity[utils.get_peer_id(x)]
for x in inputs for x in inputs
] ]
return result[0] if single else result return result[0] if single else result
def _get_entity_from_string(self, string): async def _get_entity_from_string(self, string):
""" """
Gets a full entity from the given string, which may be a phone or Gets a full entity from the given string, which may be a phone or
an username, and processes all the found entities on the session. an username, and processes all the found entities on the session.
The string may also be a user link, or a channel/chat invite link. The string may also be a user link, or a channel/chat invite link.
This method has the side effect of adding the found users to the This method has the side effect of adding the found users to the
session database, so it can be queried later without API calls, session database, so it can be queried later without API calls,
if this option is enabled on the session. if this option is enabled on the session.
Returns the found entity, or raises TypeError if not found. Returns the found entity, or raises TypeError if not found.
""" """
phone = utils.parse_phone(string) phone = utils.parse_phone(string)
if phone: if phone:
for user in self(GetContactsRequest(0)).users: for user in (await self(GetContactsRequest(0))).users:
if user.phone == phone: if user.phone == phone:
return user return user
else: else:
username, is_join_chat = utils.parse_username(string) username, is_join_chat = utils.parse_username(string)
if is_join_chat: if is_join_chat:
invite = self(CheckChatInviteRequest(username)) invite = await self(CheckChatInviteRequest(username))
if isinstance(invite, ChatInvite): if isinstance(invite, ChatInvite):
raise ValueError( raise ValueError(
'Cannot get entity from a channel ' 'Cannot get entity from a channel '
@ -2134,14 +2118,15 @@ class TelegramClient(TelegramBareClient):
return invite.chat return invite.chat
elif username: elif username:
if username in ('me', 'self'): if username in ('me', 'self'):
return self.get_me() return await self.get_me()
result = self(ResolveUsernameRequest(username)) result = await self(ResolveUsernameRequest(username))
for entity in itertools.chain(result.users, result.chats): for entity in itertools.chain(result.users, result.chats):
if entity.username.lower() == username: if entity.username.lower() == username:
return entity return entity
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return self.get_entity(self.session.get_input_entity(string)) return await self.get_entity(
self.session.get_input_entity(string))
except ValueError: except ValueError:
pass pass
@ -2149,24 +2134,20 @@ class TelegramClient(TelegramBareClient):
'Cannot turn "{}" into any entity (user or chat)'.format(string) 'Cannot turn "{}" into any entity (user or chat)'.format(string)
) )
def get_input_entity(self, peer): async def get_input_entity(self, peer):
""" """
Turns the given peer into its input entity version. Most requests Turns the given peer into its input entity version. Most requests
use this kind of InputUser, InputChat and so on, so this is the use this kind of InputUser, InputChat and so on, so this is the
most suitable call to make for those cases. most suitable call to make for those cases.
entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`): entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`):
The integer ID of an user or otherwise either of a The integer ID of an user or otherwise either of a
``PeerUser``, ``PeerChat`` or ``PeerChannel``, for ``PeerUser``, ``PeerChat`` or ``PeerChannel``, for
which to get its ``Input*`` version. which to get its ``Input*`` version.
If this ``Peer`` hasn't been seen before by the library, the top If this ``Peer`` hasn't been seen before by the library, the top
dialogs will be loaded and their entities saved to the session dialogs will be loaded and their entities saved to the session
file (unless this feature was disabled explicitly). file (unless this feature was disabled explicitly).
If in the end the access hash required for the peer was not found, If in the end the access hash required for the peer was not found,
a ValueError will be raised. a ValueError will be raised.
Returns: Returns:
``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``. ``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``.
""" """
@ -2179,7 +2160,7 @@ class TelegramClient(TelegramBareClient):
if isinstance(peer, str): if isinstance(peer, str):
if peer in ('me', 'self'): if peer in ('me', 'self'):
return InputPeerSelf() return InputPeerSelf()
return utils.get_input_peer(self._get_entity_from_string(peer)) return utils.get_input_peer(await self._get_entity_from_string(peer))
if isinstance(peer, int): if isinstance(peer, int):
peer, kind = utils.resolve_id(peer) peer, kind = utils.resolve_id(peer)
@ -2206,7 +2187,7 @@ class TelegramClient(TelegramBareClient):
limit=100 limit=100
) )
while True: while True:
result = self(req) result = await self(req)
entities = {} entities = {}
for x in itertools.chain(result.users, result.chats): for x in itertools.chain(result.users, result.chats):
x_id = utils.get_peer_id(x) x_id = utils.get_peer_id(x)
@ -2222,7 +2203,7 @@ class TelegramClient(TelegramBareClient):
req.offset_peer = entities[utils.get_peer_id( req.offset_peer = entities[utils.get_peer_id(
result.dialogs[-1].peer result.dialogs[-1].peer
)] )]
time.sleep(1) asyncio.sleep(1)
raise TypeError( raise TypeError(
'Could not find the input entity corresponding to "{}". ' 'Could not find the input entity corresponding to "{}". '

View File

@ -26,9 +26,9 @@ class Dialog:
self.draft = Draft(client, dialog.peer, dialog.draft) self.draft = Draft(client, dialog.peer, dialog.draft)
def send_message(self, *args, **kwargs): async def send_message(self, *args, **kwargs):
""" """
Sends a message to this dialog. This is just a wrapper around Sends a message to this dialog. This is just a wrapper around
client.send_message(dialog.input_entity, *args, **kwargs). client.send_message(dialog.input_entity, *args, **kwargs).
""" """
return self._client.send_message(self.input_entity, *args, **kwargs) return await self._client.send_message(self.input_entity, *args, **kwargs)

View File

@ -31,14 +31,14 @@ class Draft:
return cls(client=client, peer=update.peer, draft=update.draft) return cls(client=client, peer=update.peer, draft=update.draft)
@property @property
def entity(self): async def entity(self):
return self._client.get_entity(self._peer) return await self._client.get_entity(self._peer)
@property @property
def input_entity(self): async def input_entity(self):
return self._client.get_input_entity(self._peer) return await self._client.get_input_entity(self._peer)
def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None): async def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
""" """
Changes the draft message on the Telegram servers. The changes are Changes the draft message on the Telegram servers. The changes are
reflected in this object. Changing only individual attributes like for reflected in this object. Changing only individual attributes like for
@ -58,7 +58,7 @@ class Draft:
:param list entities: A list of formatting entities :param list entities: A list of formatting entities
:return bool: ``True`` on success :return bool: ``True`` on success
""" """
result = self._client(SaveDraftRequest( result = await self._client(SaveDraftRequest(
peer=self._peer, peer=self._peer,
message=text, message=text,
no_webpage=no_webpage, no_webpage=no_webpage,
@ -74,9 +74,9 @@ class Draft:
return result return result
def delete(self): async def delete(self):
""" """
Deletes this draft Deletes this draft
:return bool: ``True`` on success :return bool: ``True`` on success
""" """
return self.set_message(text='') return await self.set_message(text='')

View File

@ -1,11 +1,10 @@
import struct import struct
from datetime import datetime, date from datetime import datetime, date
from threading import Event
class TLObject: class TLObject:
def __init__(self): def __init__(self):
self.confirm_received = Event() self.confirm_received = None
self.rpc_error = None self.rpc_error = None
self.result = None self.result = None
@ -157,7 +156,7 @@ class TLObject:
return TLObject.pretty_format(self, indent=0) return TLObject.pretty_format(self, indent=0)
# These should be overrode # These should be overrode
def resolve(self, client, utils): async def resolve(self, client, utils):
pass pass
def to_dict(self): def to_dict(self):

View File

@ -1,9 +1,8 @@
import logging import logging
import pickle import pickle
import asyncio
from collections import deque from collections import deque
from queue import Queue, Empty
from datetime import datetime from datetime import datetime
from threading import RLock, Thread
from .tl import types as tl from .tl import types as tl
@ -16,106 +15,21 @@ class UpdateState:
""" """
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None): def __init__(self, loop=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke the self.handler.
"""
self._workers = workers
self._worker_threads = []
self.handler = None self.handler = None
self._updates_lock = RLock() self._loop = loop if loop else asyncio.get_event_loop()
self._updates = Queue()
# https://core.telegram.org/api/updates # https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self): def handle_update(self, update):
"""Returns True if a call to .poll() won't lock""" if self.handler:
return not self._updates.empty() asyncio.ensure_future(self.handler(update), loop=self._loop)
def poll(self, timeout=None):
"""
Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
try:
return self._updates.get(timeout=timeout)
except Empty:
return None
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
if n is None:
self.stop_workers()
else:
self._workers = n
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""
Waits for all the worker threads to stop.
"""
# Put dummy ``None`` objects so that they don't need to timeout.
n = self._workers
self._workers = None
if n:
with self._updates_lock:
for _ in range(n):
self._updates.put(None)
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while self._workers is not None:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
if update and self.handler:
self.handler(update)
except StopIteration:
break
except:
# We don't want to crash a worker thread due to any reason
__log__.exception('Unhandled exception on worker %d', wid)
def process(self, update): def process(self, update):
"""Processes an update object. This method is normally called by """Processes an update object. This method is normally called by
the library itself. the library itself.
""" """
if self._workers is None:
return # No processing needs to be done if nobody's working
with self._updates_lock:
if isinstance(update, tl.updates.State): if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state') __log__.debug('Saved new updates state')
self._state = update self._state = update
@ -128,13 +42,13 @@ class UpdateState:
# 1000 updates, the only duplicates received were users going # 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports. # online or offline. We can trust the server until new reports.
if isinstance(update, tl.UpdateShort): if isinstance(update, tl.UpdateShort):
self._updates.put(update.update) self.handle_update(update.update)
# Expand "Updates" into "Update", and pass these to callbacks. # Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we # Since .users and .chats have already been processed, we
# don't need to care about those either. # don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
for u in update.updates: for u in update.updates:
self._updates.put(u) self.handle_update(u)
# TODO Handle "tl.UpdatesTooLong" # TODO Handle "tl.UpdatesTooLong"
else: else:
self._updates.put(update) self.handle_update(update)

View File

@ -11,9 +11,9 @@ AUTO_GEN_NOTICE = \
AUTO_CASTS = { AUTO_CASTS = {
'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))', 'InputPeer': 'utils.get_input_peer(await client.get_input_entity({}))',
'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))', 'InputChannel': 'utils.get_input_channel(await client.get_input_entity({}))',
'InputUser': 'utils.get_input_user(client.get_input_entity({}))', 'InputUser': 'utils.get_input_user(await client.get_input_entity({}))',
'InputMedia': 'utils.get_input_media({})', 'InputMedia': 'utils.get_input_media({})',
'InputPhoto': 'utils.get_input_photo({})' 'InputPhoto': 'utils.get_input_photo({})'
} }
@ -289,7 +289,7 @@ class TLGenerator:
# Write the resolve(self, client, utils) method # Write the resolve(self, client, utils) method
if any(arg.type in AUTO_CASTS for arg in args): if any(arg.type in AUTO_CASTS for arg in args):
builder.writeln('def resolve(self, client, utils):') builder.writeln('async def resolve(self, client, utils):')
for arg in args: for arg in args:
ac = AUTO_CASTS.get(arg.type, None) ac = AUTO_CASTS.get(arg.type, None)
if ac: if ac: