Use async def everywhere

This commit is contained in:
Lonami Exo 2017-10-06 21:02:41 +02:00
parent 9716d1d543
commit 77c99db066
7 changed files with 206 additions and 208 deletions

View File

@ -1,9 +1,12 @@
# Python rough implementation of a C# TCP client # Python rough implementation of a C# TCP client
import asyncio
import errno import errno
import socket import socket
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop()
class TcpClient: class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
@ -30,7 +33,7 @@ class TcpClient:
self._socket.settimeout(self.timeout) self._socket.settimeout(self.timeout)
def connect(self, ip, port): async def connect(self, ip, port):
"""Connects to the specified IP and port number. """Connects to the specified IP and port number.
'timeout' must be given in seconds 'timeout' must be given in seconds
""" """
@ -44,7 +47,7 @@ class TcpClient:
while not self._socket: while not self._socket:
self._recreate_socket(mode) self._recreate_socket(mode)
self._socket.connect(address) await loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect break # Successful connection, stop retrying to connect
except OSError as e: except OSError as e:
# There are some errors that we know how to handle, and # There are some errors that we know how to handle, and
@ -72,15 +75,13 @@ class TcpClient:
finally: finally:
self._socket = None self._socket = None
def write(self, data): async def write(self, data):
"""Writes (sends) the specified bytes to the connected peer""" """Writes (sends) the specified bytes to the connected peer"""
if self._socket is None: if self._socket is None:
raise ConnectionResetError() raise ConnectionResetError()
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try: try:
self._socket.sendall(data) await loop.sock_sendall(self._socket, data)
except socket.timeout as e: except socket.timeout as e:
raise TimeoutError() from e raise TimeoutError() from e
except BrokenPipeError: except BrokenPipeError:
@ -91,14 +92,9 @@ class TcpClient:
else: else:
raise raise
def read(self, size): async def read(self, size):
"""Reads (receives) a whole block of 'size bytes """Reads (receives) a whole block of 'size bytes
from the connected peer. from the connected peer.
A timeout can be specified, which will cancel the operation if
no data has been read in the specified time. If data was read
and it's waiting for more, the timeout will NOT cancel the
operation. Set to None for no timeout
""" """
if self._socket is None: if self._socket is None:
raise ConnectionResetError() raise ConnectionResetError()
@ -108,7 +104,7 @@ class TcpClient:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
try: try:
partial = self._socket.recv(bytes_left) partial = await loop.sock_recv(self._socket, bytes_left)
except socket.timeout as e: except socket.timeout as e:
raise TimeoutError() from e raise TimeoutError() from e
except OSError as e: except OSError as e:

View File

@ -17,21 +17,21 @@ from ..tl.functions import (
) )
def do_authentication(connection, retries=5): async def do_authentication(connection, retries=5):
if not retries or retries < 0: if not retries or retries < 0:
retries = 1 retries = 1
last_error = None last_error = None
while retries: while retries:
try: try:
return _do_authentication(connection) return await _do_authentication(connection)
except (SecurityError, AssertionError, NotImplementedError) as e: except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e last_error = e
retries -= 1 retries -= 1
raise last_error raise last_error
def _do_authentication(connection): async def _do_authentication(connection):
"""Executes the authentication process with the Telegram servers. """Executes the authentication process with the Telegram servers.
If no error is raised, returns both the authorization key and the If no error is raised, returns both the authorization key and the
time offset. time offset.
@ -42,8 +42,8 @@ def _do_authentication(connection):
req_pq_request = ReqPqRequest( req_pq_request = ReqPqRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True) nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
) )
sender.send(req_pq_request.to_bytes()) await sender.send(req_pq_request.to_bytes())
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
req_pq_request.on_response(reader) req_pq_request.on_response(reader)
res_pq = req_pq_request.result res_pq = req_pq_request.result
@ -90,10 +90,10 @@ def _do_authentication(connection):
public_key_fingerprint=target_fingerprint, public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text encrypted_data=cipher_text
) )
sender.send(req_dh_params.to_bytes()) await sender.send(req_dh_params.to_bytes())
# Step 2 response: DH Exchange # Step 2 response: DH Exchange
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
req_dh_params.on_response(reader) req_dh_params.on_response(reader)
server_dh_params = req_dh_params.result server_dh_params = req_dh_params.result
@ -157,10 +157,10 @@ def _do_authentication(connection):
server_nonce=res_pq.server_nonce, server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted, encrypted_data=client_dh_encrypted,
) )
sender.send(set_client_dh.to_bytes()) await sender.send(set_client_dh.to_bytes())
# Step 3 response: Complete DH Exchange # Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader: with BinaryReader(await sender.receive()) as reader:
set_client_dh.on_response(reader) set_client_dh.on_response(reader)
dh_gen = set_client_dh.result dh_gen = set_client_dh.result

View File

