Merge remote-tracking branch 'tulir/asyncio' into asyncio

This commit is contained in:
Lonami Exo 2018-03-03 17:03:27 +01:00
commit 563d731c95
17 changed files with 644 additions and 785 deletions

View File

@ -30,6 +30,7 @@ Creating a client
.. code:: python
import asyncio
from telethon import TelegramClient
# These example values won't work. You must get your own api_id and
@ -38,22 +39,28 @@ Creating a client
api_hash = '0123456789abcdef0123456789abcdef'
client = TelegramClient('session_name', api_id, api_hash)
client.start()
async def main():
await client.start()
asyncio.get_event_loop().run_until_complete(main())
Doing stuff
-----------
Note that this assumes you're inside an "async def" method. Check out the
`Python documentation <https://docs.python.org/3/library/asyncio-dev.html>`_
if you're new with ``asyncio``.
.. code:: python
print(client.get_me().stringify())
client.send_message('username', 'Hello! Talking to you from Telethon')
client.send_file('username', '/home/myself/Pictures/holidays.jpg')
await client.send_message('username', 'Hello! Talking to you from Telethon')
await client.send_file('username', '/home/myself/Pictures/holidays.jpg')
client.download_profile_photo('me')
messages = client.get_message_history('username')
client.download_media(messages[0])
await client.download_profile_photo('me')
messages = await client.get_message_history('username')
await client.download_media(messages[0])
Next steps
@ -61,5 +68,7 @@ Next steps
Do you like how Telethon looks? Check out
`Read The Docs <http://telethon.rtfd.io/>`_
for a more in-depth explanation, with examples,
troubleshooting issues, and more useful information.
for a more in-depth explanation, with examples, troubleshooting issues,
and more useful information. Note that the examples there are written for
the threaded version, not the one using asyncio. However, you just need to
await every remote call.

View File

