Use async def everywhere

This commit is contained in:
Lonami Exo 2017-10-06 21:02:41 +02:00
parent 9716d1d543
commit 77c99db066
7 changed files with 206 additions and 208 deletions

View File

@ -1,9 +1,12 @@
# Python rough implementation of a C# TCP client
import asyncio
import errno
import socket
from datetime import timedelta
from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop()
class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
@ -30,7 +33,7 @@ class TcpClient:
self._socket.settimeout(self.timeout)
def connect(self, ip, port):
async def connect(self, ip, port):
"""Connects to the specified IP and port number.
'timeout' must be given in seconds
"""
@ -44,7 +47,7 @@ class TcpClient:
while not self._socket:
self._recreate_socket(mode)
self._socket.connect(address)
await loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect
except OSError as e:
# There are some errors that we know how to handle, and
@ -72,15 +75,13 @@ class TcpClient:
finally:
self._socket = None
def write(self, data):
async def write(self, data):
"""Writes (sends) the specified bytes to the connected peer"""
if self._socket is None:
raise ConnectionResetError()
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try:
self._socket.sendall(data)
await loop.sock_sendall(self._socket, data)
except socket.timeout as e:
raise TimeoutError() from e
except BrokenPipeError:
@ -91,14 +92,9 @@ class TcpClient:
else:
raise
def read(self, size):
async def read(self, size):
"""Reads (receives) a whole block of 'size bytes
from the connected peer.
A timeout can be specified, which will cancel the operation if
no data has been read in the specified time. If data was read
and it's waiting for more, the timeout will NOT cancel the
operation. Set to None for no timeout
"""
if self._socket is None:
raise ConnectionResetError()
@ -108,7 +104,7 @@ class TcpClient:
bytes_left = size
while bytes_left != 0:
try:
partial = self._socket.recv(bytes_left)
partial = await loop.sock_recv(self._socket, bytes_left)
except socket.timeout as e:
raise TimeoutError() from e
except OSError as e:

View File