@ -1,14 +1,13 @@
import errno
import os import os
import struct import struct
from datetime import timedelta from datetime import timedelta
from zlib import crc32
from enum import Enum from enum import Enum
from zlib import crc32
import errno
from ..crypto import AESModeCTR from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError from ..errors import InvalidChecksumError
from ..extensions import TcpClient
class ConnectionMode(Enum): class ConnectionMode(Enum):
@ -74,9 +73,9 @@ class Connection:
setattr(self, 'write', self._write_plain) setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain) setattr(self, 'read', self._read_plain)
def connect(self, ip, port): async def connect(self, ip, port):
try: try:
self.conn.connect(ip, port) await self.conn.connect(ip, port)
except OSError as e: except OSError as e:
if e.errno == errno.EISCONN: if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up return # Already connected, no need to re-set everything up
@ -85,16 +84,16 @@ class Connection:
self._send_counter = 0 self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED: if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef') await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE: elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee') await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED: elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation() await self._setup_obfuscation()
def get_timeout(self): def get_timeout(self):
return self.conn.timeout return self.conn.timeout
def _setup_obfuscation(self): async def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these # Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True: while True:
@ -119,7 +118,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random)) await self.conn.write(bytes(random))
def is_connected(self): def is_connected(self):
return self.conn.connected return self.conn.connected
@ -135,20 +134,23 @@ class Connection:
# region Receive message implementations # region Receive message implementations
def recv(self): async def recv(self):
"""Receives and unpacks a message""" """Receives and unpacks a message"""
# Default implementation is just an error # Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self): async def _recv_tcp_full(self):
packet_length_bytes = self.read(4) # TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_length_bytes = await self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little') packet_length = int.from_bytes(packet_length_bytes, 'little')
seq_bytes = self.read(4) seq_bytes = await self.read(4)
seq = int.from_bytes(seq_bytes, 'little') seq = int.from_bytes(seq_bytes, 'little')
body = self.read(packet_length - 12) body = await self.read(packet_length - 12)
checksum = int.from_bytes(self.read(4), 'little') checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body) valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum: if checksum != valid_checksum:
@ -156,72 +158,70 @@ class Connection:
return body return body
def _recv_intermediate(self): async def _recv_intermediate(self):
return self.read(int.from_bytes(self.read(4), 'little')) return await self.read(int.from_bytes(self.read(4), 'little'))
def _recv_abridged(self): async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little') length = int.from_bytes(self.read(1), 'little')
if length >= 127: if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little') length = int.from_bytes(self.read(3) + b'\0', 'little')
return self.read(length << 2) return await self.read(length << 2)
# endregion # endregion
# region Send message implementations # region Send message implementations
def send(self, message): async def send(self, message):
"""Encapsulates and sends the given message""" """Encapsulates and sends the given message"""
# Default implementation is just an error # Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message): async def _send_tcp_full(self, message):
# https://core.telegram.org/mtproto#tcp-transport # https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32) # total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12 length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data)) crc = struct.pack('<I', crc32(data))
self._send_counter += 1 self._send_counter += 1
self.write(data + crc) await self.write(data + crc)
def _send_intermediate(self, message): async def _send_intermediate(self, message):
self.write(struct.pack('<i', len(message)) + message) await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message): async def _send_abridged(self, message):
length = len(message) >> 2 length = len(message) >> 2
if length < 127: if length < 127:
length = struct.pack('B', length) length = struct.pack('B', length)
else: else:
length = b'\x7f' + int.to_bytes(length, 3, 'little') length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message) await self.write(length + message)
# endregion # endregion
# region Read implementations # region Read implementations
def read(self, length): async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length): async def _read_plain(self, length):
return self.conn.read(length) return await self.conn.read(length)
def _read_obfuscated(self, length): async def _read_obfuscated(self, length):
return self._aes_decrypt.encrypt( return await self._aes_decrypt.encrypt(self.conn.read(length))
self.conn.read(length)
)
# endregion # endregion
# region Write implementations # region Write implementations
def write(self, data): async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode)) raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data): async def _write_plain(self, data):
self.conn.write(data) await self.conn.write(data)
def _write_obfuscated(self, data): async def _write_obfuscated(self, data):
self.conn.write(self._aes_encrypt.encrypt(data)) await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion # endregion

View File

@ -16,23 +16,23 @@ class MtProtoPlainSender:
self._last_msg_id = 0 self._last_msg_id = 0
self._connection = connection self._connection = connection
def connect(self): async def connect(self):
self._connection.connect() await self._connection.connect()
def disconnect(self): def disconnect(self):
self._connection.close() self._connection.close()
def send(self, data): async def send(self, data):
"""Sends a plain packet (auth_key_id = 0) containing the """Sends a plain packet (auth_key_id = 0) containing the
given message body (data) given message body (data)
""" """
self._connection.send( await self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
) )
def receive(self): async def receive(self):
"""Receives a plain packet, returning the body of the response""" """Receives a plain packet, returning the body of the response"""
body = self._connection.recv() body = await self._connection.recv()
if body == b'l\xfe\xff\xff': # -404 little endian signed if body == b'l\xfe\xff\xff': # -404 little endian signed
# Broken authorization, must reset the auth key # Broken authorization, must reset the auth key
raise BrokenAuthKeyError() raise BrokenAuthKeyError()

View File

@ -41,9 +41,9 @@ class MtProtoSender:
# Requests (as msg_id: Message) sent waiting to be received # Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {} self._pending_receive = {}
def connect(self): async def connect(self):
"""Connects to the server""" """Connects to the server"""
self.connection.connect(self.session.server_address, self.session.port) await self.connection.connect(self.session.server_address, self.session.port)
def is_connected(self): def is_connected(self):
return self.connection.is_connected() return self.connection.is_connected()
@ -60,7 +60,7 @@ class MtProtoSender:
# region Send and receive # region Send and receive
def send(self, *requests): async def send(self, *requests):
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation.""" which needed confirmation."""
@ -80,13 +80,13 @@ class MtProtoSender:
else: else:
message = TLMessage(self.session, MessageContainer(messages)) message = TLMessage(self.session, MessageContainer(messages))
self._send_message(message) await self._send_message(message)
def _send_acknowledge(self, msg_id): async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id""" """Sends a message acknowledge for the given msg_id"""
self._send_message(TLMessage(self.session, MsgsAck([msg_id]))) await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
def receive(self, update_state): async def receive(self, update_state):
"""Receives a single message from the connected endpoint. """Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts This method returns nothing, and will only affect other parts
@ -97,7 +97,7 @@ class MtProtoSender:
update_state.process(TLObject). update_state.process(TLObject).
""" """
try: try:
body = self.connection.recv() body = await self.connection.recv()
except (BufferError, InvalidChecksumError): except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause... # TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear # "No more bytes left"; something wrong happened, clear
@ -111,13 +111,13 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_seq, reader, update_state) await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
# endregion # endregion
# region Low level processing # region Low level processing
def _send_message(self, message): async def _send_message(self, message):
"""Sends the given Message(TLObject) encrypted through the network""" """Sends the given Message(TLObject) encrypted through the network"""
plain_text = \ plain_text = \
@ -130,7 +130,7 @@ class MtProtoSender:
cipher_text = AES.encrypt_ige(plain_text, key, iv) cipher_text = AES.encrypt_ige(plain_text, key, iv)
result = key_id + msg_key + cipher_text result = key_id + msg_key + cipher_text
self.connection.send(result) await self.connection.send(result)
def _decode_msg(self, body): def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes""" """Decodes an received encrypted message body bytes"""
@ -163,7 +163,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, state): async def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message. """Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't Returns True if the message was handled correctly and doesn't
@ -178,22 +178,22 @@ class MtProtoSender:
# The following codes are "parsed manually" # The following codes are "parsed manually"
if code == 0xf35c6d01: # rpc_result, (response of an RPC call) if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return self._handle_rpc_result(msg_id, sequence, reader) return await self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong if code == 0x347773c5: # pong
return self._handle_pong(msg_id, sequence, reader) return await self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, state) return await self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, state) return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader) return await self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification(msg_id, sequence, reader) return await self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted # msgs_ack, it may handle the request we wanted
if code == 0x62d6b459: if code == 0x62d6b459:
@ -247,7 +247,7 @@ class MtProtoSender:
r.confirm_received.set() r.confirm_received.set()
self._pending_receive.clear() self._pending_receive.clear()
def _handle_pong(self, msg_id, sequence, reader): async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
received_msg_id = reader.read_long() received_msg_id = reader.read_long()
@ -259,7 +259,7 @@ class MtProtoSender:
return True return True
def _handle_container(self, msg_id, sequence, reader, state): async def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container') self._logger.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader): for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position() begin_position = reader.tell_position()
@ -267,7 +267,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not self._process_msg(inner_msg_id, sequence, reader, state): if not await self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_len) reader.set_position(begin_position + inner_len)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
@ -276,7 +276,7 @@ class MtProtoSender:
return True return True
def _handle_bad_server_salt(self, msg_id, sequence, reader): async def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt') self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
bad_msg_id = reader.read_long() bad_msg_id = reader.read_long()
@ -287,11 +287,11 @@ class MtProtoSender:
request = self._pop_request(bad_msg_id) request = self._pop_request(bad_msg_id)
if request: if request:
self.send(request) await self.send(request)
return True return True
def _handle_bad_msg_notification(self, msg_id, sequence, reader): async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification') self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
reader.read_long() # request_id reader.read_long() # request_id
@ -318,7 +318,7 @@ class MtProtoSender:
else: else:
raise error raise error
def _handle_rpc_result(self, msg_id, sequence, reader): async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result') self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
request_id = reader.read_long() request_id = reader.read_long()
@ -338,7 +338,7 @@ class MtProtoSender:
) )
# Acknowledge that we received the error # Acknowledge that we received the error
self._send_acknowledge(request_id) await self._send_acknowledge(request_id)
if request: if request:
request.rpc_error = error request.rpc_error = error
@ -366,9 +366,9 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.') self._logger.debug('Lost request will be skipped.')
return False return False
def _handle_gzip_packed(self, msg_id, sequence, reader, state): async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, state) return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion # endregion

View File

@ -137,7 +137,7 @@ class TelegramBareClient:
# region Connecting # region Connecting
def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False): async def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
"""Connects to the Telegram servers, executing authentication if """Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which not the same as authenticating the desired user itself, which
@ -158,13 +158,13 @@ class TelegramBareClient:
centers won't be invoked. centers won't be invoked.
""" """
try: try:
self._sender.connect() await self._sender.connect()
if not self.session.auth_key: if not self.session.auth_key:
# New key, we need to tell the server we're going to use # New key, we need to tell the server we're going to use
# the latest layer # the latest layer
try: try:
self.session.auth_key, self.session.time_offset = \ self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(self._sender.connection) await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError: except BrokenAuthKeyError:
return False return False
@ -176,21 +176,21 @@ class TelegramBareClient:
if init_connection: if init_connection:
if _exported_auth is not None: if _exported_auth is not None:
self._init_connection(ImportAuthorizationRequest( await self._init_connection(ImportAuthorizationRequest(
_exported_auth.id, _exported_auth.bytes _exported_auth.id, _exported_auth.bytes
)) ))
elif not _cdn: elif not _cdn:
TelegramBareClient._dc_options = \ TelegramBareClient._dc_options = \
self._init_connection(GetConfigRequest()).dc_options (await self._init_connection(GetConfigRequest())).dc_options
elif _exported_auth is not None: elif _exported_auth is not None:
self(ImportAuthorizationRequest( await self(ImportAuthorizationRequest(
_exported_auth.id, _exported_auth.bytes _exported_auth.id, _exported_auth.bytes
)) ))
if TelegramBareClient._dc_options is None and not _cdn: if TelegramBareClient._dc_options is None and not _cdn:
TelegramBareClient._dc_options = \ TelegramBareClient._dc_options = \
self(GetConfigRequest()).dc_options (await self(GetConfigRequest())).dc_options
# Connection was successful! Try syncing the update state # Connection was successful! Try syncing the update state
# UNLESS '_sync_updates' is False (we probably are in # UNLESS '_sync_updates' is False (we probably are in
@ -199,7 +199,7 @@ class TelegramBareClient:
self._user_connected = True self._user_connected = True
if _sync_updates and not _cdn: if _sync_updates and not _cdn:
try: try:
self.sync_updates() await self.sync_updates()
self._set_connected_and_authorized() self._set_connected_and_authorized()
except UnauthorizedError: except UnauthorizedError:
self._authorized = False self._authorized = False
@ -227,8 +227,8 @@ class TelegramBareClient:
def is_connected(self): def is_connected(self):
return self._sender.is_connected() return self._sender.is_connected()
def _init_connection(self, query=None): async def _init_connection(self, query=None):
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest( result = await self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
api_id=self.api_id, api_id=self.api_id,
device_model=self.session.device_model, device_model=self.session.device_model,
system_version=self.session.system_version, system_version=self.session.system_version,
@ -249,7 +249,7 @@ class TelegramBareClient:
# TODO Shall we clear the _exported_sessions, or may be reused? # TODO Shall we clear the _exported_sessions, or may be reused?
pass pass
def _reconnect(self, new_dc=None): async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made """If 'new_dc' is not set, only a call to .connect() will be made
since it's assumed that the connection has been lost and the since it's assumed that the connection has been lost and the
library is reconnecting. library is reconnecting.
@ -260,7 +260,7 @@ class TelegramBareClient:
""" """
if new_dc is None: if new_dc is None:
# Assume we are disconnected due to some error, so connect again # Assume we are disconnected due to some error, so connect again
return self.connect() return await self.connect()
else: else:
self.disconnect() self.disconnect()
self.session.auth_key = None # Force creating new auth_key self.session.auth_key = None # Force creating new auth_key
@ -269,23 +269,24 @@ class TelegramBareClient:
self.session.server_address = ip self.session.server_address = ip
self.session.port = dc.port self.session.port = dc.port
self.session.save() self.session.save()
return self.connect() return await self.connect()
# endregion # endregion
# region Working with different connections/Data Centers # region Working with different connections/Data Centers
def _get_dc(self, dc_id, ipv6=False, cdn=False): async def _get_dc(self, dc_id, ipv6=False, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'""" """Gets the Data Center (DC) associated to 'dc_id'"""
if TelegramBareClient._dc_options is None: if TelegramBareClient._dc_options is None:
raise ConnectionError( raise ConnectionError(
'Cannot determine the required data center IP address. ' 'Cannot determine the required data center IP address. '
'Stabilise a successful initial connection first.') 'Stabilise a successful initial connection first.'
)
try: try:
if cdn: if cdn:
# Ensure we have the latest keys for the CDNs # Ensure we have the latest keys for the CDNs
for pk in self(GetCdnConfigRequest()).public_keys: for pk in await (self(GetCdnConfigRequest())).public_keys:
rsa.add_key(pk.public_key) rsa.add_key(pk.public_key)
return next( return next(
@ -297,10 +298,10 @@ class TelegramBareClient:
raise raise
# New configuration, perhaps a new CDN was added? # New configuration, perhaps a new CDN was added?
TelegramBareClient._dc_options = self(GetConfigRequest()).dc_options TelegramBareClient._dc_options = await (self(GetConfigRequest())).dc_options
return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn) return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn)
def _get_exported_client(self, dc_id): async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC. """Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time calling the method with a given dc_id, If it's the first time calling the method with a given dc_id,
@ -317,10 +318,10 @@ class TelegramBareClient:
# TODO Add a lock, don't allow two threads to create an auth key # TODO Add a lock, don't allow two threads to create an auth key
# (when calling .connect() if there wasn't a previous session). # (when calling .connect() if there wasn't a previous session).
# for the same data center. # for the same data center.
dc = self._get_dc(dc_id) dc = await self._get_dc(dc_id)
# Export the current authorization to the new DC. # Export the current authorization to the new DC.
export_auth = self(ExportAuthorizationRequest(dc_id)) export_auth = await self(ExportAuthorizationRequest(dc_id))
# Create a temporary session for this IP address, which needs # Create a temporary session for this IP address, which needs
# to be different because each auth_key is unique per DC. # to be different because each auth_key is unique per DC.
@ -337,15 +338,15 @@ class TelegramBareClient:
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout()
) )
client.connect(_exported_auth=export_auth, _sync_updates=False) await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth client._authorized = True # We exported the auth, so we got auth
return client return client
def _get_cdn_client(self, cdn_redirect): async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs""" """Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(self.session) session = Session(self.session)
session.server_address = dc.ip_address session.server_address = dc.ip_address
session.port = dc.port session.port = dc.port
@ -361,7 +362,7 @@ class TelegramBareClient:
# #
# This relies on the fact that TelegramBareClient._dc_options is # This relies on the fact that TelegramBareClient._dc_options is
# static and it won't be called from this DC (it would fail). # static and it won't be called from this DC (it would fail).
client.connect(_cdn=True) # Avoid invoking non-CDN specific methods await client.connect(_cdn=True) # Avoid invoking non-CDN methods
client._authorized = self._authorized client._authorized = self._authorized
return client return client
@ -369,7 +370,7 @@ class TelegramBareClient:
# region Invoking Telegram requests # region Invoking Telegram requests
def __call__(self, *requests, retries=5): async def __call__(self, *requests, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) a MTProtoRequest and returns (receives) its result.
The invoke will be retried up to 'retries' times before raising The invoke will be retried up to 'retries' times before raising
@ -384,7 +385,7 @@ class TelegramBareClient:
try: try:
for _ in range(retries): for _ in range(retries):
result = self._invoke(sender, *requests) result = await self._invoke(sender, *requests)
if result: if result:
return result return result
@ -396,16 +397,16 @@ class TelegramBareClient:
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
def _invoke(self, sender, *requests): async def _invoke(self, sender, *requests):
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
for x in requests: for x in requests:
x.confirm_received.clear() x.confirm_received.clear()
x.rpc_error = None x.rpc_error = None
sender.send(*requests) await sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests): while not all(x.confirm_received.is_set() for x in requests):
sender.receive(update_state=self.updates) await sender.receive(update_state=self.updates)
except TimeoutError: except TimeoutError:
pass # We will just retry pass # We will just retry
@ -420,9 +421,9 @@ class TelegramBareClient:
if sender != self._sender: if sender != self._sender:
# TODO Try reconnecting forever too? # TODO Try reconnecting forever too?
sender.connect() await sender.connect()
else: else:
while self._user_connected and not self._reconnect(): while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request sleep(0.1) # Retry forever until we can send the request
finally: finally:
@ -449,8 +450,8 @@ class TelegramBareClient:
'attempting to reconnect at DC {}'.format(e.new_dc) 'attempting to reconnect at DC {}'.format(e.new_dc)
) )
self._reconnect(new_dc=e.new_dc) await self._reconnect(new_dc=e.new_dc)
return self._invoke(sender, *requests) return await self._invoke(sender, *requests)
except ServerError as e: except ServerError as e:
# Telegram is having some issues, just retry # Telegram is having some issues, just retry
@ -474,7 +475,7 @@ class TelegramBareClient:
# region Uploading media # region Uploading media
def upload_file(self, async def upload_file(self,
file, file,
part_size_kb=None, part_size_kb=None,
file_name=None, file_name=None,
@ -537,7 +538,7 @@ class TelegramBareClient:
else: else:
request = SaveFilePartRequest(file_id, part_index, part) request = SaveFilePartRequest(file_id, part_index, part)
result = self(request) result = await self(request)
if result: if result:
if not is_large: if not is_large:
# No need to update the hash if it's a large file # No need to update the hash if it's a large file
@ -568,7 +569,7 @@ class TelegramBareClient:
# region Downloading media # region Downloading media
def download_file(self, async def download_file(self,
input_location, input_location,
file, file,
part_size_kb=None, part_size_kb=None,
@ -616,18 +617,20 @@ class TelegramBareClient:
if cdn_decrypter: if cdn_decrypter:
result = cdn_decrypter.get_file() result = cdn_decrypter.get_file()
else: else:
result = client(GetFileRequest( result = await client(GetFileRequest(
input_location, offset, part_size input_location, offset, part_size
)) ))
if isinstance(result, FileCdnRedirect): if isinstance(result, FileCdnRedirect):
cdn_decrypter, result = \ cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter( CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result), result client,
await self._get_cdn_client(result),
result
) )
except FileMigrateError as e: except FileMigrateError as e:
client = self._get_exported_client(e.new_dc) client = await self._get_exported_client(e.new_dc)
continue continue
offset_index += 1 offset_index += 1
@ -657,12 +660,12 @@ class TelegramBareClient:
# region Updates handling # region Updates handling
def sync_updates(self): async def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be """Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True, called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates. otherwise it should be called manually after enabling updates.
""" """
self.updates.process(self(GetStateRequest())) self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler): def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,

View File

@ -99,15 +99,15 @@ class TelegramClient(TelegramBareClient):
# region Authorization requests # region Authorization requests
def send_code_request(self, phone): async def send_code_request(self, phone):
"""Sends a code request to the specified phone number""" """Sends a code request to the specified phone number"""
phone = EntityDatabase.parse_phone(phone) or self._phone phone = EntityDatabase.parse_phone(phone) or self._phone
result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone self._phone = phone
self._phone_code_hash = result.phone_code_hash self._phone_code_hash = result.phone_code_hash
return result return result
def sign_in(self, phone=None, code=None, async def sign_in(self, phone=None, code=None,
password=None, bot_token=None, phone_code_hash=None): password=None, bot_token=None, phone_code_hash=None):
"""Completes the sign in process with the phone number + code pair. """Completes the sign in process with the phone number + code pair.
@ -132,7 +132,7 @@ class TelegramClient(TelegramBareClient):
""" """
if phone and not code: if phone and not code:
return self.send_code_request(phone) return await self.send_code_request(phone)
elif code: elif code:
phone = EntityDatabase.parse_phone(phone) or self._phone phone = EntityDatabase.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash phone_code_hash = phone_code_hash or self._phone_code_hash
@ -147,18 +147,18 @@ class TelegramClient(TelegramBareClient):
if isinstance(code, int): if isinstance(code, int):
code = str(code) code = str(code)
result = self(SignInRequest(phone, phone_code_hash, code)) result = await self(SignInRequest(phone, phone_code_hash, code))
except (PhoneCodeEmptyError, PhoneCodeExpiredError, except (PhoneCodeEmptyError, PhoneCodeExpiredError,
PhoneCodeHashEmptyError, PhoneCodeInvalidError): PhoneCodeHashEmptyError, PhoneCodeInvalidError):
return None return None
elif password: elif password:
salt = self(GetPasswordRequest()).current_salt salt = await self(GetPasswordRequest()).current_salt
result = self(CheckPasswordRequest( result = await self(CheckPasswordRequest(
helpers.get_password_hash(password, salt) helpers.get_password_hash(password, salt)
)) ))
elif bot_token: elif bot_token:
result = self(ImportBotAuthorizationRequest( result = await self(ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token, flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash api_id=self.api_id, api_hash=self.api_hash
)) ))
@ -171,9 +171,9 @@ class TelegramClient(TelegramBareClient):
self._set_connected_and_authorized() self._set_connected_and_authorized()
return result.user return result.user
def sign_up(self, code, first_name, last_name=''): async def sign_up(self, code, first_name, last_name=''):
"""Signs up to Telegram. Make sure you sent a code request first!""" """Signs up to Telegram. Make sure you sent a code request first!"""
result = self(SignUpRequest( result = await self(SignUpRequest(
phone_number=self._phone, phone_number=self._phone,
phone_code_hash=self._phone_code_hash, phone_code_hash=self._phone_code_hash,
phone_code=code, phone_code=code,
@ -184,11 +184,11 @@ class TelegramClient(TelegramBareClient):
self._set_connected_and_authorized() self._set_connected_and_authorized()
return result.user return result.user
def log_out(self): async def log_out(self):
"""Logs out and deletes the current session. """Logs out and deletes the current session.
Returns True if everything went okay.""" Returns True if everything went okay."""
try: try:
self(LogOutRequest()) await self(LogOutRequest())
except RPCError: except RPCError:
return False return False
@ -197,11 +197,11 @@ class TelegramClient(TelegramBareClient):
self.session = None self.session = None
return True return True
def get_me(self): async def get_me(self):
"""Gets "me" (the self user) which is currently authenticated, """Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated).""" or None if the request fails (hence, not authenticated)."""
try: try:
return self(GetUsersRequest([InputUserSelf()]))[0] return await self(GetUsersRequest([InputUserSelf()]))[0]
except UnauthorizedError: except UnauthorizedError:
return None return None
@ -209,7 +209,7 @@ class TelegramClient(TelegramBareClient):
# region Dialogs ("chats") requests # region Dialogs ("chats") requests
def get_dialogs(self, async def get_dialogs(self,
limit=10, limit=10,
offset_date=None, offset_date=None,
offset_id=0, offset_id=0,
@ -232,7 +232,7 @@ class TelegramClient(TelegramBareClient):
entities = {} entities = {}
while len(dialogs) < limit: while len(dialogs) < limit:
need = limit - len(dialogs) need = limit - len(dialogs)
r = self(GetDialogsRequest( r = await self(GetDialogsRequest(
offset_date=offset_date, offset_date=offset_date,
offset_id=offset_id, offset_id=offset_id,
offset_peer=offset_peer, offset_peer=offset_peer,
@ -281,18 +281,18 @@ class TelegramClient(TelegramBareClient):
# region Message requests # region Message requests
def send_message(self, async def send_message(self,
entity, entity,
message, message,
reply_to=None, reply_to=None,
link_preview=True): link_preview=True):
"""Sends a message to the given entity (or input peer) """Sends a message to the given entity (or input peer)
and returns the sent message as a Telegram object. and returns the sent message as a Telegram object.
If 'reply_to' is set to either a message or a message ID, If 'reply_to' is set to either a message or a message ID,
the sent message will be replying to such message. the sent message will be replying to such message.
""" """
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
request = SendMessageRequest( request = SendMessageRequest(
peer=entity, peer=entity,
message=message, message=message,
@ -300,7 +300,7 @@ class TelegramClient(TelegramBareClient):
no_webpage=not link_preview, no_webpage=not link_preview,
reply_to_msg_id=self._get_reply_to(reply_to) reply_to_msg_id=self._get_reply_to(reply_to)
) )
result = self(request) result = await self(request)
if isinstance(result, UpdateShortSentMessage): if isinstance(result, UpdateShortSentMessage):
return Message( return Message(
id=result.id, id=result.id,
@ -328,7 +328,7 @@ class TelegramClient(TelegramBareClient):
return None # Should not happen return None # Should not happen
def delete_messages(self, entity, message_ids, revoke=True): async def delete_messages(self, entity, message_ids, revoke=True):
""" """
Deletes a message from a chat, optionally "for everyone" with argument Deletes a message from a chat, optionally "for everyone" with argument
`revoke` set to `True`. `revoke` set to `True`.
@ -352,16 +352,16 @@ class TelegramClient(TelegramBareClient):
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids] message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
if entity is None: if entity is None:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
entity = self.get_input_entity(entity) entity = await self.get_input_entity(entity)
if isinstance(entity, InputPeerChannel): if isinstance(entity, InputPeerChannel):
return self(channels.DeleteMessagesRequest(entity, message_ids)) return await self(channels.DeleteMessagesRequest(entity, message_ids))
else: else:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke)) return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
def get_message_history(self, async def get_message_history(self,
entity, entity,
limit=20, limit=20,
offset_date=None, offset_date=None,
@ -386,8 +386,8 @@ class TelegramClient(TelegramBareClient):
The entity may be a phone or an username at the expense of The entity may be a phone or an username at the expense of
some performance loss. some performance loss.
""" """
result = self(GetHistoryRequest( result = await self(GetHistoryRequest(
peer=self.get_input_entity(entity), peer=await self.get_input_entity(entity),
limit=limit, limit=limit,
offset_date=offset_date, offset_date=offset_date,
offset_id=offset_id, offset_id=offset_id,
@ -413,7 +413,7 @@ class TelegramClient(TelegramBareClient):
return total_messages, result.messages, entities return total_messages, result.messages, entities
def send_read_acknowledge(self, entity, messages=None, max_id=None): async def send_read_acknowledge(self, entity, messages=None, max_id=None):
"""Sends a "read acknowledge" (i.e., notifying the given peer that we've """Sends a "read acknowledge" (i.e., notifying the given peer that we've
read their messages, also known as the "double check"). read their messages, also known as the "double check").
@ -435,8 +435,8 @@ class TelegramClient(TelegramBareClient):
else: else:
max_id = messages.id max_id = messages.id
return self(ReadHistoryRequest( return await self(ReadHistoryRequest(
peer=self.get_input_entity(entity), peer=await self.get_input_entity(entity),
max_id=max_id max_id=max_id
)) ))
@ -460,10 +460,10 @@ class TelegramClient(TelegramBareClient):
# region Uploading files # region Uploading files
def send_file(self, entity, file, caption='', async def send_file(self, entity, file, caption='',
force_document=False, progress_callback=None, force_document=False, progress_callback=None,
reply_to=None, reply_to=None,
**kwargs): **kwargs):
"""Sends a file to the specified entity. """Sends a file to the specified entity.
The file may either be a path, a byte array, or a stream. The file may either be a path, a byte array, or a stream.
@ -500,7 +500,7 @@ class TelegramClient(TelegramBareClient):
if file_hash in self._upload_cache: if file_hash in self._upload_cache:
file_handle = self._upload_cache[file_hash] file_handle = self._upload_cache[file_hash]
else: else:
self._upload_cache[file_hash] = file_handle = self.upload_file( self._upload_cache[file_hash] = file_handle = await self.upload_file(
file, progress_callback=progress_callback file, progress_callback=progress_callback
) )
@ -538,19 +538,19 @@ class TelegramClient(TelegramBareClient):
# Once the media type is properly specified and the file uploaded, # Once the media type is properly specified and the file uploaded,
# send the media message to the desired entity. # send the media message to the desired entity.
self(SendMediaRequest( await self(SendMediaRequest(
peer=self.get_input_entity(entity), peer=await self.get_input_entity(entity),
media=media, media=media,
reply_to_msg_id=self._get_reply_to(reply_to) reply_to_msg_id=self._get_reply_to(reply_to)
)) ))
def send_voice_note(self, entity, file, caption='', upload_progress=None, async def send_voice_note(self, entity, file, caption='', upload_progress=None,
reply_to=None): reply_to=None):
"""Wrapper method around .send_file() with is_voice_note=()""" """Wrapper method around .send_file() with is_voice_note=()"""
return self.send_file(entity, file, caption, return await self.send_file(entity, file, caption,
upload_progress=upload_progress, upload_progress=upload_progress,
reply_to=reply_to, reply_to=reply_to,
is_voice_note=()) # empty tuple is enough is_voice_note=()) # empty tuple is enough
def clear_file_cache(self): def clear_file_cache(self):
"""Calls to .send_file() will cache the remote location of the """Calls to .send_file() will cache the remote location of the
@ -564,7 +564,7 @@ class TelegramClient(TelegramBareClient):
# region Downloading media requests # region Downloading media requests
def download_profile_photo(self, entity, file=None, download_big=True): async def download_profile_photo(self, entity, file=None, download_big=True):
"""Downloads the profile photo for an user or a chat (channels too). """Downloads the profile photo for an user or a chat (channels too).
Returns None if no photo was provided, or if it was Empty. Returns None if no photo was provided, or if it was Empty.
@ -590,12 +590,12 @@ class TelegramClient(TelegramBareClient):
# The hexadecimal numbers above are simply: # The hexadecimal numbers above are simply:
# hex(crc32(x.encode('ascii'))) for x in # hex(crc32(x.encode('ascii'))) for x in
# ('User', 'Chat', 'UserFull', 'ChatFull') # ('User', 'Chat', 'UserFull', 'ChatFull')
entity = self.get_entity(entity) entity = await self.get_entity(entity)
if not hasattr(entity, 'photo'): if not hasattr(entity, 'photo'):
# Special case: may be a ChatFull with photo:Photo # Special case: may be a ChatFull with photo:Photo
# This is different from a normal UserProfilePhoto and Chat # This is different from a normal UserProfilePhoto and Chat
if hasattr(entity, 'chat_photo'): if hasattr(entity, 'chat_photo'):
return self._download_photo( return await self._download_photo(
entity.chat_photo, file, entity.chat_photo, file,
date=None, progress_callback=None date=None, progress_callback=None
) )
@ -623,7 +623,7 @@ class TelegramClient(TelegramBareClient):
) )
# Download the media with the largest size input file location # Download the media with the largest size input file location
self.download_file( await self.download_file(
InputFileLocation( InputFileLocation(
volume_id=photo_location.volume_id, volume_id=photo_location.volume_id,
local_id=photo_location.local_id, local_id=photo_location.local_id,
@ -633,7 +633,7 @@ class TelegramClient(TelegramBareClient):
) )
return file return file
def download_media(self, message, file=None, progress_callback=None): async def download_media(self, message, file=None, progress_callback=None):
"""Downloads the media from a specified Message (it can also be """Downloads the media from a specified Message (it can also be
the message.media) into the desired file (a stream or str), the message.media) into the desired file (a stream or str),
optionally finding its extension automatically. optionally finding its extension automatically.
@ -659,19 +659,19 @@ class TelegramClient(TelegramBareClient):
media = message media = message
if isinstance(media, MessageMediaPhoto): if isinstance(media, MessageMediaPhoto):
return self._download_photo( return await self._download_photo(
media, file, date, progress_callback media, file, date, progress_callback
) )
elif isinstance(media, MessageMediaDocument): elif isinstance(media, MessageMediaDocument):
return self._download_document( return await self._download_document(
media, file, date, progress_callback media, file, date, progress_callback
) )
elif isinstance(media, MessageMediaContact): elif isinstance(media, MessageMediaContact):
return self._download_contact( return await self._download_contact(
media, file media, file
) )
def _download_photo(self, mm_photo, file, date, progress_callback): async def _download_photo(self, mm_photo, file, date, progress_callback):
"""Specialized version of .download_media() for photos""" """Specialized version of .download_media() for photos"""
# Determine the photo and its largest size # Determine the photo and its largest size
@ -683,7 +683,7 @@ class TelegramClient(TelegramBareClient):
file = self._get_proper_filename(file, 'photo', '.jpg', date=date) file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
# Download the media with the largest size input file location # Download the media with the largest size input file location
self.download_file( await self.download_file(
InputFileLocation( InputFileLocation(
volume_id=largest_size.volume_id, volume_id=largest_size.volume_id,
local_id=largest_size.local_id, local_id=largest_size.local_id,
@ -695,7 +695,7 @@ class TelegramClient(TelegramBareClient):
) )
return file return file
def _download_document(self, mm_doc, file, date, progress_callback): async def _download_document(self, mm_doc, file, date, progress_callback):
"""Specialized version of .download_media() for documents""" """Specialized version of .download_media() for documents"""
document = mm_doc.document document = mm_doc.document
file_size = document.size file_size = document.size
@ -715,7 +715,7 @@ class TelegramClient(TelegramBareClient):
date=date, possible_names=possible_names date=date, possible_names=possible_names
) )
self.download_file( await self.download_file(
InputDocumentFileLocation( InputDocumentFileLocation(
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
@ -826,7 +826,7 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier # region Small utilities to make users' life easier
def get_entity(self, entity): async def get_entity(self, entity):
"""Turns an entity into a valid Telegram user or chat. """Turns an entity into a valid Telegram user or chat.
If "entity" is a string which can be converted to an integer, If "entity" is a string which can be converted to an integer,
or if it starts with '+' it will be resolved as if it or if it starts with '+' it will be resolved as if it
@ -851,15 +851,14 @@ class TelegramClient(TelegramBareClient):
isinstance(entity, TLObject) and isinstance(entity, TLObject) and
# crc32(b'InputPeer') and crc32(b'Peer') # crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = self.get_input_entity(entity) ie = await self.get_input_entity(entity)
result = None result = None
if isinstance(ie, InputPeerUser): if isinstance(ie, InputPeerUser):
result = self(GetUsersRequest([ie])) result = await self(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat): elif isinstance(ie, InputPeerChat):
result = self(GetChatsRequest([ie.chat_id])) result = await self(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel): elif isinstance(ie, InputPeerChannel):
result = self(GetChannelsRequest([ie])) result = await self(GetChannelsRequest([ie]))
if result: if result:
self.session.process_entities(result) self.session.process_entities(result)
try: try:
@ -868,23 +867,23 @@ class TelegramClient(TelegramBareClient):
pass pass
if isinstance(entity, str): if isinstance(entity, str):
return self._get_entity_from_string(entity) return await self._get_entity_from_string(entity)
raise ValueError( raise ValueError(
'Cannot turn "{}" into any entity (user or chat)'.format(entity) 'Cannot turn "{}" into any entity (user or chat)'.format(entity)
) )
def _get_entity_from_string(self, string): async def _get_entity_from_string(self, string):
"""Gets an entity from the given string, which may be a phone or """Gets an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session. an username, and processes all the found entities on the session.
""" """
phone = EntityDatabase.parse_phone(string) phone = EntityDatabase.parse_phone(string)
if phone: if phone:
entity = phone entity = phone
self.session.process_entities(self(GetContactsRequest(0))) self.session.process_entities(await self(GetContactsRequest(0)))
else: else:
entity = string.strip('@').lower() entity = string.strip('@').lower()
self.session.process_entities(self(ResolveUsernameRequest(entity))) self.session.process_entities(await self(ResolveUsernameRequest(entity)))
try: try:
return self.session.entities[entity] return self.session.entities[entity]
@ -893,7 +892,7 @@ class TelegramClient(TelegramBareClient):
'Could not find user with username {}'.format(entity) 'Could not find user with username {}'.format(entity)
) )
def get_input_entity(self, peer): async def get_input_entity(self, peer):
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel. """Gets the input entity given its PeerUser, PeerChat, PeerChannel.
If no Peer class is used, peer is assumed to be the integer ID If no Peer class is used, peer is assumed to be the integer ID
of an User. of an User.
@ -910,7 +909,7 @@ class TelegramClient(TelegramBareClient):
pass pass
if isinstance(peer, str): if isinstance(peer, str):
return utils.get_input_peer(self._get_entity_from_string(peer)) return utils.get_input_peer(await self._get_entity_from_string(peer))
is_peer = False is_peer = False
if isinstance(peer, int): if isinstance(peer, int):
@ -932,7 +931,7 @@ class TelegramClient(TelegramBareClient):
if self.session.save_entities: if self.session.save_entities:
# Not found, look in the dialogs (this will save the users) # Not found, look in the dialogs (this will save the users)
self.get_dialogs(limit=None) await self.get_dialogs(limit=None)
try: try:
return self.session.entities.get_input_entity(peer) return self.session.entities.get_input_entity(peer)