@ -30,7 +30,7 @@ class CdnDecrypter:
self.cdn_file_hashes = cdn_file_hashes
@staticmethod
def prepare_decrypter(client, cdn_client, cdn_redirect):
async def prepare_decrypter(client, cdn_client, cdn_redirect):
"""
Prepares a new CDN decrypter.
@ -52,14 +52,14 @@ class CdnDecrypter:
cdn_aes, cdn_redirect.cdn_file_hashes
)
cdn_file = cdn_client(GetCdnFileRequest(
cdn_file = await cdn_client(GetCdnFileRequest(
file_token=cdn_redirect.file_token,
offset=cdn_redirect.cdn_file_hashes[0].offset,
limit=cdn_redirect.cdn_file_hashes[0].limit
))
if isinstance(cdn_file, CdnFileReuploadNeeded):
# We need to use the original client here
client(ReuploadCdnFileRequest(
await client(ReuploadCdnFileRequest(
file_token=cdn_redirect.file_token,
request_token=cdn_file.request_token
))
@ -73,7 +73,7 @@ class CdnDecrypter:
return decrypter, cdn_file
def get_file(self):
async def get_file(self):
"""
Calls GetCdnFileRequest and decrypts its bytes.
Also ensures that the file hasn't been tampered.
@ -82,7 +82,7 @@ class CdnDecrypter:
"""
if self.cdn_file_hashes:
cdn_hash = self.cdn_file_hashes.pop(0)
cdn_file = self.client(GetCdnFileRequest(
cdn_file = await self.client(GetCdnFileRequest(
self.file_token, cdn_hash.offset, cdn_hash.limit
))
cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes)

View File

@ -9,7 +9,7 @@ from ..extensions import markdown
from ..tl import types, functions
def _into_id_set(client, chats):
async def _into_id_set(client, chats):
"""Helper util to turn the input chat or chats into a set of IDs."""
if chats is None:
return None
@ -19,9 +19,9 @@ def _into_id_set(client, chats):
result = set()
for chat in chats:
chat = client.get_input_entity(chat)
chat = await client.get_input_entity(chat)
if isinstance(chat, types.InputPeerSelf):
chat = client.get_me(input_peer=True)
chat = await client.get_me(input_peer=True)
result.add(utils.get_peer_id(chat))
return result
@ -48,10 +48,10 @@ class _EventBuilder(abc.ABC):
def build(self, update):
"""Builds an event for the given update if possible, or returns None"""
def resolve(self, client):
async def resolve(self, client):
"""Helper method to allow event builders to be resolved before usage"""
self.chats = _into_id_set(client, self.chats)
self._self_id = client.get_me(input_peer=True).user_id
self.chats = await _into_id_set(client, self.chats)
self._self_id = (await client.get_me(input_peer=True)).user_id
def _filter_event(self, event):
"""
@ -86,7 +86,7 @@ class _EventCommon(abc.ABC):
)
self.is_channel = isinstance(chat_peer, types.PeerChannel)
def _get_input_entity(self, msg_id, entity_id, chat=None):
async def _get_input_entity(self, msg_id, entity_id, chat=None):
"""
Helper function to call GetMessages on the give msg_id and
return the input entity whose ID is the given entity ID.
@ -95,11 +95,11 @@ class _EventCommon(abc.ABC):
"""
try:
if isinstance(chat, types.InputPeerChannel):
result = self._client(
result = await self._client(
functions.channels.GetMessagesRequest(chat, [msg_id])
)
else:
result = self._client(
result = await self._client(
functions.messages.GetMessagesRequest([msg_id])
)
except RPCError:
@ -113,7 +113,7 @@ class _EventCommon(abc.ABC):
return utils.get_input_peer(entity)
@property
def input_chat(self):
async def input_chat(self):
"""
The (:obj:`InputPeer`) (group, megagroup or channel) on which
the event occurred. This doesn't have the title or anything,
@ -125,7 +125,7 @@ class _EventCommon(abc.ABC):
if self._input_chat is None and self._chat_peer is not None:
try:
self._input_chat = self._client.get_input_entity(
self._input_chat = await self._client.get_input_entity(
self._chat_peer
)
except (ValueError, TypeError):
@ -134,22 +134,22 @@ class _EventCommon(abc.ABC):
# TODO For channels, getDifference? Maybe looking
# in the dialogs (which is already done) is enough.
if self._message_id is not None:
self._input_chat = self._get_input_entity(
self._input_chat = await self._get_input_entity(
self._message_id,
utils.get_peer_id(self._chat_peer)
)
return self._input_chat
@property
def chat(self):
async def chat(self):
"""
The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which
the event occurred. This property will make an API call the first time
to get the most up to date version of the chat, so use with care as
there is no caching besides local caching yet.
"""
if self._chat is None and self.input_chat:
self._chat = self._client.get_entity(self._input_chat)
if self._chat is None and await self.input_chat:
self._chat = await self._client.get_entity(await self._input_chat)
return self._chat
@ -157,7 +157,7 @@ class Raw(_EventBuilder):
"""
Represents a raw event. The event is the update itself.
"""
def resolve(self, client):
async def resolve(self, client):
pass
def build(self, update):
@ -304,23 +304,23 @@ class NewMessage(_EventBuilder):
self.is_reply = bool(message.reply_to_msg_id)
self._reply_message = None
def respond(self, *args, **kwargs):
async def respond(self, *args, **kwargs):
"""
Responds to the message (not as a reply). This is a shorthand for
``client.send_message(event.chat, ...)``.
"""
return self._client.send_message(self.input_chat, *args, **kwargs)
return await self._client.send_message(await self.input_chat, *args, **kwargs)
def reply(self, *args, **kwargs):
async def reply(self, *args, **kwargs):
"""
Replies to the message (as a reply). This is a shorthand for
``client.send_message(event.chat, ..., reply_to=event.message.id)``.
"""
return self._client.send_message(self.input_chat,
reply_to=self.message.id,
*args, **kwargs)
return await self._client.send_message(await self.input_chat,
reply_to=self.message.id,
*args, **kwargs)
def edit(self, *args, **kwargs):
async def edit(self, *args, **kwargs):
"""
Edits the message iff it's outgoing. This is a shorthand for
``client.edit_message(event.chat, event.message, ...)``.
@ -331,27 +331,27 @@ class NewMessage(_EventBuilder):
if not self.message.out:
if not isinstance(self.message.to_id, types.PeerUser):
return None
me = self._client.get_me(input_peer=True)
me = await self._client.get_me(input_peer=True)
if self.message.to_id.user_id != me.user_id:
return None
return self._client.edit_message(self.input_chat,
self.message,
*args, **kwargs)
return await self._client.edit_message(await self.input_chat,
self.message,
*args, **kwargs)
def delete(self, *args, **kwargs):
async def delete(self, *args, **kwargs):
"""
Deletes the message. You're responsible for checking whether you
have the permission to do so, or to except the error otherwise.
This is a shorthand for
``client.delete_messages(event.chat, event.message, ...)``.
"""
return self._client.delete_messages(self.input_chat,
[self.message],
*args, **kwargs)
return await self._client.delete_messages(await self.input_chat,
[self.message],
*args, **kwargs)
@property
def input_sender(self):
async def input_sender(self):
"""
This (:obj:`InputPeer`) is the input version of the user who
sent the message. Similarly to ``input_chat``, this doesn't have
@ -365,21 +365,21 @@ class NewMessage(_EventBuilder):
return None
try:
self._input_sender = self._client.get_input_entity(
self._input_sender = await self._client.get_input_entity(
self.message.from_id
)
except (ValueError, TypeError):
# We can rely on self.input_chat for this
self._input_sender = self._get_input_entity(
self._input_sender = await self._get_input_entity(
self.message.id,
self.message.from_id,
chat=self.input_chat
chat=await self.input_chat
)
return self._input_sender
@property
def sender(self):
async def sender(self):
"""
This (:obj:`User`) will make an API call the first time to get
the most up to date version of the sender, so use with care as
@ -387,8 +387,8 @@ class NewMessage(_EventBuilder):
``input_sender`` needs to be available (often the case).
"""
if self._sender is None and self.input_sender:
self._sender = self._client.get_entity(self._input_sender)
if self._sender is None and await self.input_sender:
self._sender = await self._client.get_entity(self._input_sender)
return self._sender
@property
@ -411,7 +411,7 @@ class NewMessage(_EventBuilder):
return self.message.message
@property
def reply_message(self):
async def reply_message(self):
"""
This (:obj:`Message`, optional) will make an API call the first
time to get the full ``Message`` object that one was replying to,
@ -421,12 +421,12 @@ class NewMessage(_EventBuilder):
return None
if self._reply_message is None:
if isinstance(self.input_chat, types.InputPeerChannel):
r = self._client(functions.channels.GetMessagesRequest(
self.input_chat, [self.message.reply_to_msg_id]
if isinstance(await self.input_chat, types.InputPeerChannel):
r = await self._client(functions.channels.GetMessagesRequest(
await self.input_chat, [self.message.reply_to_msg_id]
))
else:
r = self._client(functions.messages.GetMessagesRequest(
r = await self._client(functions.messages.GetMessagesRequest(
[self.message.reply_to_msg_id]
))
if not isinstance(r, types.messages.MessagesNotModified):
@ -610,7 +610,7 @@ class ChatAction(_EventBuilder):
self.new_title = new_title
@property
def pinned_message(self):
async def pinned_message(self):
"""
If ``new_pin`` is ``True``, this returns the (:obj:`Message`)
object that was pinned.
@ -618,8 +618,8 @@ class ChatAction(_EventBuilder):
if self._pinned_message == 0:
return None
if isinstance(self._pinned_message, int) and self.input_chat:
r = self._client(functions.channels.GetMessagesRequest(
if isinstance(self._pinned_message, int) and await self.input_chat:
r = await self._client(functions.channels.GetMessagesRequest(
self._input_chat, [self._pinned_message]
))
try:
@ -635,25 +635,25 @@ class ChatAction(_EventBuilder):
return self._pinned_message
@property
def added_by(self):
async def added_by(self):
"""
The user who added ``users``, if applicable (``None`` otherwise).
"""
if self._added_by and not isinstance(self._added_by, types.User):
self._added_by = self._client.get_entity(self._added_by)
self._added_by = await self._client.get_entity(self._added_by)
return self._added_by
@property
def kicked_by(self):
async def kicked_by(self):
"""
The user who kicked ``users``, if applicable (``None`` otherwise).
"""
if self._kicked_by and not isinstance(self._kicked_by, types.User):
self._kicked_by = self._client.get_entity(self._kicked_by)
self._kicked_by = await self._client.get_entity(self._kicked_by)
return self._kicked_by
@property
def user(self):
async def user(self):
"""
The single user that takes part in this action (e.g. joined).
@ -661,12 +661,12 @@ class ChatAction(_EventBuilder):
there is no user taking part.
"""
try:
return next(self.users)
return next(await self.users)
except (StopIteration, TypeError):
return None
@property
def users(self):
async def users(self):
"""
A list of users that take part in this action (e.g. joined).
@ -675,7 +675,7 @@ class ChatAction(_EventBuilder):
"""
if self._users is None and self._user_peers:
try:
self._users = self._client.get_entity(self._user_peers)
self._users = await self._client.get_entity(self._user_peers)
except (TypeError, ValueError):
self._users = []

View File

@ -194,9 +194,9 @@ def get_inner_text(text, entity):
"""
if isinstance(entity, TLObject):
entity = (entity,)
multiple = True
else:
multiple = False
else:
multiple = True
text = _add_surrogate(text)
result = []

View File

@ -1,13 +1,20 @@
"""
This module holds a rough implementation of the C# TCP client.
"""
# Python rough implementation of a C# TCP client
import asyncio
import errno
import logging
import socket
import time
from datetime import timedelta
from io import BytesIO, BufferedWriter
from threading import Lock
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
try:
import socks
@ -25,7 +32,7 @@ __log__ = logging.getLogger(__name__)
class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies."""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
"""
Initializes the TCP client.
@ -34,7 +41,7 @@ class TcpClient:
"""
self.proxy = proxy
self._socket = None
self._closing_lock = Lock()
self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
@ -54,9 +61,9 @@ class TcpClient:
else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout)
self._socket.setblocking(False)
def connect(self, ip, port):
async def connect(self, ip, port):
"""
Tries connecting forever to IP:port unless an OSError is raised.
@ -72,11 +79,15 @@ class TcpClient:
timeout = 1
while True:
try:
while not self._socket:
if not self._socket:
self._recreate_socket(mode)
self._socket.connect(address)
await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e:
__log__.info('OSError "%s" raised while connecting', e)
# Stop retrying to connect if proxy connection error occurred
@ -90,7 +101,7 @@ class TcpClient:
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
time.sleep(timeout)
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
else:
raise
@ -103,21 +114,16 @@ class TcpClient:
def close(self):
"""Closes the connection."""
if self._closing_lock.locked():
# Already closing, no need to close again (avoid None.close())
return
try:
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
with self._closing_lock:
try:
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
def write(self, data):
async def write(self, data):
"""
Writes (sends) the specified bytes to the connected peer.
@ -126,11 +132,13 @@ class TcpClient:
if self._socket is None:
self._raise_connection_reset(None)
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try:
self._socket.sendall(data)
except socket.timeout as e:
await asyncio.wait_for(
self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
__log__.debug('socket.timeout "%s" while writing data', e)
raise TimeoutError() from e
except ConnectionError as e:
@ -143,7 +151,7 @@ class TcpClient:
else:
raise
def read(self, size):
async def read(self, size):
"""
Reads (receives) a whole block of size bytes from the connected peer.
@ -153,13 +161,18 @@ class TcpClient:
if self._socket is None:
self._raise_connection_reset(None)
# TODO Remove the timeout from this method, always use previous one
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
while bytes_left != 0:
try:
partial = self._socket.recv(bytes_left)
except socket.timeout as e:
if self._socket is None:
self._raise_connection_reset()
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
# These are somewhat common if the server has nothing
# to send to us, so use a lower logging priority.
__log__.debug('socket.timeout "%s" while reading data', e)
@ -168,7 +181,7 @@ class TcpClient:
__log__.info('ConnectionError "%s" while reading data', e)
self._raise_connection_reset(e)
except OSError as e:
if e.errno != errno.EBADF and self._closing_lock.locked():
if e.errno != errno.EBADF:
# Ignore bad file descriptor while closing
__log__.info('OSError "%s" while reading data', e)
@ -190,5 +203,56 @@ class TcpClient:
def _raise_connection_reset(self, original):
"""Disconnects the client and raises ConnectionResetError."""
self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.')\
from original
raise ConnectionResetError('The server has closed the connection.') from original
# due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n):
fut = self._loop.create_future()
self._sock_recv(fut, None, n)
return fut
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled():
return
try:
data = self._socket.recv(n)
except (BlockingIOError, InterruptedError):
fd = self._socket.fileno()
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(data)
def sock_sendall(self, data):
fut = self._loop.create_future()
if data:
self._sock_sendall(fut, None, data)
else:
fut.set_result(None)
return fut
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled():
return
try:
n = self._socket.send(data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = self._socket.fileno()
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)

View File

@ -21,7 +21,7 @@ from ..tl.functions import (
)
def do_authentication(connection, retries=5):
async def do_authentication(connection, retries=5):
"""
Performs the authentication steps on the given connection.
Raises an error if all attempts fail.
@ -36,14 +36,14 @@ def do_authentication(connection, retries=5):
last_error = None
while retries:
try:
return _do_authentication(connection)
return await _do_authentication(connection)
except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e
retries -= 1
raise last_error
def _do_authentication(connection):
async def _do_authentication(connection):
"""
Executes the authentication process with the Telegram servers.
@ -56,8 +56,8 @@ def _do_authentication(connection):
req_pq_request = ReqPqMultiRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
)
sender.send(bytes(req_pq_request))
with BinaryReader(sender.receive()) as reader:
await sender.send(bytes(req_pq_request))
with BinaryReader(await sender.receive()) as reader:
req_pq_request.on_response(reader)
res_pq = req_pq_request.result
@ -104,10 +104,10 @@ def _do_authentication(connection):
public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text
)
sender.send(bytes(req_dh_params))
await sender.send(bytes(req_dh_params))
# Step 2 response: DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
req_dh_params.on_response(reader)
server_dh_params = req_dh_params.result
@ -174,10 +174,10 @@ def _do_authentication(connection):
server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted,
)
sender.send(bytes(set_client_dh))
await sender.send(bytes(set_client_dh))
# Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
set_client_dh.on_response(reader)
dh_gen = set_client_dh.result

View File

@ -2,18 +2,17 @@
This module holds both the Connection class and the ConnectionMode enum,
which specifies the protocol to be used by the Connection.
"""
import errno
import logging
import os
import struct
from datetime import timedelta
from zlib import crc32
from enum import Enum
import errno
from zlib import crc32
from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError
from ..extensions import TcpClient
__log__ = logging.getLogger(__name__)
@ -52,7 +51,7 @@ class Connection:
"""
def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)):
proxy=None, timeout=timedelta(seconds=5), loop=None):
"""
Initializes a new connection.
@ -65,7 +64,7 @@ class Connection:
self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout)
self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages
if mode == ConnectionMode.TCP_FULL:
@ -89,7 +88,7 @@ class Connection:
setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain)
def connect(self, ip, port):
async def connect(self, ip, port):
"""
Estabilishes a connection to IP:port.
@ -97,7 +96,7 @@ class Connection:
:param port: the port to connect to.
"""
try:
self.conn.connect(ip, port)
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
@ -106,17 +105,17 @@ class Connection:
self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef')
await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee')
await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation()
await self._setup_obfuscation()
def get_timeout(self):
"""Returns the timeout used by the connection."""
return self.conn.timeout
def _setup_obfuscation(self):
async def _setup_obfuscation(self):
"""
Sets up the obfuscated protocol.
"""
@ -144,7 +143,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random))
await self.conn.write(bytes(random))
def is_connected(self):
"""
@ -166,12 +165,12 @@ class Connection:
# region Receive message implementations
def recv(self):
async def recv(self):
"""Receives and unpacks a message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self):
async def _recv_tcp_full(self):
"""
Receives a message from the network,
internally encoded using the TCP full protocol.
@ -181,7 +180,10 @@ class Connection:
:return: the read message payload.
"""
packet_len_seq = self.read(8) # 4 and 4
# TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
# Sometimes Telegram seems to send a packet length of 0 (12)
@ -192,15 +194,15 @@ class Connection:
'reading data left:', packet_len)
while True:
try:
__log__.error(repr(self.read(1)))
__log__.error(repr(await self.read(1)))
except TimeoutError:
break
# Connection reset and hope it's fixed after
self.conn.close()
raise ConnectionResetError()
body = self.read(packet_len - 12)
checksum = struct.unpack('<I', self.read(4))[0]
body = await self.read(packet_len - 12)
checksum = struct.unpack('<I', await self.read(4))[0]
valid_checksum = crc32(packet_len_seq + body)
if checksum != valid_checksum:
@ -208,38 +210,38 @@ class Connection:
return body
def _recv_intermediate(self):
async def _recv_intermediate(self):
"""
Receives a message from the network,
internally encoded using the TCP intermediate protocol.
:return: the read message payload.
"""
return self.read(struct.unpack('<i', self.read(4))[0])
return await self.read(struct.unpack('<i', await self.read(4))[0])
def _recv_abridged(self):
async def _recv_abridged(self):
"""
Receives a message from the network,
internally encoded using the TCP abridged protocol.
:return: the read message payload.
"""
length = struct.unpack('<B', self.read(1))[0]
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:
length = struct.unpack('<i', self.read(3) + b'\0')[0]
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return self.read(length << 2)
return await self.read(length << 2)
# endregion
# region Send message implementations
def send(self, message):
async def send(self, message):
"""Encapsulates and sends the given message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message):
async def _send_tcp_full(self, message):
"""
Encapsulates and sends the given message payload
using the TCP full mode (length, sequence, message, crc32).
@ -252,18 +254,18 @@ class Connection:
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
await self.write(data + crc)
def _send_intermediate(self, message):
async def _send_intermediate(self, message):
"""
Encapsulates and sends the given message payload
using the TCP intermediate mode (length, message).
:param message: the message to be sent.
"""
self.write(struct.pack('<i', len(message)) + message)
await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message):
async def _send_abridged(self, message):
"""
Encapsulates and sends the given message payload
using the TCP abridged mode (short length, message).
@ -276,57 +278,55 @@ class Connection:
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
await self.write(length + message)
# endregion
# region Read implementations
def read(self, length):
async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length):
async def _read_plain(self, length):
"""
Reads data from the socket connection.
:param length: how many bytes should be read.
:return: a byte sequence with len(data) == length
"""
return self.conn.read(length)
return await self.conn.read(length)
def _read_obfuscated(self, length):
async def _read_obfuscated(self, length):
"""
Reads data and decrypts from the socket connection.
:param length: how many bytes should be read.
:return: the decrypted byte sequence with len(data) == length
"""
return self._aes_decrypt.encrypt(
self.conn.read(length)
)
return self._aes_decrypt.encrypt(await self.conn.read(length))
# endregion
# region Write implementations
def write(self, data):
async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data):
async def _write_plain(self, data):
"""
Writes the given data through the socket connection.
:param data: the data in bytes to be written.
"""
self.conn.write(data)
await self.conn.write(data)
def _write_obfuscated(self, data):
async def _write_obfuscated(self, data):
"""
Writes the given data through the socket connection,
using the obfuscated mode (AES encryption is applied on top).
:param data: the data in bytes to be written.
"""
self.conn.write(self._aes_encrypt.encrypt(data))
await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion

View File

@ -26,32 +26,32 @@ class MtProtoPlainSender:
self._last_msg_id = 0
self._connection = connection
def connect(self):
async def connect(self):
"""Connects to Telegram's servers."""
self._connection.connect()
await self._connection.connect()
def disconnect(self):
"""Disconnects from Telegram's servers."""
self._connection.close()
def send(self, data):
async def send(self, data):
"""
Sends a plain packet (auth_key_id = 0) containing the
given message body (data).
:param data: the data to be sent.
"""
self._connection.send(
await self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
)
def receive(self):
async def receive(self):
"""
Receives a plain packet from the network.
:return: the response body.
"""
body = self._connection.recv()
body = await self._connection.recv()
if body == b'l\xfe\xff\xff': # -404 little endian signed
# Broken authorization, must reset the auth key
raise BrokenAuthKeyError()

View File

@ -2,9 +2,10 @@
This module contains the class used to communicate with Telegram's servers
encrypting every packet, and relies on a valid AuthKey in the used Session.
"""
import asyncio
import gzip
import logging
from threading import Lock
from asyncio import Event
from .. import helpers as utils
from ..errors import (
@ -33,7 +34,7 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed.
"""
def __init__(self, session, connection):
def __init__(self, session, connection, loop=None):
"""
Initializes a new MTProto sender.
@ -42,22 +43,20 @@ class MtProtoSender:
port of the server, salt, ID, and AuthKey,
:param connection:
the Connection to be used.
:param loop:
the asyncio loop to be used, or the default one.
"""
self.session = session
self.connection = connection
# Message IDs that need confirmation
self._need_confirmation = set()
self._loop = loop if loop else asyncio.get_event_loop()
self._recv_lock = asyncio.Lock()
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
# Multithreading
self._send_lock = Lock()
def connect(self):
async def connect(self):
"""Connects to the server."""
self.connection.connect(self.session.server_address, self.session.port)
await self.connection.connect(self.session.server_address, self.session.port)
def is_connected(self):
"""
@ -70,18 +69,25 @@ class MtProtoSender:
def disconnect(self):
"""Disconnects from the server."""
self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending()
# region Send and receive
def send(self, *requests):
async def send(self, *requests):
"""
Sends the specified TLObject(s) (which must be requests),
and acknowledging any message which needed confirmation.
:param requests: the requests to be sent.
"""
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages})
@ -91,13 +97,6 @@ class MtProtoSender:
for m in messages
))
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
)
self._need_confirmation.clear()
if len(messages) == 1:
message = messages[0]
else:
@ -108,13 +107,13 @@ class MtProtoSender:
for m in messages:
m.container_msg_id = message.msg_id
self._send_message(message)
await self._send_message(message)
def _send_acknowledge(self, msg_id):
async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id."""
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
def receive(self, update_state):
async def receive(self, update_state):
"""
Receives a single message from the connected endpoint.
@ -130,7 +129,10 @@ class MtProtoSender:
Update and Updates objects.
"""
try:
body = self.connection.recv()
with await self._recv_lock:
# Receiving items is not an "atomic" operation since we
# need to read the length and then upcoming parts separated.
body = await self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
@ -147,20 +149,20 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion
# region Low level processing
def _send_message(self, message):
async def _send_message(self, message):
"""
Sends the given encrypted through the network.
:param message: the TLMessage to be sent.
"""
with self._send_lock:
self.connection.send(utils.pack_message(self.session, message))
await self.connection.send(utils.pack_message(self.session, message))
def _decode_msg(self, body):
"""
@ -178,7 +180,7 @@ class MtProtoSender:
with BinaryReader(body) as reader:
return utils.unpack_message(self.session, reader)
def _process_msg(self, msg_id, sequence, reader, state):
async def _process_msg(self, msg_id, sequence, reader, state):
"""
Processes the message read from the network inside reader.
@ -189,7 +191,6 @@ class MtProtoSender:
:return: true if the message was handled correctly, false otherwise.
"""
# TODO Check salt, session_id and sequence_number
self._need_confirmation.add(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
@ -197,15 +198,15 @@ class MtProtoSender:
# These are a bit of special case, not yet generated by the code gen
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
__log__.debug('Processing Remote Procedure Call result')
return self._handle_rpc_result(msg_id, sequence, reader)
return await self._handle_rpc_result(msg_id, sequence, reader)
if code == MessageContainer.CONSTRUCTOR_ID:
__log__.debug('Processing container result')
return self._handle_container(msg_id, sequence, reader, state)
return await self._handle_container(msg_id, sequence, reader, state)
if code == GzipPacked.CONSTRUCTOR_ID:
__log__.debug('Processing gzipped result')
return self._handle_gzip_packed(msg_id, sequence, reader, state)
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code not in tlobjects:
__log__.warning(
@ -218,22 +219,22 @@ class MtProtoSender:
__log__.debug('Processing %s result', type(obj).__name__)
if isinstance(obj, Pong):
return self._handle_pong(msg_id, sequence, obj)
return await self._handle_pong(msg_id, sequence, obj)
if isinstance(obj, BadServerSalt):
return self._handle_bad_server_salt(msg_id, sequence, obj)
return await self._handle_bad_server_salt(msg_id, sequence, obj)
if isinstance(obj, BadMsgNotification):
return self._handle_bad_msg_notification(msg_id, sequence, obj)
return await self._handle_bad_msg_notification(msg_id, sequence, obj)
if isinstance(obj, MsgDetailedInfo):
return self._handle_msg_detailed_info(msg_id, sequence, obj)
return await self._handle_msg_detailed_info(msg_id, sequence, obj)
if isinstance(obj, MsgNewDetailedInfo):
return self._handle_msg_new_detailed_info(msg_id, sequence, obj)
return await self._handle_msg_new_detailed_info(msg_id, sequence, obj)
if isinstance(obj, NewSessionCreated):
return self._handle_new_session_created(msg_id, sequence, obj)
return await self._handle_new_session_created(msg_id, sequence, obj)
if isinstance(obj, MsgsAck): # may handle the request we wanted
# Ignore every ack request *unless* when logging out, when it's
@ -310,7 +311,7 @@ class MtProtoSender:
r.request.confirm_received.set()
self._pending_receive.clear()
def _resend_request(self, msg_id):
async def _resend_request(self, msg_id):
"""
Re-sends the request that belongs to a certain msg_id. This may
also be the msg_id of a container if they were sent in one.
@ -319,12 +320,13 @@ class MtProtoSender:
"""
request = self._pop_request(msg_id)
if request:
return self.send(request)
await self.send(request)
return
requests = self._pop_requests_of_container(msg_id)
if requests:
return self.send(*requests)
await self.send(*requests)
def _handle_pong(self, msg_id, sequence, pong):
async def _handle_pong(self, msg_id, sequence, pong):
"""
Handles a Pong response.
@ -340,7 +342,7 @@ class MtProtoSender:
return True
def _handle_container(self, msg_id, sequence, reader, state):
async def _handle_container(self, msg_id, sequence, reader, state):
"""
Handles a MessageContainer response.
@ -355,7 +357,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(inner_msg_id, sequence, reader, state):
if not await self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
@ -364,7 +366,7 @@ class MtProtoSender:
return True
def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
async def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
"""
Handles a BadServerSalt response.
@ -378,10 +380,11 @@ class MtProtoSender:
# "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it"
self._resend_request(bad_salt.bad_msg_id)
await self._resend_request(bad_salt.bad_msg_id)
return True
def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
"""
Handles a BadMessageError response.
@ -397,25 +400,25 @@ class MtProtoSender:
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
__log__.info('Attempting to use the correct time offset')
self._resend_request(bad_msg.bad_msg_id)
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session.sequence += 64
__log__.info('Attempting to set the right higher sequence')
self._resend_request(bad_msg.bad_msg_id)
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session.sequence -= 16
__log__.info('Attempting to set the right lower sequence')
self._resend_request(bad_msg.bad_msg_id)
await self._resend_request(bad_msg.bad_msg_id)
return True
else:
raise error
def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
async def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
"""
Handles a MsgDetailedInfo response.
@ -426,10 +429,10 @@ class MtProtoSender:
"""
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6
self._send_acknowledge(msg_new.answer_msg_id)
await self._send_acknowledge(msg_new.answer_msg_id)
return True
def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
async def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
"""
Handles a MsgNewDetailedInfo response.
@ -440,10 +443,10 @@ class MtProtoSender:
"""
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR
self._send_acknowledge(msg_new.answer_msg_id)
await self._send_acknowledge(msg_new.answer_msg_id)
return True
def _handle_new_session_created(self, msg_id, sequence, new_session):
async def _handle_new_session_created(self, msg_id, sequence, new_session):
"""
Handles a NewSessionCreated response.
@ -456,7 +459,7 @@ class MtProtoSender:
# TODO https://goo.gl/LMyN7A
return True
def _handle_rpc_result(self, msg_id, sequence, reader):
async def _handle_rpc_result(self, msg_id, sequence, reader):
"""
Handles a RPCResult response.
@ -484,9 +487,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string()
)
# Acknowledge that we received the error
self._send_acknowledge(request_id)
if request:
request.rpc_error = error
request.confirm_received.set()
@ -522,7 +522,7 @@ class MtProtoSender:
)
return False
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
"""
Handles a GzipPacked response.
@ -532,11 +532,6 @@ class MtProtoSender:
:return: the result of processing the packed message.
"""
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
# We are reentering process_msg, which seemingly the same msg_id
# to the self._need_confirmation set. Remove it from there first
# to avoid any future conflicts (i.e. if we "ignore" messages
# that we are already aware of, see 1a91c02 and old 63dfb1e)
self._need_confirmation -= {msg_id}
return self._process_msg(msg_id, sequence, compressed_reader, state)
return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion

View File

@ -3,7 +3,6 @@ import os
import sqlite3
from base64 import b64decode
from os.path import isfile as file_exists
from threading import Lock, RLock
from .memory import MemorySession, _SentFileType
from ..crypto import AuthKey
@ -39,11 +38,6 @@ class SQLiteSession(MemorySession):
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
# Cross-thread safety
self._seq_no_lock = Lock()
self._msg_id_lock = Lock()
self._db_lock = RLock()
# Migrating from .json -> SQL
entities = self._check_migrate_json()
@ -189,42 +183,37 @@ class SQLiteSession(MemorySession):
self._update_session_table()
def _update_session_table(self):
with self._db_lock:
c = self._cursor()
# While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't
# tell us which auth_key's are usable and will work. Needs
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
))
c.close()
c = self._cursor()
# While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't
# tell us which auth_key's are usable and will work. Needs
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
))
c.close()
def save(self):
"""Saves the current session object as session_user_id.session"""
with self._db_lock:
self._conn.commit()
self._conn.commit()
def _cursor(self):
"""Asserts that the connection is open and returns a cursor"""
with self._db_lock:
if self._conn is None:
self._conn = sqlite3.connect(self.filename,
check_same_thread=False)
return self._conn.cursor()
if self._conn is None:
self._conn = sqlite3.connect(self.filename)
return self._conn.cursor()
def close(self):
"""Closes the connection unless we're working in-memory"""
if self.filename != ':memory:':
with self._db_lock:
if self._conn is not None:
self._conn.close()
self._conn = None
if self._conn is not None:
self._conn.close()
self._conn = None
def delete(self):
"""Deletes the current session file"""
@ -259,11 +248,10 @@ class SQLiteSession(MemorySession):
if not rows:
return
with self._db_lock:
self._cursor().executemany(
'insert or replace into entities values (?,?,?,?,?)', rows
)
self.save()
self._cursor().executemany(
'insert or replace into entities values (?,?,?,?,?)', rows
)
self.save()
def _fetchone_entity(self, query, args):
c = self._cursor()
@ -302,11 +290,10 @@ class SQLiteSession(MemorySession):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
with self._db_lock:
self._cursor().execute(
'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size,
_SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash
))
self.save()
self._cursor().execute(
'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size,
_SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash
))
self.save()