@ -17,21 +17,21 @@ from ..tl.functions import (
)
def do_authentication(connection, retries=5):
async def do_authentication(connection, retries=5):
if not retries or retries < 0:
retries = 1
last_error = None
while retries:
try:
return _do_authentication(connection)
return await _do_authentication(connection)
except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e
retries -= 1
raise last_error
def _do_authentication(connection):
async def _do_authentication(connection):
"""Executes the authentication process with the Telegram servers.
If no error is raised, returns both the authorization key and the
time offset.
@ -42,8 +42,8 @@ def _do_authentication(connection):
req_pq_request = ReqPqRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
)
sender.send(req_pq_request.to_bytes())
with BinaryReader(sender.receive()) as reader:
await sender.send(req_pq_request.to_bytes())
with BinaryReader(await sender.receive()) as reader:
req_pq_request.on_response(reader)
res_pq = req_pq_request.result
@ -90,10 +90,10 @@ def _do_authentication(connection):
public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text
)
sender.send(req_dh_params.to_bytes())
await sender.send(req_dh_params.to_bytes())
# Step 2 response: DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
req_dh_params.on_response(reader)
server_dh_params = req_dh_params.result
@ -157,10 +157,10 @@ def _do_authentication(connection):
server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted,
)
sender.send(set_client_dh.to_bytes())
await sender.send(set_client_dh.to_bytes())
# Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
set_client_dh.on_response(reader)
dh_gen = set_client_dh.result

View File

@ -1,14 +1,13 @@
import errno
import os
import struct
from datetime import timedelta
from zlib import crc32
from enum import Enum
import errno
from zlib import crc32
from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError
from ..extensions import TcpClient
class ConnectionMode(Enum):
@ -74,9 +73,9 @@ class Connection:
setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain)
def connect(self, ip, port):
async def connect(self, ip, port):
try:
self.conn.connect(ip, port)
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
@ -85,16 +84,16 @@ class Connection:
self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef')
await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee')
await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation()
await self._setup_obfuscation()
def get_timeout(self):
return self.conn.timeout
def _setup_obfuscation(self):
async def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True:
@ -119,7 +118,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random))
await self.conn.write(bytes(random))
def is_connected(self):
return self.conn.connected
@ -135,20 +134,23 @@ class Connection:
# region Receive message implementations
def recv(self):
async def recv(self):
"""Receives and unpacks a message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self):
packet_length_bytes = self.read(4)
async def _recv_tcp_full(self):
# TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_length_bytes = await self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little')
seq_bytes = self.read(4)
seq_bytes = await self.read(4)
seq = int.from_bytes(seq_bytes, 'little')
body = self.read(packet_length - 12)
checksum = int.from_bytes(self.read(4), 'little')
body = await self.read(packet_length - 12)
checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum:
@ -156,72 +158,70 @@ class Connection:
return body
def _recv_intermediate(self):
return self.read(int.from_bytes(self.read(4), 'little'))
async def _recv_intermediate(self):
return await self.read(int.from_bytes(self.read(4), 'little'))
def _recv_abridged(self):
async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little')
if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little')
return self.read(length << 2)
return await self.read(length << 2)
# endregion
# region Send message implementations
def send(self, message):
async def send(self, message):
"""Encapsulates and sends the given message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message):
async def _send_tcp_full(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
await self.write(data + crc)
def _send_intermediate(self, message):
self.write(struct.pack('<i', len(message)) + message)
async def _send_intermediate(self, message):
await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message):
async def _send_abridged(self, message):
length = len(message) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
await self.write(length + message)
# endregion
# region Read implementations
def read(self, length):
async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length):
return self.conn.read(length)
async def _read_plain(self, length):
return await self.conn.read(length)
def _read_obfuscated(self, length):
return self._aes_decrypt.encrypt(
self.conn.read(length)
)
async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length))
# endregion
# region Write implementations
def write(self, data):
async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data):
self.conn.write(data)
async def _write_plain(self, data):
await self.conn.write(data)
def _write_obfuscated(self, data):
self.conn.write(self._aes_encrypt.encrypt(data))
async def _write_obfuscated(self, data):
await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion

View File

@ -16,23 +16,23 @@ class MtProtoPlainSender:
self._last_msg_id = 0
self._connection = connection
def connect(self):
self._connection.connect()
async def connect(self):
await self._connection.connect()
def disconnect(self):
self._connection.close()
def send(self, data):
async def send(self, data):
"""Sends a plain packet (auth_key_id = 0) containing the
given message body (data)
"""
self._connection.send(
await self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
)
def receive(self):
async def receive(self):
"""Receives a plain packet, returning the body of the response"""
body = self._connection.recv()
body = await self._connection.recv()
if body == b'l\xfe\xff\xff': # -404 little endian signed
# Broken authorization, must reset the auth key
raise BrokenAuthKeyError()

View File

@ -41,9 +41,9 @@ class MtProtoSender:
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
def connect(self):
async def connect(self):
"""Connects to the server"""
self.connection.connect(self.session.server_address, self.session.port)
await self.connection.connect(self.session.server_address, self.session.port)
def is_connected(self):
return self.connection.is_connected()
@ -60,7 +60,7 @@ class MtProtoSender:
# region Send and receive
def send(self, *requests):
async def send(self, *requests):
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
@ -80,13 +80,13 @@ class MtProtoSender:
else:
message = TLMessage(self.session, MessageContainer(messages))
self._send_message(message)
await self._send_message(message)
def _send_acknowledge(self, msg_id):
async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id"""
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
def receive(self, update_state):
async def receive(self, update_state):
"""Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts
@ -97,7 +97,7 @@ class MtProtoSender:
update_state.process(TLObject).
"""
try:
body = self.connection.recv()
body = await self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
@ -111,13 +111,13 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
# endregion
# region Low level processing
def _send_message(self, message):
async def _send_message(self, message):
"""Sends the given Message(TLObject) encrypted through the network"""
plain_text = \
@ -130,7 +130,7 @@ class MtProtoSender:
cipher_text = AES.encrypt_ige(plain_text, key, iv)
result = key_id + msg_key + cipher_text
self.connection.send(result)
await self.connection.send(result)
def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes"""
@ -163,7 +163,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, state):
async def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't
@ -178,22 +178,22 @@ class MtProtoSender:
# The following codes are "parsed manually"
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return self._handle_rpc_result(msg_id, sequence, reader)
return await self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong
return self._handle_pong(msg_id, sequence, reader)
return await self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, state)
return await self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, state)
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader)
return await self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification(msg_id, sequence, reader)
return await self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
@ -247,7 +247,7 @@ class MtProtoSender:
r.confirm_received.set()
self._pending_receive.clear()
def _handle_pong(self, msg_id, sequence, reader):
async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
reader.read_int(signed=False) # code
received_msg_id = reader.read_long()
@ -259,7 +259,7 @@ class MtProtoSender:
return True
def _handle_container(self, msg_id, sequence, reader, state):
async def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position()
@ -267,7 +267,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(inner_msg_id, sequence, reader, state):
if not await self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
@ -276,7 +276,7 @@ class MtProtoSender:
return True
def _handle_bad_server_salt(self, msg_id, sequence, reader):
async def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code
bad_msg_id = reader.read_long()
@ -287,11 +287,11 @@ class MtProtoSender:
request = self._pop_request(bad_msg_id)
if request:
self.send(request)
await self.send(request)
return True
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code
reader.read_long() # request_id
@ -318,7 +318,7 @@ class MtProtoSender:
else:
raise error
def _handle_rpc_result(self, msg_id, sequence, reader):
async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code
request_id = reader.read_long()
@ -338,7 +338,7 @@ class MtProtoSender:
)
# Acknowledge that we received the error
self._send_acknowledge(request_id)
await self._send_acknowledge(request_id)
if request:
request.rpc_error = error
@ -366,9 +366,9 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.')
return False
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, state)
return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion

View File

@ -137,7 +137,7 @@ class TelegramBareClient:
# region Connecting
def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
async def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
@ -158,13 +158,13 @@ class TelegramBareClient:
centers won't be invoked.
"""
try:
self._sender.connect()
await self._sender.connect()
if not self.session.auth_key:
# New key, we need to tell the server we're going to use
# the latest layer
try:
self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(self._sender.connection)
await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError:
return False
@ -176,21 +176,21 @@ class TelegramBareClient:
if init_connection:
if _exported_auth is not None:
self._init_connection(ImportAuthorizationRequest(
await self._init_connection(ImportAuthorizationRequest(
_exported_auth.id, _exported_auth.bytes
))
elif not _cdn:
TelegramBareClient._dc_options = \
self._init_connection(GetConfigRequest()).dc_options
(await self._init_connection(GetConfigRequest())).dc_options
elif _exported_auth is not None:
self(ImportAuthorizationRequest(
await self(ImportAuthorizationRequest(
_exported_auth.id, _exported_auth.bytes
))
if TelegramBareClient._dc_options is None and not _cdn:
TelegramBareClient._dc_options = \
self(GetConfigRequest()).dc_options
(await self(GetConfigRequest())).dc_options
# Connection was successful! Try syncing the update state
# UNLESS '_sync_updates' is False (we probably are in
@ -199,7 +199,7 @@ class TelegramBareClient:
self._user_connected = True
if _sync_updates and not _cdn:
try:
self.sync_updates()
await self.sync_updates()
self._set_connected_and_authorized()
except UnauthorizedError:
self._authorized = False
@ -227,8 +227,8 @@ class TelegramBareClient:
def is_connected(self):
return self._sender.is_connected()
def _init_connection(self, query=None):
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
async def _init_connection(self, query=None):
result = await self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
api_id=self.api_id,
device_model=self.session.device_model,
system_version=self.session.system_version,
@ -249,7 +249,7 @@ class TelegramBareClient:
# TODO Shall we clear the _exported_sessions, or may be reused?
pass
def _reconnect(self, new_dc=None):
async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made
since it's assumed that the connection has been lost and the
library is reconnecting.
@ -260,7 +260,7 @@ class TelegramBareClient:
"""
if new_dc is None:
# Assume we are disconnected due to some error, so connect again
return self.connect()
return await self.connect()
else:
self.disconnect()
self.session.auth_key = None # Force creating new auth_key
@ -269,23 +269,24 @@ class TelegramBareClient:
self.session.server_address = ip
self.session.port = dc.port
self.session.save()
return self.connect()
return await self.connect()
# endregion
# region Working with different connections/Data Centers
def _get_dc(self, dc_id, ipv6=False, cdn=False):
async def _get_dc(self, dc_id, ipv6=False, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'"""
if TelegramBareClient._dc_options is None:
raise ConnectionError(
'Cannot determine the required data center IP address. '
'Stabilise a successful initial connection first.')
'Stabilise a successful initial connection first.'
)
try:
if cdn:
# Ensure we have the latest keys for the CDNs
for pk in self(GetCdnConfigRequest()).public_keys:
for pk in await (self(GetCdnConfigRequest())).public_keys:
rsa.add_key(pk.public_key)
return next(
@ -297,10 +298,10 @@ class TelegramBareClient:
raise
# New configuration, perhaps a new CDN was added?
TelegramBareClient._dc_options = self(GetConfigRequest()).dc_options
TelegramBareClient._dc_options = await (self(GetConfigRequest())).dc_options
return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn)
def _get_exported_client(self, dc_id):
async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time calling the method with a given dc_id,
@ -317,10 +318,10 @@ class TelegramBareClient:
# TODO Add a lock, don't allow two threads to create an auth key
# (when calling .connect() if there wasn't a previous session).
# for the same data center.
dc = self._get_dc(dc_id)
dc = await self._get_dc(dc_id)
# Export the current authorization to the new DC.
export_auth = self(ExportAuthorizationRequest(dc_id))
export_auth = await self(ExportAuthorizationRequest(dc_id))
# Create a temporary session for this IP address, which needs
# to be different because each auth_key is unique per DC.
@ -337,15 +338,15 @@ class TelegramBareClient:
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
)
client.connect(_exported_auth=export_auth, _sync_updates=False)
await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth
return client
def _get_cdn_client(self, cdn_redirect):
async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(self.session)
session.server_address = dc.ip_address
session.port = dc.port
@ -361,7 +362,7 @@ class TelegramBareClient:
#
# This relies on the fact that TelegramBareClient._dc_options is
# static and it won't be called from this DC (it would fail).
client.connect(_cdn=True) # Avoid invoking non-CDN specific methods
await client.connect(_cdn=True) # Avoid invoking non-CDN methods
client._authorized = self._authorized
return client
@ -369,7 +370,7 @@ class TelegramBareClient:
# region Invoking Telegram requests
def __call__(self, *requests, retries=5):
async def __call__(self, *requests, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
The invoke will be retried up to 'retries' times before raising
@ -384,7 +385,7 @@ class TelegramBareClient:
try:
for _ in range(retries):
result = self._invoke(sender, *requests)
result = await self._invoke(sender, *requests)
if result:
return result
@ -396,16 +397,16 @@ class TelegramBareClient:
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
def _invoke(self, sender, *requests):
async def _invoke(self, sender, *requests):
try:
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.confirm_received.clear()
x.rpc_error = None
sender.send(*requests)
await sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests):
sender.receive(update_state=self.updates)
await sender.receive(update_state=self.updates)
except TimeoutError:
pass # We will just retry
@ -420,9 +421,9 @@ class TelegramBareClient:
if sender != self._sender:
# TODO Try reconnecting forever too?
sender.connect()
await sender.connect()
else:
while self._user_connected and not self._reconnect():
while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request
finally:
@ -449,8 +450,8 @@ class TelegramBareClient:
'attempting to reconnect at DC {}'.format(e.new_dc)
)
self._reconnect(new_dc=e.new_dc)
return self._invoke(sender, *requests)
await self._reconnect(new_dc=e.new_dc)
return await self._invoke(sender, *requests)
except ServerError as e:
# Telegram is having some issues, just retry
@ -474,7 +475,7 @@ class TelegramBareClient:
# region Uploading media
def upload_file(self,
async def upload_file(self,
file,
part_size_kb=None,
file_name=None,
@ -537,7 +538,7 @@ class TelegramBareClient:
else:
request = SaveFilePartRequest(file_id, part_index, part)
result = self(request)
result = await self(request)
if result:
if not is_large:
# No need to update the hash if it's a large file
@ -568,7 +569,7 @@ class TelegramBareClient:
# region Downloading media
def download_file(self,
async def download_file(self,
input_location,
file,
part_size_kb=None,
@ -616,18 +617,20 @@ class TelegramBareClient:
if cdn_decrypter:
result = cdn_decrypter.get_file()
else:
result = client(GetFileRequest(
result = await client(GetFileRequest(
input_location, offset, part_size
))
if isinstance(result, FileCdnRedirect):
cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result), result
client,
await self._get_cdn_client(result),
result
)
except FileMigrateError as e:
client = self._get_exported_client(e.new_dc)
client = await self._get_exported_client(e.new_dc)
continue
offset_index += 1
@ -657,12 +660,12 @@ class TelegramBareClient:
# region Updates handling
def sync_updates(self):
async def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
self.updates.process(self(GetStateRequest()))
self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,

View File

@ -99,15 +99,15 @@ class TelegramClient(TelegramBareClient):
# region Authorization requests
def send_code_request(self, phone):
async def send_code_request(self, phone):
"""Sends a code request to the specified phone number"""
phone = EntityDatabase.parse_phone(phone) or self._phone
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone
self._phone_code_hash = result.phone_code_hash
return result
def sign_in(self, phone=None, code=None,
async def sign_in(self, phone=None, code=None,
password=None, bot_token=None, phone_code_hash=None):
"""Completes the sign in process with the phone number + code pair.
@ -132,7 +132,7 @@ class TelegramClient(TelegramBareClient):
"""
if phone and not code:
return self.send_code_request(phone)
return await self.send_code_request(phone)
elif code:
phone = EntityDatabase.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash
@ -147,18 +147,18 @@ class TelegramClient(TelegramBareClient):
if isinstance(code, int):
code = str(code)
result = self(SignInRequest(phone, phone_code_hash, code))
result = await self(SignInRequest(phone, phone_code_hash, code))
except (PhoneCodeEmptyError, PhoneCodeExpiredError,
PhoneCodeHashEmptyError, PhoneCodeInvalidError):
return None
elif password:
salt = self(GetPasswordRequest()).current_salt
result = self(CheckPasswordRequest(
salt = await self(GetPasswordRequest()).current_salt
result = await self(CheckPasswordRequest(
helpers.get_password_hash(password, salt)
))
elif bot_token:
result = self(ImportBotAuthorizationRequest(
result = await self(ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash
))
@ -171,9 +171,9 @@ class TelegramClient(TelegramBareClient):
self._set_connected_and_authorized()
return result.user
def sign_up(self, code, first_name, last_name=''):
async def sign_up(self, code, first_name, last_name=''):
"""Signs up to Telegram. Make sure you sent a code request first!"""
result = self(SignUpRequest(
result = await self(SignUpRequest(
phone_number=self._phone,
phone_code_hash=self._phone_code_hash,
phone_code=code,
@ -184,11 +184,11 @@ class TelegramClient(TelegramBareClient):
self._set_connected_and_authorized()
return result.user
def log_out(self):
async def log_out(self):
"""Logs out and deletes the current session.
Returns True if everything went okay."""
try:
self(LogOutRequest())
await self(LogOutRequest())
except RPCError:
return False
@ -197,11 +197,11 @@ class TelegramClient(TelegramBareClient):
self.session = None
return True
def get_me(self):
async def get_me(self):
"""Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated)."""
try:
return self(GetUsersRequest([InputUserSelf()]))[0]
return await self(GetUsersRequest([InputUserSelf()]))[0]
except UnauthorizedError:
return None
@ -209,7 +209,7 @@ class TelegramClient(TelegramBareClient):
# region Dialogs ("chats") requests
def get_dialogs(self,
async def get_dialogs(self,
limit=10,
offset_date=None,
offset_id=0,
@ -232,7 +232,7 @@ class TelegramClient(TelegramBareClient):
entities = {}
while len(dialogs) < limit:
need = limit - len(dialogs)
r = self(GetDialogsRequest(
r = await self(GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
@ -281,18 +281,18 @@ class TelegramClient(TelegramBareClient):
# region Message requests
def send_message(self,
entity,
message,
reply_to=None,
link_preview=True):
async def send_message(self,
entity,
message,
reply_to=None,
link_preview=True):
"""Sends a message to the given entity (or input peer)
and returns the sent message as a Telegram object.
If 'reply_to' is set to either a message or a message ID,
the sent message will be replying to such message.
"""
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
request = SendMessageRequest(
peer=entity,
message=message,
@ -300,7 +300,7 @@ class TelegramClient(TelegramBareClient):
no_webpage=not link_preview,
reply_to_msg_id=self._get_reply_to(reply_to)
)
result = self(request)
result = await self(request)
if isinstance(result, UpdateShortSentMessage):
return Message(
id=result.id,
@ -328,7 +328,7 @@ class TelegramClient(TelegramBareClient):
return None # Should not happen
def delete_messages(self, entity, message_ids, revoke=True):
async def delete_messages(self, entity, message_ids, revoke=True):
"""
Deletes a message from a chat, optionally "for everyone" with argument
`revoke` set to `True`.
@ -352,16 +352,16 @@ class TelegramClient(TelegramBareClient):
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
if entity is None:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
entity = self.get_input_entity(entity)
entity = await self.get_input_entity(entity)
if isinstance(entity, InputPeerChannel):
return self(channels.DeleteMessagesRequest(entity, message_ids))
return await self(channels.DeleteMessagesRequest(entity, message_ids))
else:
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
def get_message_history(self,
async def get_message_history(self,
entity,
limit=20,
offset_date=None,
@ -386,8 +386,8 @@ class TelegramClient(TelegramBareClient):
The entity may be a phone or an username at the expense of
some performance loss.
"""
result = self(GetHistoryRequest(
peer=self.get_input_entity(entity),
result = await self(GetHistoryRequest(
peer=await self.get_input_entity(entity),
limit=limit,
offset_date=offset_date,
offset_id=offset_id,
@ -413,7 +413,7 @@ class TelegramClient(TelegramBareClient):
return total_messages, result.messages, entities
def send_read_acknowledge(self, entity, messages=None, max_id=None):
async def send_read_acknowledge(self, entity, messages=None, max_id=None):
"""Sends a "read acknowledge" (i.e., notifying the given peer that we've
read their messages, also known as the "double check").
@ -435,8 +435,8 @@ class TelegramClient(TelegramBareClient):
else:
max_id = messages.id
return self(ReadHistoryRequest(
peer=self.get_input_entity(entity),
return await self(ReadHistoryRequest(
peer=await self.get_input_entity(entity),
max_id=max_id
))
@ -460,10 +460,10 @@ class TelegramClient(TelegramBareClient):
# region Uploading files
def send_file(self, entity, file, caption='',
force_document=False, progress_callback=None,
reply_to=None,
**kwargs):
async def send_file(self, entity, file, caption='',
force_document=False, progress_callback=None,
reply_to=None,
**kwargs):
"""Sends a file to the specified entity.
The file may either be a path, a byte array, or a stream.
@ -500,7 +500,7 @@ class TelegramClient(TelegramBareClient):
if file_hash in self._upload_cache:
file_handle = self._upload_cache[file_hash]
else:
self._upload_cache[file_hash] = file_handle = self.upload_file(
self._upload_cache[file_hash] = file_handle = await self.upload_file(
file, progress_callback=progress_callback
)
@ -538,19 +538,19 @@ class TelegramClient(TelegramBareClient):
# Once the media type is properly specified and the file uploaded,
# send the media message to the desired entity.
self(SendMediaRequest(
peer=self.get_input_entity(entity),
await self(SendMediaRequest(
peer=await self.get_input_entity(entity),
media=media,
reply_to_msg_id=self._get_reply_to(reply_to)
))
def send_voice_note(self, entity, file, caption='', upload_progress=None,
async def send_voice_note(self, entity, file, caption='', upload_progress=None,
reply_to=None):
"""Wrapper method around .send_file() with is_voice_note=()"""
return self.send_file(entity, file, caption,
upload_progress=upload_progress,
reply_to=reply_to,
is_voice_note=()) # empty tuple is enough
return await self.send_file(entity, file, caption,
upload_progress=upload_progress,
reply_to=reply_to,
is_voice_note=()) # empty tuple is enough
def clear_file_cache(self):
"""Calls to .send_file() will cache the remote location of the
@ -564,7 +564,7 @@ class TelegramClient(TelegramBareClient):
# region Downloading media requests
def download_profile_photo(self, entity, file=None, download_big=True):
async def download_profile_photo(self, entity, file=None, download_big=True):
"""Downloads the profile photo for an user or a chat (channels too).
Returns None if no photo was provided, or if it was Empty.
@ -590,12 +590,12 @@ class TelegramClient(TelegramBareClient):
# The hexadecimal numbers above are simply:
# hex(crc32(x.encode('ascii'))) for x in
# ('User', 'Chat', 'UserFull', 'ChatFull')
entity = self.get_entity(entity)
entity = await self.get_entity(entity)
if not hasattr(entity, 'photo'):
# Special case: may be a ChatFull with photo:Photo
# This is different from a normal UserProfilePhoto and Chat
if hasattr(entity, 'chat_photo'):
return self._download_photo(
return await self._download_photo(
entity.chat_photo, file,
date=None, progress_callback=None
)
@ -623,7 +623,7 @@ class TelegramClient(TelegramBareClient):
)
# Download the media with the largest size input file location
self.download_file(
await self.download_file(
InputFileLocation(
volume_id=photo_location.volume_id,
local_id=photo_location.local_id,
@ -633,7 +633,7 @@ class TelegramClient(TelegramBareClient):
)
return file
def download_media(self, message, file=None, progress_callback=None):
async def download_media(self, message, file=None, progress_callback=None):
"""Downloads the media from a specified Message (it can also be
the message.media) into the desired file (a stream or str),
optionally finding its extension automatically.
@ -659,19 +659,19 @@ class TelegramClient(TelegramBareClient):
media = message
if isinstance(media, MessageMediaPhoto):
return self._download_photo(
return await self._download_photo(
media, file, date, progress_callback
)
elif isinstance(media, MessageMediaDocument):
return self._download_document(
return await self._download_document(
media, file, date, progress_callback
)
elif isinstance(media, MessageMediaContact):
return self._download_contact(
return await self._download_contact(
media, file
)
def _download_photo(self, mm_photo, file, date, progress_callback):
async def _download_photo(self, mm_photo, file, date, progress_callback):
"""Specialized version of .download_media() for photos"""
# Determine the photo and its largest size
@ -683,7 +683,7 @@ class TelegramClient(TelegramBareClient):
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
# Download the media with the largest size input file location
self.download_file(
await self.download_file(
InputFileLocation(
volume_id=largest_size.volume_id,
local_id=largest_size.local_id,
@ -695,7 +695,7 @@ class TelegramClient(TelegramBareClient):
)
return file
def _download_document(self, mm_doc, file, date, progress_callback):
async def _download_document(self, mm_doc, file, date, progress_callback):
"""Specialized version of .download_media() for documents"""
document = mm_doc.document
file_size = document.size
@ -715,7 +715,7 @@ class TelegramClient(TelegramBareClient):
date=date, possible_names=possible_names
)
self.download_file(
await self.download_file(
InputDocumentFileLocation(
id=document.id,
access_hash=document.access_hash,
@ -826,7 +826,7 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier
def get_entity(self, entity):
async def get_entity(self, entity):
"""Turns an entity into a valid Telegram user or chat.
If "entity" is a string which can be converted to an integer,
or if it starts with '+' it will be resolved as if it
@ -851,15 +851,14 @@ class TelegramClient(TelegramBareClient):
isinstance(entity, TLObject) and
# crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = self.get_input_entity(entity)
ie = await self.get_input_entity(entity)
result = None
if isinstance(ie, InputPeerUser):
result = self(GetUsersRequest([ie]))
result = await self(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat):
result = self(GetChatsRequest([ie.chat_id]))
result = await self(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel):
result = self(GetChannelsRequest([ie]))
result = await self(GetChannelsRequest([ie]))
if result:
self.session.process_entities(result)
try:
@ -868,23 +867,23 @@ class TelegramClient(TelegramBareClient):
pass
if isinstance(entity, str):
return self._get_entity_from_string(entity)
return await self._get_entity_from_string(entity)
raise ValueError(
'Cannot turn "{}" into any entity (user or chat)'.format(entity)
)
def _get_entity_from_string(self, string):
async def _get_entity_from_string(self, string):
"""Gets an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
"""
phone = EntityDatabase.parse_phone(string)
if phone:
entity = phone
self.session.process_entities(self(GetContactsRequest(0)))
self.session.process_entities(await self(GetContactsRequest(0)))
else:
entity = string.strip('@').lower()
self.session.process_entities(self(ResolveUsernameRequest(entity)))
self.session.process_entities(await self(ResolveUsernameRequest(entity)))
try:
return self.session.entities[entity]
@ -893,7 +892,7 @@ class TelegramClient(TelegramBareClient):
'Could not find user with username {}'.format(entity)
)
def get_input_entity(self, peer):
async def get_input_entity(self, peer):
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel.
If no Peer class is used, peer is assumed to be the integer ID
of an User.
@ -910,7 +909,7 @@ class TelegramClient(TelegramBareClient):
pass
if isinstance(peer, str):
return utils.get_input_peer(self._get_entity_from_string(peer))
return utils.get_input_peer(await self._get_entity_from_string(peer))
is_peer = False
if isinstance(peer, int):
@ -932,7 +931,7 @@ class TelegramClient(TelegramBareClient):
if self.session.save_entities:
# Not found, look in the dialogs (this will save the users)
self.get_dialogs(limit=None)
await self.get_dialogs(limit=None)
try:
return self.session.entities.get_input_entity(peer)