View File

@ -1,11 +1,9 @@
import asyncio
import logging
import os
from asyncio import Lock
from datetime import timedelta
import platform
import threading
from datetime import timedelta, datetime
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Lock
from time import sleep
from . import version, utils
from .crypto import rsa
from .errors import (
@ -70,8 +68,6 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL,
use_ipv6=False,
proxy=None,
update_workers=None,
spawn_read_thread=False,
timeout=timedelta(seconds=5),
loop=None,
device_model=None,
@ -95,6 +91,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.'
)
self._loop = loop if loop else asyncio.get_event_loop()
# ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or
(':' in session.server_address) != use_ipv6):
@ -112,13 +110,15 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection(
mode=connection_mode, proxy=proxy, timeout=timeout
))
self._sender = MtProtoSender(
self.session,
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
# Two threads may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock()
# Two coroutines may be calling reconnect() when the connection
# is lost, we only want one to actually perform the reconnection.
self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively
@ -127,7 +127,7 @@ class TelegramBareClient:
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=update_workers)
self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect
system = platform.uname()
@ -153,34 +153,25 @@ class TelegramBareClient:
# See https://core.telegram.org/api/invoking#saving-client-info.
self._first_request = True
# Constantly read for results and updates from within the main client,
# if the user has left enabled such option.
self._spawn_read_thread = spawn_read_thread
self._recv_thread = None
self._idling = threading.Event()
self._recv_loop = None
self._ping_loop = None
self._state_loop = None
self._idling = asyncio.Event()
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
# Also have another delay for GetStateRequest.
#
# If the connection is kept alive for long without invoking any
# high level request the server simply stops sending updates.
# TODO maybe we can have ._last_request instead if any req works?
self._last_state = datetime.now()
self._state_delay = timedelta(hours=1)
# Some errors are known but there's nothing we can do from the
# background thread. If any of these happens, call .disconnect(),
# and raise them next time .invoke() is tried to be called.
self._background_error = None
# endregion
# region Connecting
def connect(self, _sync_updates=True):
async def connect(self, _sync_updates=True):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
@ -197,10 +188,8 @@ class TelegramBareClient:
__log__.info('Connecting to %s:%d...',
self.session.server_address, self.session.port)
self._background_error = None # Clear previous errors
try:
self._sender.connect()
await self._sender.connect()
__log__.info('Connection success!')
# Connection was successful! Try syncing the update state
@ -210,12 +199,12 @@ class TelegramBareClient:
self._user_connected = True
if self._authorized is None and _sync_updates:
try:
self.sync_updates()
self._set_connected_and_authorized()
await self.sync_updates()
await self._set_connected_and_authorized()
except UnauthorizedError:
self._authorized = False
elif self._authorized:
self._set_connected_and_authorized()
await self._set_connected_and_authorized()
return True
@ -224,7 +213,7 @@ class TelegramBareClient:
__log__.warning('Connection failed, got unexpected type with ID '
'%s. Migrating?', hex(e.invalid_constructor_id))
self.disconnect()
return self.connect(_sync_updates=_sync_updates)
return await self.connect(_sync_updates=_sync_updates)
except (RPCError, ConnectionError) as e:
# Probably errors from the previous session, ignore them
@ -249,24 +238,15 @@ class TelegramBareClient:
))
def disconnect(self):
"""Disconnects from the Telegram server
and stops all the spawned threads"""
"""Disconnects from the Telegram server"""
__log__.info('Disconnecting...')
self._user_connected = False # This will stop recv_thread's loop
__log__.debug('Stopping all workers...')
self.updates.stop_workers()
# This will trigger a "ConnectionResetError" on the recv_thread,
# which won't attempt reconnecting as ._user_connected is False.
__log__.debug('Disconnecting the socket...')
self._user_connected = False
self._sender.disconnect()
# TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again
self.session.close()
def _reconnect(self, new_dc=None):
async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made
since it's assumed that the connection has been lost and the
library is reconnecting.
@ -276,13 +256,14 @@ class TelegramBareClient:
connects to the new data center.
"""
if new_dc is None:
if self.is_connected():
__log__.info('Reconnection aborted: already connected')
return True
# Assume we are disconnected due to some error, so connect again
try:
if self.is_connected():
__log__.info('Reconnection aborted: already connected')
return True
__log__.info('Attempting reconnection...')
return self.connect()
return await self.connect()
except ConnectionResetError as e:
__log__.warning('Reconnection failed due to %s', e)
return False
@ -290,7 +271,7 @@ class TelegramBareClient:
# Since we're reconnecting possibly due to a UserMigrateError,
# we need to first know the Data Centers we can connect to. Do
# that before disconnecting.
dc = self._get_dc(new_dc)
dc = await self._get_dc(new_dc)
__log__.info('Reconnecting to new data center %s', dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port)
@ -299,7 +280,7 @@ class TelegramBareClient:
self.session.auth_key = None
self.session.save()
self.disconnect()
return self.connect()
return await self.connect()
def set_proxy(self, proxy):
"""Change the proxy used by the connections.
@ -312,19 +293,15 @@ class TelegramBareClient:
# region Working with different connections/Data Centers
def _on_read_thread(self):
return self._recv_thread is not None and \
threading.get_ident() == self._recv_thread.ident
def _get_dc(self, dc_id, cdn=False):
async def _get_dc(self, dc_id, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'"""
if not TelegramBareClient._config:
TelegramBareClient._config = self(GetConfigRequest())
TelegramBareClient._config = await self(GetConfigRequest())
try:
if cdn:
# Ensure we have the latest keys for the CDNs
for pk in self(GetCdnConfigRequest()).public_keys:
for pk in await (self(GetCdnConfigRequest())).public_keys:
rsa.add_key(pk.public_key)
return next(
@ -336,10 +313,10 @@ class TelegramBareClient:
raise
# New configuration, perhaps a new CDN was added?
TelegramBareClient._config = self(GetConfigRequest())
return self._get_dc(dc_id, cdn=cdn)
TelegramBareClient._config = await self(GetConfigRequest())
return await self._get_dc(dc_id, cdn=cdn)
def _get_exported_client(self, dc_id):
async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time calling the method with a given dc_id,
@ -356,11 +333,11 @@ class TelegramBareClient:
# TODO Add a lock, don't allow two threads to create an auth key
# (when calling .connect() if there wasn't a previous session).
# for the same data center.
dc = self._get_dc(dc_id)
dc = await self._get_dc(dc_id)
# Export the current authorization to the new DC.
__log__.info('Exporting authorization for data center %s', dc)
export_auth = self(ExportAuthorizationRequest(dc_id))
export_auth = await self(ExportAuthorizationRequest(dc_id))
# Create a temporary session for this IP address, which needs
# to be different because each auth_key is unique per DC.
@ -375,11 +352,12 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
client.connect(_sync_updates=False)
await client.connect(_sync_updates=False)
if isinstance(export_auth, ExportedAuthorization):
client(ImportAuthorizationRequest(
await client(ImportAuthorizationRequest(
id=export_auth.id, bytes=export_auth.bytes
))
elif export_auth is not None:
@ -388,11 +366,11 @@ class TelegramBareClient:
client._authorized = True # We exported the auth, so we got auth
return client
def _get_cdn_client(self, cdn_redirect):
async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session
@ -401,7 +379,8 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
# This will make use of the new RSA keys for this specific CDN.
@ -409,7 +388,7 @@ class TelegramBareClient:
# We won't be calling GetConfigRequest because it's only called
# when needed by ._get_dc, and also it's static so it's likely
# set already. Avoid invoking non-CDN methods by not syncing updates.
client.connect(_sync_updates=False)
await client.connect(_sync_updates=False)
client._authorized = self._authorized
return client
@ -417,7 +396,7 @@ class TelegramBareClient:
# region Invoking Telegram requests
def __call__(self, *requests, retries=5):
async def __call__(self, *requests, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
The invoke will be retried up to 'retries' times before raising
@ -427,11 +406,8 @@ class TelegramBareClient:
x.content_related for x in requests):
raise TypeError('You can only invoke requests, not types!')
if self._background_error:
raise self._background_error
for request in requests:
request.resolve(self, utils)
await request.resolve(self, utils)
# For logging purposes
if len(requests) == 1:
@ -440,26 +416,23 @@ class TelegramBareClient:
which = '{} requests ({})'.format(
len(requests), [type(x).__name__ for x in requests])
# Determine the sender to be used (main or a new connection)
__log__.debug('Invoking %s', which)
call_receive = \
not self._idling.is_set() or self._reconnect_lock.locked()
for retry in range(retries):
result = self._invoke(call_receive, *requests)
result = await self._invoke(call_receive, retry, *requests)
if result is not None:
return result
__log__.warning('Invoking %s failed %d times, '
'reconnecting and retrying',
[str(x) for x in requests], retry + 1)
sleep(1)
# The ReadThread has priority when attempting reconnection,
# since this thread is constantly running while __call__ is
# only done sometimes. Here try connecting only once/retry.
await asyncio.sleep(retry + 1, loop=self._loop)
if not self._reconnect_lock.locked():
with self._reconnect_lock:
self._reconnect()
with await self._reconnect_lock:
await self._reconnect()
raise RuntimeError('Number of retries reached 0 for {}.'.format(
[type(x).__name__ for x in requests]
@ -468,18 +441,17 @@ class TelegramBareClient:
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
def _invoke(self, call_receive, *requests):
async def _invoke(self, call_receive, retry, *requests):
try:
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.confirm_received.clear()
x.rpc_error = None
if not self.session.auth_key:
__log__.info('Need to generate new auth key before invoking')
self._first_request = True
self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(self._sender.connection)
await authenticator.do_authentication(self._sender.connection)
if self._first_request:
__log__.info('Initializing a new connection while invoking')
@ -489,24 +461,21 @@ class TelegramBareClient:
# We need a SINGLE request (like GetConfig) to init conn.
# Once that's done, the N original requests will be
# invoked.
TelegramBareClient._config = self(
TelegramBareClient._config = await self(
self._wrap_init_connection(GetConfigRequest())
)
self._sender.send(*requests)
await self._sender.send(*requests)
if not call_receive:
# TODO This will be slightly troublesome if we allow
# switching between constant read or not on the fly.
# Must also watch out for calling .read() from two places,
# in which case a Lock would be required for .receive().
for x in requests:
x.confirm_received.wait(
self._sender.connection.get_timeout()
)
await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else:
while not all(x.confirm_received.is_set() for x in requests):
self._sender.receive(update_state=self.updates)
await self._sender.receive(update_state=self.updates)
except BrokenAuthKeyError:
__log__.error('Authorization key seems broken and was invalid!')
@ -552,12 +521,8 @@ class TelegramBareClient:
except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e:
# TODO What happens with the background thread here?
# For normal use cases, this won't happen, because this will only
# be on the very first connection (not authorized, not running),
# but may be an issue for people who actually travel?
self._reconnect(new_dc=e.new_dc)
return self._invoke(call_receive, *requests)
await self._reconnect(new_dc=e.new_dc)
return await self._invoke(call_receive, retry, *requests)
except ServerError as e:
# Telegram is having some issues, just retry
@ -568,7 +533,8 @@ class TelegramBareClient:
if e.seconds > self.session.flood_sleep_threshold | 0:
raise
sleep(e.seconds)
await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality
@ -588,90 +554,69 @@ class TelegramBareClient:
# region Updates handling
def sync_updates(self):
async def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
self.updates.process(self(GetStateRequest()))
self._last_state = datetime.now()
self.updates.process(await self(GetStateRequest()))
# endregion
# region Constant read
# Constant read
def _set_connected_and_authorized(self):
# This is async so that the overrided version in TelegramClient can be
# async without problems.
async def _set_connected_and_authorized(self):
self._authorized = True
self.updates.setup_workers()
if self._spawn_read_thread and self._recv_thread is None:
self._recv_thread = threading.Thread(
name='ReadThread', daemon=True,
target=self._recv_thread_impl
)
self._recv_thread.start()
if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
if self._state_loop is None:
self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop)
def _signal_handler(self, signum, frame):
if self._user_connected:
self.disconnect()
else:
os._exit(1)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
self._ping_loop = None
def idle(self, stop_signals=(SIGINT, SIGTERM, SIGABRT)):
"""
Idles the program by looping forever and listening for updates
until one of the signals are received, which breaks the loop.
:param stop_signals:
Iterable containing signals from the signal module that will
be subscribed to TelegramClient.disconnect() (effectively
stopping the idle loop), which will be called on receiving one
of those signals.
:return:
"""
if self._spawn_read_thread and not self._on_read_thread():
raise RuntimeError('Can only idle if spawn_read_thread=False')
async def _state_loop_impl(self):
while self._user_connected:
await asyncio.sleep(self._state_delay.seconds, loop=self._loop)
await self._sender.send(GetStateRequest())
async def _recv_loop_impl(self):
__log__.info('Starting to wait for items from the network')
self._idling.set()
for sig in stop_signals:
signal(sig, self._signal_handler)
if self._on_read_thread():
__log__.info('Starting to wait for items from the network')
else:
__log__.info('Idling to receive items from the network')
need_reconnect = False
while self._user_connected:
try:
if datetime.now() > self._last_ping + self._ping_delay:
self._sender.send(PingRequest(
int.from_bytes(os.urandom(8), 'big', signed=True)
))
self._last_ping = datetime.now()
if need_reconnect:
__log__.info('Attempting reconnection from read loop')
need_reconnect = False
with await self._reconnect_lock:
while self._user_connected and not await self._reconnect():
# Retry forever, this is instant messaging
await asyncio.sleep(0.1, loop=self._loop)
if datetime.now() > self._last_state + self._state_delay:
self._sender.send(GetStateRequest())
self._last_state = datetime.now()
__log__.debug('Receiving items from the network...')
self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem
__log__.debug('Receiving items from the network timed out')
except ConnectionResetError:
if self._user_connected:
__log__.error('Connection was reset while receiving '
'items. Reconnecting')
with self._reconnect_lock:
while self._user_connected and not self._reconnect():
sleep(0.1) # Retry forever, this is instant messaging
if self.is_connected():
# Telegram seems to kick us every 1024 items received
# from the network not considering things like bad salt.
# We must execute some *high level* request (that's not
# a ping) if we want to receive updates again.
# TODO Test if getDifference works too (better alternative)
self._sender.send(GetStateRequest())
await self._sender.send(GetStateRequest())
__log__.debug('Receiving items from the network...')
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
__log__.debug('Receiving items from the network timed out')
except ConnectionError:
need_reconnect = True
__log__.error('Connection was reset while receiving items')
await asyncio.sleep(1, loop=self._loop)
except:
self._idling.clear()
raise
@ -679,39 +624,4 @@ class TelegramBareClient:
self._idling.clear()
__log__.info('Connection closed by the user, not reading anymore')
# By using this approach, another thread will be
# created and started upon connection to constantly read
# from the other end. Otherwise, manual calls to .receive()
# must be performed. The MtProtoSender cannot be connected,
# or an error will be thrown.
#
# This way, sending and receiving will be completely independent.
def _recv_thread_impl(self):
# This thread is "idle" (only listening for updates), but also
# excepts everything unlike the manual idle because it should
# not crash.
while self._user_connected:
try:
self.idle(stop_signals=tuple())
except Exception as error:
__log__.exception('Unknown exception in the read thread! '
'Disconnecting and leaving it to main thread')
# Unknown exception, pass it to the main thread
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a proxy error"
self._recv_thread = None
# endregion

View File

@ -1,3 +1,4 @@
import asyncio
import getpass
import hashlib
import io
@ -6,7 +7,6 @@ import logging
import os
import re
import sys
import time
import warnings
from collections import OrderedDict, UserList
from datetime import datetime, timedelta
@ -158,18 +158,16 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL,
use_ipv6=False,
proxy=None,
update_workers=None,
timeout=timedelta(seconds=5),
spawn_read_thread=True,
loop=None,
**kwargs):
super().__init__(
session, api_id, api_hash,
connection_mode=connection_mode,
use_ipv6=use_ipv6,
proxy=proxy,
update_workers=update_workers,
spawn_read_thread=spawn_read_thread,
timeout=timeout,
loop=loop,
**kwargs
)
@ -190,7 +188,7 @@ class TelegramClient(TelegramBareClient):
# region Authorization requests
def send_code_request(self, phone, force_sms=False):
async def send_code_request(self, phone, force_sms=False):
"""
Sends a code request to the specified phone number.
@ -208,7 +206,7 @@ class TelegramClient(TelegramBareClient):
phone_hash = self._phone_code_hash.get(phone)
if not phone_hash:
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone_code_hash[phone] = phone_hash = result.phone_code_hash
else:
force_sms = True
@ -216,22 +214,23 @@ class TelegramClient(TelegramBareClient):
self._phone = phone
if force_sms:
result = self(ResendCodeRequest(phone, phone_hash))
result = await self(ResendCodeRequest(phone, phone_hash))
self._phone_code_hash[phone] = result.phone_code_hash
return result
def start(self,
phone=lambda: input('Please enter your phone: '),
password=lambda: getpass.getpass('Please enter your password: '),
bot_token=None, force_sms=False, code_callback=None,
first_name='New User', last_name=''):
async def start(self,
phone=lambda: input('Please enter your phone: '),
password=lambda: getpass.getpass(
'Please enter your password: '),
bot_token=None, force_sms=False, code_callback=None,
first_name='New User', last_name=''):
"""
Convenience method to interactively connect and sign in if required,
also taking into consideration that 2FA may be enabled in the account.
Example usage:
>>> client = TelegramClient(session, api_id, api_hash).start(phone)
>>> client = await TelegramClient(session, api_id, api_hash).start(phone)
Please enter the code you received: 12345
Please enter your password: *******
(You are now logged in)
@ -286,14 +285,14 @@ class TelegramClient(TelegramBareClient):
'must only provide one of either')
if not self.is_connected():
self.connect()
await self.connect()
if self.is_user_authorized():
self._check_events_pending_resolve()
return self
if bot_token:
self.sign_in(bot_token=bot_token)
await self.sign_in(bot_token=bot_token)
return self
# Turn the callable into a valid phone number
@ -305,15 +304,15 @@ class TelegramClient(TelegramBareClient):
max_attempts = 3
two_step_detected = False
sent_code = self.send_code_request(phone, force_sms=force_sms)
sent_code = await self.send_code_request(phone, force_sms=force_sms)
sign_up = not sent_code.phone_registered
while attempts < max_attempts:
try:
if sign_up:
me = self.sign_up(code_callback(), first_name, last_name)
me = await self.sign_up(code_callback(), first_name, last_name)
else:
# Raises SessionPasswordNeededError if 2FA enabled
me = self.sign_in(phone, code_callback())
me = await self.sign_in(phone, code_callback())
break
except SessionPasswordNeededError:
two_step_detected = True
@ -342,15 +341,15 @@ class TelegramClient(TelegramBareClient):
# TODO If callable given make it retry on invalid
if callable(password):
password = password()
me = self.sign_in(phone=phone, password=password)
me = await self.sign_in(phone=phone, password=password)
# We won't reach here if any step failed (exit by exception)
print('Signed in successfully as', utils.get_display_name(me))
self._check_events_pending_resolve()
return self
def sign_in(self, phone=None, code=None,
password=None, bot_token=None, phone_code_hash=None):
async def sign_in(self, phone=None, code=None,
password=None, bot_token=None, phone_code_hash=None):
"""
Starts or completes the sign in process with the given phone number
or code that Telegram sent.
@ -385,7 +384,7 @@ class TelegramClient(TelegramBareClient):
return self.get_me()
if phone and not code and not password:
return self.send_code_request(phone)
return await self.send_code_request(phone)
elif code:
phone = utils.parse_phone(phone) or self._phone
phone_code_hash = \
@ -400,14 +399,14 @@ class TelegramClient(TelegramBareClient):
# May raise PhoneCodeEmptyError, PhoneCodeExpiredError,
# PhoneCodeHashEmptyError or PhoneCodeInvalidError.
result = self(SignInRequest(phone, phone_code_hash, str(code)))
result = await self(SignInRequest(phone, phone_code_hash, str(code)))
elif password:
salt = self(GetPasswordRequest()).current_salt
result = self(CheckPasswordRequest(
salt = (await self(GetPasswordRequest())).current_salt
result = await self(CheckPasswordRequest(
helpers.get_password_hash(password, salt)
))
elif bot_token:
result = self(ImportBotAuthorizationRequest(
result = await self(ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash
))
@ -420,10 +419,10 @@ class TelegramClient(TelegramBareClient):
self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False
)
self._set_connected_and_authorized()
await self._set_connected_and_authorized()
return result.user
def sign_up(self, code, first_name, last_name=''):
async def sign_up(self, code, first_name, last_name=''):
"""
Signs up to Telegram if you don't have an account yet.
You must call .send_code_request(phone) first.
@ -442,10 +441,10 @@ class TelegramClient(TelegramBareClient):
The new created user.
"""
if self.is_user_authorized():
self._check_events_pending_resolve()
return self.get_me()
await self._check_events_pending_resolve()
return await self.get_me()
result = self(SignUpRequest(
result = await self(SignUpRequest(
phone_number=self._phone,
phone_code_hash=self._phone_code_hash.get(self._phone, ''),
phone_code=str(code),
@ -456,10 +455,10 @@ class TelegramClient(TelegramBareClient):
self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False
)
self._set_connected_and_authorized()
await self._set_connected_and_authorized()
return result.user
def log_out(self):
async def log_out(self):
"""
Logs out Telegram and deletes the current ``*.session`` file.
@ -467,7 +466,7 @@ class TelegramClient(TelegramBareClient):
True if the operation was successful.
"""
try:
self(LogOutRequest())
await self(LogOutRequest())
except RPCError:
return False
@ -475,7 +474,7 @@ class TelegramClient(TelegramBareClient):
self.session.delete()
return True
def get_me(self, input_peer=False):
async def get_me(self, input_peer=False):
"""
Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated).
@ -491,9 +490,8 @@ class TelegramClient(TelegramBareClient):
"""
if input_peer and self._self_input_peer:
return self._self_input_peer
try:
me = self(GetUsersRequest([InputUserSelf()]))[0]
me = (await self(GetUsersRequest([InputUserSelf()])))[0]
if not self._self_input_peer:
self._self_input_peer = utils.get_input_peer(
me, allow_self=False
@ -507,8 +505,8 @@ class TelegramClient(TelegramBareClient):
# region Dialogs ("chats") requests
def get_dialogs(self, limit=10, offset_date=None, offset_id=0,
offset_peer=InputPeerEmpty()):
async def get_dialogs(self, limit=10, offset_date=None, offset_id=0,
offset_peer=InputPeerEmpty()):
"""
Gets N "dialogs" (open "chats" or conversations with other people).
@ -535,7 +533,7 @@ class TelegramClient(TelegramBareClient):
limit = float('inf') if limit is None else int(limit)
if limit == 0:
# Special case, get a single dialog and determine count
dialogs = self(GetDialogsRequest(
dialogs = await self(GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
@ -549,7 +547,7 @@ class TelegramClient(TelegramBareClient):
dialogs = OrderedDict() # Use peer id as identifier to avoid dupes
while len(dialogs) < limit:
real_limit = min(limit - len(dialogs), 100)
r = self(GetDialogsRequest(
r = await self(GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
@ -580,7 +578,7 @@ class TelegramClient(TelegramBareClient):
dialogs.total = total_count
return dialogs
def get_drafts(self): # TODO: Ability to provide a `filter`
async def get_drafts(self): # TODO: Ability to provide a `filter`
"""
Gets all open draft messages.
@ -589,7 +587,7 @@ class TelegramClient(TelegramBareClient):
You can call ``draft.set_message('text')`` to change the message,
or delete it through :meth:`draft.delete()`.
"""
response = self(GetAllDraftsRequest())
response = await self(GetAllDraftsRequest())
self.session.process_entities(response)
self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates]
@ -636,7 +634,7 @@ class TelegramClient(TelegramBareClient):
if request.id == update.message.id:
return update.message
def _parse_message_text(self, message, parse_mode):
async def _parse_message_text(self, message, parse_mode):
"""
Returns a (parsed message, entities) tuple depending on parse_mode.
"""
@ -657,7 +655,7 @@ class TelegramClient(TelegramBareClient):
if m:
try:
msg_entities[i] = InputMessageEntityMentionName(
e.offset, e.length, self.get_input_entity(
e.offset, e.length, await self.get_input_entity(
int(m.group(1)) if m.group(1) else e.url
)
)
@ -667,8 +665,8 @@ class TelegramClient(TelegramBareClient):
return message, msg_entities
def send_message(self, entity, message, reply_to=None, parse_mode='md',
link_preview=True):
async def send_message(self, entity, message, reply_to=None,
parse_mode='md', link_preview=True):
"""
Sends the given message to the specified entity (user/chat/channel).
@ -695,11 +693,12 @@ class TelegramClient(TelegramBareClient):
Returns:
the sent message
"""
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if isinstance(message, Message):
if (message.media
and not isinstance(message.media, MessageMediaWebPage)):
return self.send_file(entity, message.media)
return await self.send_file(entity, message.media)
if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id):
reply_id = message.reply_to_msg_id
@ -716,7 +715,7 @@ class TelegramClient(TelegramBareClient):
)
message = message.message
else:
message, msg_ent = self._parse_message_text(message, parse_mode)
message, msg_ent = await self._parse_message_text(message, parse_mode)
request = SendMessageRequest(
peer=entity,
message=message,
@ -725,7 +724,8 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_message_id(reply_to)
)
result = self(request)
result = await self(request)
if isinstance(result, UpdateShortSentMessage):
return Message(
id=result.id,
@ -739,8 +739,8 @@ class TelegramClient(TelegramBareClient):
return self._get_response_message(request, result)
def edit_message(self, entity, message_id, message=None, parse_mode='md',
link_preview=True):
async def edit_message(self, entity, message_id, message=None,
parse_mode='md', link_preview=True):
"""
Edits the given message ID (to change its contents or disable preview).
@ -773,18 +773,18 @@ class TelegramClient(TelegramBareClient):
Returns:
the edited message
"""
message, msg_entities = self._parse_message_text(message, parse_mode)
message, msg_entities = await self._parse_message_text(message, parse_mode)
request = EditMessageRequest(
peer=self.get_input_entity(entity),
peer=await self.get_input_entity(entity),
id=self._get_message_id(message_id),
message=message,
no_webpage=not link_preview,
entities=msg_entities
)
result = self(request)
result = await self(request)
return self._get_response_message(request, result)
def delete_messages(self, entity, message_ids, revoke=True):
async def delete_messages(self, entity, message_ids, revoke=True):
"""
Deletes a message from a chat, optionally "for everyone".
@ -812,18 +812,18 @@ class TelegramClient(TelegramBareClient):
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
if entity is None:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if isinstance(entity, InputPeerChannel):
return self(channels.DeleteMessagesRequest(entity, message_ids))
return await self(channels.DeleteMessagesRequest(entity, message_ids))
else:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
def get_message_history(self, entity, limit=20, offset_date=None,
offset_id=0, max_id=0, min_id=0, add_offset=0,
batch_size=100, wait_time=None):
async def get_message_history(self, entity, limit=20, offset_date=None,
offset_id=0, max_id=0, min_id=0, add_offset=0,
batch_size=100, wait_time=None):
"""
Gets the message history for the specified entity
@ -884,13 +884,12 @@ class TelegramClient(TelegramBareClient):
second is the default for this limit (or above). You may need
an higher limit, so you're free to set the ``batch_size`` that
you think may be good.
"""
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
limit = float('inf') if limit is None else int(limit)
if limit == 0:
# No messages, but we still need to know the total message count
result = self(GetHistoryRequest(
result = await self(GetHistoryRequest(
peer=entity, limit=1,
offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0
))
@ -906,7 +905,7 @@ class TelegramClient(TelegramBareClient):
while len(messages) < limit:
# Telegram has a hard limit of 100
real_limit = min(limit - len(messages), batch_size)
result = self(GetHistoryRequest(
result = await self(GetHistoryRequest(
peer=entity,
limit=real_limit,
offset_date=offset_date,
@ -931,7 +930,7 @@ class TelegramClient(TelegramBareClient):
offset_id = result.messages[-1].id
offset_date = result.messages[-1].date
time.sleep(wait_time)
await asyncio.sleep(wait_time)
# Add a few extra attributes to the Message to make it friendlier.
messages.total = total_messages
@ -959,8 +958,8 @@ class TelegramClient(TelegramBareClient):
return messages
def send_read_acknowledge(self, entity, message=None, max_id=None,
clear_mentions=False):
async def send_read_acknowledge(self, entity, message=None, max_id=None,
clear_mentions=False):
"""
Sends a "read acknowledge" (i.e., notifying the given peer that we've
read their messages, also known as the "double check").
@ -993,17 +992,17 @@ class TelegramClient(TelegramBareClient):
raise ValueError(
'Either a message list or a max_id must be provided.')
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if clear_mentions:
self(ReadMentionsRequest(entity))
await self(ReadMentionsRequest(entity))
if max_id is None:
return True
if max_id is not None:
if isinstance(entity, InputPeerChannel):
return self(channels.ReadHistoryRequest(entity, max_id=max_id))
return await self(channels.ReadHistoryRequest(entity, max_id=max_id))
else:
return self(messages.ReadHistoryRequest(entity, max_id=max_id))
return await self(messages.ReadHistoryRequest(entity, max_id=max_id))
return False
@ -1025,8 +1024,8 @@ class TelegramClient(TelegramBareClient):
raise TypeError('Invalid message type: {}'.format(type(message)))
def get_participants(self, entity, limit=None, search='',
aggressive=False):
async def get_participants(self, entity, limit=None, search='',
aggressive=False):
"""
Gets the list of participants from the specified entity.
@ -1054,12 +1053,12 @@ class TelegramClient(TelegramBareClient):
A list of participants with an additional .total variable on the
list indicating the total amount of members in this group/channel.
"""
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
limit = float('inf') if limit is None else int(limit)
if isinstance(entity, InputPeerChannel):
total = self(GetFullChannelRequest(
total = (await self(GetFullChannelRequest(
entity
)).full_chat.participants_count
))).full_chat.participants_count
all_participants = {}
if total > 10000 and aggressive:
@ -1091,9 +1090,9 @@ class TelegramClient(TelegramBareClient):
break
if len(requests) == 1:
results = (self(requests[0]),)
results = (await self(requests[0]),)
else:
results = self(*requests)
results = await self(*requests)
for i in reversed(range(len(requests))):
participants = results[i]
if not participants.users:
@ -1111,7 +1110,7 @@ class TelegramClient(TelegramBareClient):
users = UserList(values)
users.total = total
elif isinstance(entity, InputPeerChat):
users = self(GetFullChatRequest(entity.chat_id)).users
users = (await self(GetFullChatRequest(entity.chat_id))).users
if len(users) > limit:
users = users[:limit]
users = UserList(users)
@ -1125,14 +1124,14 @@ class TelegramClient(TelegramBareClient):
# region Uploading files
def send_file(self, entity, file, caption=None,
force_document=False, progress_callback=None,
reply_to=None,
attributes=None,
thumb=None,
allow_cache=True,
parse_mode='md',
**kwargs):
async def send_file(self, entity, file, caption=None,
force_document=False, progress_callback=None,
reply_to=None,
attributes=None,
thumb=None,
allow_cache=True,
parse_mode='md',
**kwargs):
"""
Sends a file to the specified entity.
@ -1201,14 +1200,14 @@ class TelegramClient(TelegramBareClient):
# Convert to tuple so we can iterate several times
file = tuple(x for x in file)
if all(utils.is_image(x) for x in file):
return self._send_album(
return await self._send_album(
entity, file, caption=caption,
progress_callback=progress_callback, reply_to=reply_to,
parse_mode=parse_mode
)
# Not all are images, so send all the files one by one
return [
self.send_file(
await self.send_file(
entity, x, allow_cache=False,
caption=caption, force_document=force_document,
progress_callback=progress_callback, reply_to=reply_to,
@ -1216,7 +1215,7 @@ class TelegramClient(TelegramBareClient):
) for x in file
]
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
reply_to = self._get_message_id(reply_to)
caption, msg_entities = self._parse_message_text(caption, parse_mode)
@ -1233,11 +1232,11 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=reply_to,
message=caption,
entities=msg_entities)
return self._get_response_message(request, self(request))
return self._get_response_message(request, await self(request))
as_image = utils.is_image(file) and not force_document
use_cache = InputPhoto if as_image else InputDocument
file_handle = self.upload_file(
file_handle = await self.upload_file(
file, progress_callback=progress_callback,
use_cache=use_cache if allow_cache else None
)
@ -1314,7 +1313,7 @@ class TelegramClient(TelegramBareClient):
input_kw = {}
if thumb:
input_kw['thumb'] = self.upload_file(thumb)
input_kw['thumb'] = await self.upload_file(thumb)
media = InputMediaUploadedDocument(
file=file_handle,
@ -1327,7 +1326,7 @@ class TelegramClient(TelegramBareClient):
# send the media message to the desired entity.
request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to,
message=caption, entities=msg_entities)
msg = self._get_response_message(request, self(request))
msg = self._get_response_message(request, await self(request))
if msg and isinstance(file_handle, InputSizedFile):
# There was a response message and we didn't use cached
# version, so cache whatever we just sent to the database.
@ -1345,7 +1344,7 @@ class TelegramClient(TelegramBareClient):
kwargs['is_voice_note'] = True
return self.send_file(*args, **kwargs)
def _send_album(self, entity, files, caption=None,
async def _send_album(self, entity, files, caption=None,
progress_callback=None, reply_to=None,
parse_mode='md'):
"""Specialized version of .send_file for albums"""
@ -1354,7 +1353,7 @@ class TelegramClient(TelegramBareClient):
# we need to produce right now to send albums (uploadMedia), and
# cache only makes a difference for documents where the user may
# want the attributes used on them to change.
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if not utils.is_list_like(caption):
caption = (caption,)
captions = [
@ -1367,11 +1366,11 @@ class TelegramClient(TelegramBareClient):
media = []
for file in files:
# fh will either be InputPhoto or a modified InputFile
fh = self.upload_file(file, use_cache=InputPhoto)
fh = await self.upload_file(file, use_cache=InputPhoto)
if not isinstance(fh, InputPhoto):
input_photo = utils.get_input_photo(self(UploadMediaRequest(
input_photo = utils.get_input_photo((await self(UploadMediaRequest(
entity, media=InputMediaUploadedPhoto(fh)
)).photo)
))).photo)
self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo
@ -1383,7 +1382,7 @@ class TelegramClient(TelegramBareClient):
entities=msg_entities))
# Now we can construct the multi-media request
result = self(SendMultiMediaRequest(
result = await self(SendMultiMediaRequest(
entity, reply_to_msg_id=reply_to, multi_media=media
))
return [
@ -1392,12 +1391,8 @@ class TelegramClient(TelegramBareClient):
if isinstance(update, UpdateMessageID)
]
def upload_file(self,
file,
part_size_kb=None,
file_name=None,
use_cache=None,
progress_callback=None):
async def upload_file(self, file, part_size_kb=None, file_name=None,
use_cache=None, progress_callback=None):
"""
Uploads the specified file and returns a handle (an instance of
InputFile or InputFileBig, as required) which can be later used
@ -1510,7 +1505,7 @@ class TelegramClient(TelegramBareClient):
else:
request = SaveFilePartRequest(file_id, part_index, part)
result = self(request)
result = await self(request)
if result:
__log__.debug('Uploaded %d/%d', part_index + 1,
part_count)
@ -1531,7 +1526,7 @@ class TelegramClient(TelegramBareClient):
# region Downloading media requests
def download_profile_photo(self, entity, file=None, download_big=True):
async def download_profile_photo(self, entity, file=None, download_big=True):
"""
Downloads the profile photo of the given entity (user/chat/channel).
@ -1565,12 +1560,12 @@ class TelegramClient(TelegramBareClient):
# The hexadecimal numbers above are simply:
# hex(crc32(x.encode('ascii'))) for x in
# ('User', 'Chat', 'UserFull', 'ChatFull')
entity = self.get_entity(entity)
entity = await self.get_entity(entity)
if not hasattr(entity, 'photo'):
# Special case: may be a ChatFull with photo:Photo
# This is different from a normal UserProfilePhoto and Chat
if hasattr(entity, 'chat_photo'):
return self._download_photo(
return await self._download_photo(
entity.chat_photo, file,
date=None, progress_callback=None
)
@ -1595,7 +1590,7 @@ class TelegramClient(TelegramBareClient):
# Download the media with the largest size input file location
try:
self.download_file(
await self.download_file(
InputFileLocation(
volume_id=photo_location.volume_id,
local_id=photo_location.local_id,
@ -1606,10 +1601,10 @@ class TelegramClient(TelegramBareClient):
except LocationInvalidError:
# See issue #500, Android app fails as of v4.6.0 (1155).
# The fix seems to be using the full channel chat photo.
ie = self.get_input_entity(entity)
ie = await self.get_input_entity(entity)
if isinstance(ie, InputPeerChannel):
full = self(GetFullChannelRequest(ie))
return self._download_photo(
full = await self(GetFullChannelRequest(ie))
return await self._download_photo(
full.full_chat.chat_photo, file,
date=None, progress_callback=None
)
@ -1618,7 +1613,7 @@ class TelegramClient(TelegramBareClient):
return None
return file
def download_media(self, message, file=None, progress_callback=None):
async def download_media(self, message, file=None, progress_callback=None):
"""
Downloads the given media, or the media from a specified Message.
@ -1646,19 +1641,19 @@ class TelegramClient(TelegramBareClient):
media = message
if isinstance(media, (MessageMediaPhoto, Photo)):
return self._download_photo(
return await self._download_photo(
media, file, date, progress_callback
)
elif isinstance(media, (MessageMediaDocument, Document)):
return self._download_document(
return await self._download_document(
media, file, date, progress_callback
)
elif isinstance(media, MessageMediaContact):
return self._download_contact(
return await self._download_contact(
media, file
)
def _download_photo(self, photo, file, date, progress_callback):
async def _download_photo(self, photo, file, date, progress_callback):
"""Specialized version of .download_media() for photos"""
# Determine the photo and its largest size
if isinstance(photo, MessageMediaPhoto):
@ -1673,7 +1668,7 @@ class TelegramClient(TelegramBareClient):
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
# Download the media with the largest size input file location
self.download_file(
await self.download_file(
InputFileLocation(
volume_id=largest_size.volume_id,
local_id=largest_size.local_id,
@ -1685,7 +1680,7 @@ class TelegramClient(TelegramBareClient):
)
return file
def _download_document(self, document, file, date, progress_callback):
async def _download_document(self, document, file, date, progress_callback):
"""Specialized version of .download_media() for documents"""
if isinstance(document, MessageMediaDocument):
document = document.document
@ -1718,7 +1713,7 @@ class TelegramClient(TelegramBareClient):
date=date, possible_names=possible_names
)
self.download_file(
await self.download_file(
InputDocumentFileLocation(
id=document.id,
access_hash=document.access_hash,
@ -1825,12 +1820,8 @@ class TelegramClient(TelegramBareClient):
return result
i += 1
def download_file(self,
input_location,
file,
part_size_kb=None,
file_size=None,
progress_callback=None):
async def download_file(self, input_location, file, part_size_kb=None,
file_size=None, progress_callback=None):
"""
Downloads the given input location to a file.
@ -1889,23 +1880,24 @@ class TelegramClient(TelegramBareClient):
while True:
try:
if cdn_decrypter:
result = cdn_decrypter.get_file()
result = await cdn_decrypter.get_file()
else:
result = client(GetFileRequest(
result = await client(GetFileRequest(
input_location, offset, part_size
))
if isinstance(result, FileCdnRedirect):
__log__.info('File lives in a CDN')
cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result),
await CdnDecrypter.prepare_decrypter(
client,
await self._get_cdn_client(result),
result
)
except FileMigrateError as e:
__log__.info('File lives in another DC')
client = self._get_exported_client(e.new_dc)
client = await self._get_exported_client(e.new_dc)
continue
offset += part_size
@ -1947,25 +1939,25 @@ class TelegramClient(TelegramBareClient):
The event builder class or instance to be used,
for instance ``events.NewMessage``.
"""
def decorator(f):
self.add_event_handler(f, event)
async def decorator(f):
await self.add_event_handler(f, event)
return f
return decorator
def _check_events_pending_resolve(self):
async def _check_events_pending_resolve(self):
if self._events_pending_resolve:
for event in self._events_pending_resolve:
event.resolve(self)
await event.resolve(self)
self._events_pending_resolve.clear()
def _on_handler(self, update):
async def _on_handler(self, update):
for builder, callback in self._event_builders:
event = builder.build(update)
if event:
event._client = self
try:
callback(event)
await callback(event)
except events.StopPropagation:
__log__.debug(
"Event handler '{}' stopped chain of "
@ -1974,7 +1966,7 @@ class TelegramClient(TelegramBareClient):
)
break
def add_event_handler(self, callback, event=None):
async def add_event_handler(self, callback, event=None):
"""
Registers the given callback to be called on the specified event.
@ -1989,12 +1981,6 @@ class TelegramClient(TelegramBareClient):
If left unspecified, ``events.Raw`` (the ``Update`` objects
with no further processing) will be passed instead.
"""
if self.updates.workers is None:
warnings.warn(
"You have not setup any workers, so you won't receive updates."
" Pass update_workers=1 when creating the TelegramClient,"
" or set client.self.updates.workers = 1"
)
self.updates.handler = self._on_handler
if isinstance(event, type):
@ -2003,8 +1989,8 @@ class TelegramClient(TelegramBareClient):
event = events.Raw()
if self.is_user_authorized():
event.resolve(self)
self._check_events_pending_resolve()
await event.resolve(self)
await self._check_events_pending_resolve()
else:
self._events_pending_resolve.append(event)
@ -2031,11 +2017,11 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier
def _set_connected_and_authorized(self):
super()._set_connected_and_authorized()
self._check_events_pending_resolve()
async def _set_connected_and_authorized(self):
await super()._set_connected_and_authorized()
await self._check_events_pending_resolve()
def get_entity(self, entity):
async def get_entity(self, entity):
"""
Turns the given entity into a valid Telegram user or chat.
@ -2069,7 +2055,7 @@ class TelegramClient(TelegramBareClient):
# input channels (get channels) to get the most entities
# in the less amount of calls possible.
inputs = [
x if isinstance(x, str) else self.get_input_entity(x)
x if isinstance(x, str) else await self.get_input_entity(x)
for x in entity
]
users = [x for x in inputs if isinstance(x, InputPeerUser)]
@ -2080,12 +2066,12 @@ class TelegramClient(TelegramBareClient):
tmp = []
while users:
curr, users = users[:200], users[200:]
tmp.extend(self(GetUsersRequest(curr)))
tmp.extend(await self(GetUsersRequest(curr)))
users = tmp
if chats: # TODO Handle chats slice?
chats = self(GetChatsRequest(chats)).chats
chats = (await self(GetChatsRequest(chats))).chats
if channels:
channels = self(GetChannelsRequest(channels)).chats
channels = (await self(GetChannelsRequest(channels))).chats
# Merge users, chats and channels into a single dictionary
id_entity = {
@ -2098,33 +2084,31 @@ class TelegramClient(TelegramBareClient):
# the amount of ResolveUsername calls, it would fail to catch
# username changes.
result = [
self._get_entity_from_string(x) if isinstance(x, str)
await self._get_entity_from_string(x) if isinstance(x, str)
else id_entity[utils.get_peer_id(x)]
for x in inputs
]
return result[0] if single else result
def _get_entity_from_string(self, string):
async def _get_entity_from_string(self, string):
"""
Gets a full entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
The string may also be a user link, or a channel/chat invite link.
This method has the side effect of adding the found users to the
session database, so it can be queried later without API calls,
if this option is enabled on the session.
Returns the found entity, or raises TypeError if not found.
"""
phone = utils.parse_phone(string)
if phone:
for user in self(GetContactsRequest(0)).users:
for user in (await self(GetContactsRequest(0))).users:
if user.phone == phone:
return user
else:
username, is_join_chat = utils.parse_username(string)
if is_join_chat:
invite = self(CheckChatInviteRequest(username))
invite = await self(CheckChatInviteRequest(username))
if isinstance(invite, ChatInvite):
raise ValueError(
'Cannot get entity from a channel '
@ -2134,14 +2118,15 @@ class TelegramClient(TelegramBareClient):
return invite.chat
elif username:
if username in ('me', 'self'):
return self.get_me()
result = self(ResolveUsernameRequest(username))
return await self.get_me()
result = await self(ResolveUsernameRequest(username))
for entity in itertools.chain(result.users, result.chats):
if entity.username.lower() == username:
return entity
try:
# Nobody with this username, maybe it's an exact name/title
return self.get_entity(self.session.get_input_entity(string))
return await self.get_entity(
self.session.get_input_entity(string))
except ValueError:
pass
@ -2149,24 +2134,20 @@ class TelegramClient(TelegramBareClient):
'Cannot turn "{}" into any entity (user or chat)'.format(string)
)
def get_input_entity(self, peer):
async def get_input_entity(self, peer):
"""
Turns the given peer into its input entity version. Most requests
use this kind of InputUser, InputChat and so on, so this is the
most suitable call to make for those cases.
entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`):
The integer ID of an user or otherwise either of a
``PeerUser``, ``PeerChat`` or ``PeerChannel``, for
which to get its ``Input*`` version.
If this ``Peer`` hasn't been seen before by the library, the top
dialogs will be loaded and their entities saved to the session
file (unless this feature was disabled explicitly).
If in the end the access hash required for the peer was not found,
a ValueError will be raised.
Returns:
``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``.
"""
@ -2179,7 +2160,7 @@ class TelegramClient(TelegramBareClient):
if isinstance(peer, str):
if peer in ('me', 'self'):
return InputPeerSelf()
return utils.get_input_peer(self._get_entity_from_string(peer))
return utils.get_input_peer(await self._get_entity_from_string(peer))
if isinstance(peer, int):
peer, kind = utils.resolve_id(peer)
@ -2206,7 +2187,7 @@ class TelegramClient(TelegramBareClient):
limit=100
)
while True:
result = self(req)
result = await self(req)
entities = {}
for x in itertools.chain(result.users, result.chats):
x_id = utils.get_peer_id(x)
@ -2222,7 +2203,7 @@ class TelegramClient(TelegramBareClient):
req.offset_peer = entities[utils.get_peer_id(
result.dialogs[-1].peer
)]
time.sleep(1)
asyncio.sleep(1)
raise TypeError(
'Could not find the input entity corresponding to "{}". '

View File

@ -26,9 +26,9 @@ class Dialog:
self.draft = Draft(client, dialog.peer, dialog.draft)
def send_message(self, *args, **kwargs):
async def send_message(self, *args, **kwargs):
"""
Sends a message to this dialog. This is just a wrapper around
client.send_message(dialog.input_entity, *args, **kwargs).
"""
return self._client.send_message(self.input_entity, *args, **kwargs)
return await self._client.send_message(self.input_entity, *args, **kwargs)

View File

@ -31,14 +31,14 @@ class Draft:
return cls(client=client, peer=update.peer, draft=update.draft)
@property
def entity(self):
return self._client.get_entity(self._peer)
async def entity(self):
return await self._client.get_entity(self._peer)
@property
def input_entity(self):
return self._client.get_input_entity(self._peer)
async def input_entity(self):
return await self._client.get_input_entity(self._peer)
def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
async def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
"""
Changes the draft message on the Telegram servers. The changes are
reflected in this object. Changing only individual attributes like for
@ -58,7 +58,7 @@ class Draft:
:param list entities: A list of formatting entities
:return bool: ``True`` on success
"""
result = self._client(SaveDraftRequest(
result = await self._client(SaveDraftRequest(
peer=self._peer,
message=text,
no_webpage=no_webpage,
@ -74,9 +74,9 @@ class Draft:
return result
def delete(self):
async def delete(self):
"""
Deletes this draft
:return bool: ``True`` on success
"""
return self.set_message(text='')
return await self.set_message(text='')

View File

@ -1,11 +1,10 @@
import struct
from datetime import datetime, date
from threading import Event
class TLObject:
def __init__(self):
self.confirm_received = Event()
self.confirm_received = None
self.rpc_error = None
self.result = None
@ -157,7 +156,7 @@ class TLObject:
return TLObject.pretty_format(self, indent=0)
# These should be overrode
def resolve(self, client, utils):
async def resolve(self, client, utils):
pass
def to_dict(self):

View File

@ -1,9 +1,8 @@
import logging
import pickle
import asyncio
from collections import deque
from queue import Queue, Empty
from datetime import datetime
from threading import RLock, Thread
from .tl import types as tl
@ -16,125 +15,40 @@ class UpdateState:
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke the self.handler.
"""
self._workers = workers
self._worker_threads = []
def __init__(self, loop=None):
self.handler = None
self._updates_lock = RLock()
self._updates = Queue()
self._loop = loop if loop else asyncio.get_event_loop()
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self):
"""Returns True if a call to .poll() won't lock"""
return not self._updates.empty()
def poll(self, timeout=None):
"""
Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
try:
return self._updates.get(timeout=timeout)
except Empty:
return None
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
if n is None:
self.stop_workers()
else:
self._workers = n
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""
Waits for all the worker threads to stop.
"""
# Put dummy ``None`` objects so that they don't need to timeout.
n = self._workers
self._workers = None
if n:
with self._updates_lock:
for _ in range(n):
self._updates.put(None)
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while self._workers is not None:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
if update and self.handler:
self.handler(update)
except StopIteration:
break
except:
# We don't want to crash a worker thread due to any reason
__log__.exception('Unhandled exception on worker %d', wid)
def handle_update(self, update):
if self.handler:
asyncio.ensure_future(self.handler(update), loop=self._loop)
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if self._workers is None:
return # No processing needs to be done if nobody's working
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
with self._updates_lock:
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
if hasattr(update, 'pts'):
self._state.pts = update.pts
if hasattr(update, 'pts'):
self._state.pts = update.pts
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
if isinstance(update, tl.UpdateShort):
self._updates.put(update.update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
for u in update.updates:
self._updates.put(u)
# TODO Handle "tl.UpdatesTooLong"
else:
self._updates.put(update)
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
if isinstance(update, tl.UpdateShort):
self.handle_update(update.update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
for u in update.updates:
self.handle_update(u)
# TODO Handle "tl.UpdatesTooLong"
else:
self.handle_update(update)

View File

@ -11,9 +11,9 @@ AUTO_GEN_NOTICE = \
AUTO_CASTS = {
'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))',
'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))',
'InputUser': 'utils.get_input_user(client.get_input_entity({}))',
'InputPeer': 'utils.get_input_peer(await client.get_input_entity({}))',
'InputChannel': 'utils.get_input_channel(await client.get_input_entity({}))',
'InputUser': 'utils.get_input_user(await client.get_input_entity({}))',
'InputMedia': 'utils.get_input_media({})',
'InputPhoto': 'utils.get_input_photo({})'
}
@ -289,7 +289,7 @@ class TLGenerator:
# Write the resolve(self, client, utils) method
if any(arg.type in AUTO_CASTS for arg in args):
builder.writeln('def resolve(self, client, utils):')
builder.writeln('async def resolve(self, client, utils):')
for arg in args:
ac = AUTO_CASTS.get(arg.type, None)
if ac: