mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-10 06:05:47 +03:00
Use async def everywhere
This commit is contained in:
parent
9716d1d543
commit
77c99db066
|
@ -1,9 +1,12 @@
|
|||
# Python rough implementation of a C# TCP client
|
||||
import asyncio
|
||||
import errno
|
||||
import socket
|
||||
from datetime import timedelta
|
||||
from io import BytesIO, BufferedWriter
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class TcpClient:
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
|
@ -30,7 +33,7 @@ class TcpClient:
|
|||
|
||||
self._socket.settimeout(self.timeout)
|
||||
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
"""Connects to the specified IP and port number.
|
||||
'timeout' must be given in seconds
|
||||
"""
|
||||
|
@ -44,7 +47,7 @@ class TcpClient:
|
|||
while not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
|
||||
self._socket.connect(address)
|
||||
await loop.sock_connect(self._socket, address)
|
||||
break # Successful connection, stop retrying to connect
|
||||
except OSError as e:
|
||||
# There are some errors that we know how to handle, and
|
||||
|
@ -72,15 +75,13 @@ class TcpClient:
|
|||
finally:
|
||||
self._socket = None
|
||||
|
||||
def write(self, data):
|
||||
async def write(self, data):
|
||||
"""Writes (sends) the specified bytes to the connected peer"""
|
||||
if self._socket is None:
|
||||
raise ConnectionResetError()
|
||||
|
||||
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
|
||||
# The socket timeout is now the maximum total duration to send all data.
|
||||
try:
|
||||
self._socket.sendall(data)
|
||||
await loop.sock_sendall(self._socket, data)
|
||||
except socket.timeout as e:
|
||||
raise TimeoutError() from e
|
||||
except BrokenPipeError:
|
||||
|
@ -91,14 +92,9 @@ class TcpClient:
|
|||
else:
|
||||
raise
|
||||
|
||||
def read(self, size):
|
||||
async def read(self, size):
|
||||
"""Reads (receives) a whole block of 'size bytes
|
||||
from the connected peer.
|
||||
|
||||
A timeout can be specified, which will cancel the operation if
|
||||
no data has been read in the specified time. If data was read
|
||||
and it's waiting for more, the timeout will NOT cancel the
|
||||
operation. Set to None for no timeout
|
||||
"""
|
||||
if self._socket is None:
|
||||
raise ConnectionResetError()
|
||||
|
@ -108,7 +104,7 @@ class TcpClient:
|
|||
bytes_left = size
|
||||
while bytes_left != 0:
|
||||
try:
|
||||
partial = self._socket.recv(bytes_left)
|
||||
partial = await loop.sock_recv(self._socket, bytes_left)
|
||||
except socket.timeout as e:
|
||||
raise TimeoutError() from e
|
||||
except OSError as e:
|
||||
|
|
|
@ -17,21 +17,21 @@ from ..tl.functions import (
|
|||
)
|
||||
|
||||
|
||||
def do_authentication(connection, retries=5):
|
||||
async def do_authentication(connection, retries=5):
|
||||
if not retries or retries < 0:
|
||||
retries = 1
|
||||
|
||||
last_error = None
|
||||
while retries:
|
||||
try:
|
||||
return _do_authentication(connection)
|
||||
return await _do_authentication(connection)
|
||||
except (SecurityError, AssertionError, NotImplementedError) as e:
|
||||
last_error = e
|
||||
retries -= 1
|
||||
raise last_error
|
||||
|
||||
|
||||
def _do_authentication(connection):
|
||||
async def _do_authentication(connection):
|
||||
"""Executes the authentication process with the Telegram servers.
|
||||
If no error is raised, returns both the authorization key and the
|
||||
time offset.
|
||||
|
@ -42,8 +42,8 @@ def _do_authentication(connection):
|
|||
req_pq_request = ReqPqRequest(
|
||||
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||
)
|
||||
sender.send(req_pq_request.to_bytes())
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
await sender.send(req_pq_request.to_bytes())
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
req_pq_request.on_response(reader)
|
||||
|
||||
res_pq = req_pq_request.result
|
||||
|
@ -90,10 +90,10 @@ def _do_authentication(connection):
|
|||
public_key_fingerprint=target_fingerprint,
|
||||
encrypted_data=cipher_text
|
||||
)
|
||||
sender.send(req_dh_params.to_bytes())
|
||||
await sender.send(req_dh_params.to_bytes())
|
||||
|
||||
# Step 2 response: DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
req_dh_params.on_response(reader)
|
||||
|
||||
server_dh_params = req_dh_params.result
|
||||
|
@ -157,10 +157,10 @@ def _do_authentication(connection):
|
|||
server_nonce=res_pq.server_nonce,
|
||||
encrypted_data=client_dh_encrypted,
|
||||
)
|
||||
sender.send(set_client_dh.to_bytes())
|
||||
await sender.send(set_client_dh.to_bytes())
|
||||
|
||||
# Step 3 response: Complete DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
set_client_dh.on_response(reader)
|
||||
|
||||
dh_gen = set_client_dh.result
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
import errno
|
||||
import os
|
||||
import struct
|
||||
from datetime import timedelta
|
||||
from zlib import crc32
|
||||
from enum import Enum
|
||||
|
||||
import errno
|
||||
from zlib import crc32
|
||||
|
||||
from ..crypto import AESModeCTR
|
||||
from ..extensions import TcpClient
|
||||
from ..errors import InvalidChecksumError
|
||||
from ..extensions import TcpClient
|
||||
|
||||
|
||||
class ConnectionMode(Enum):
|
||||
|
@ -74,9 +73,9 @@ class Connection:
|
|||
setattr(self, 'write', self._write_plain)
|
||||
setattr(self, 'read', self._read_plain)
|
||||
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
try:
|
||||
self.conn.connect(ip, port)
|
||||
await self.conn.connect(ip, port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
|
@ -85,16 +84,16 @@ class Connection:
|
|||
|
||||
self._send_counter = 0
|
||||
if self._mode == ConnectionMode.TCP_ABRIDGED:
|
||||
self.conn.write(b'\xef')
|
||||
await self.conn.write(b'\xef')
|
||||
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
|
||||
self.conn.write(b'\xee\xee\xee\xee')
|
||||
await self.conn.write(b'\xee\xee\xee\xee')
|
||||
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
|
||||
self._setup_obfuscation()
|
||||
await self._setup_obfuscation()
|
||||
|
||||
def get_timeout(self):
|
||||
return self.conn.timeout
|
||||
|
||||
def _setup_obfuscation(self):
|
||||
async def _setup_obfuscation(self):
|
||||
# Obfuscated messages secrets cannot start with any of these
|
||||
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
|
||||
while True:
|
||||
|
@ -119,7 +118,7 @@ class Connection:
|
|||
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
|
||||
|
||||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||
self.conn.write(bytes(random))
|
||||
await self.conn.write(bytes(random))
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.connected
|
||||
|
@ -135,20 +134,23 @@ class Connection:
|
|||
|
||||
# region Receive message implementations
|
||||
|
||||
def recv(self):
|
||||
async def recv(self):
|
||||
"""Receives and unpacks a message"""
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _recv_tcp_full(self):
|
||||
packet_length_bytes = self.read(4)
|
||||
async def _recv_tcp_full(self):
|
||||
# TODO We don't want another call to this method that could
|
||||
# potentially await on another self.read(n). Is this guaranteed
|
||||
# by asyncio?
|
||||
packet_length_bytes = await self.read(4)
|
||||
packet_length = int.from_bytes(packet_length_bytes, 'little')
|
||||
|
||||
seq_bytes = self.read(4)
|
||||
seq_bytes = await self.read(4)
|
||||
seq = int.from_bytes(seq_bytes, 'little')
|
||||
|
||||
body = self.read(packet_length - 12)
|
||||
checksum = int.from_bytes(self.read(4), 'little')
|
||||
body = await self.read(packet_length - 12)
|
||||
checksum = int.from_bytes(await self.read(4), 'little')
|
||||
|
||||
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
|
||||
if checksum != valid_checksum:
|
||||
|
@ -156,72 +158,70 @@ class Connection:
|
|||
|
||||
return body
|
||||
|
||||
def _recv_intermediate(self):
|
||||
return self.read(int.from_bytes(self.read(4), 'little'))
|
||||
async def _recv_intermediate(self):
|
||||
return await self.read(int.from_bytes(self.read(4), 'little'))
|
||||
|
||||
def _recv_abridged(self):
|
||||
async def _recv_abridged(self):
|
||||
length = int.from_bytes(self.read(1), 'little')
|
||||
if length >= 127:
|
||||
length = int.from_bytes(self.read(3) + b'\0', 'little')
|
||||
|
||||
return self.read(length << 2)
|
||||
return await self.read(length << 2)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Send message implementations
|
||||
|
||||
def send(self, message):
|
||||
async def send(self, message):
|
||||
"""Encapsulates and sends the given message"""
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _send_tcp_full(self, message):
|
||||
async def _send_tcp_full(self, message):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(message) + 12
|
||||
data = struct.pack('<ii', length, self._send_counter) + message
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
self.write(data + crc)
|
||||
await self.write(data + crc)
|
||||
|
||||
def _send_intermediate(self, message):
|
||||
self.write(struct.pack('<i', len(message)) + message)
|
||||
async def _send_intermediate(self, message):
|
||||
await self.write(struct.pack('<i', len(message)) + message)
|
||||
|
||||
def _send_abridged(self, message):
|
||||
async def _send_abridged(self, message):
|
||||
length = len(message) >> 2
|
||||
if length < 127:
|
||||
length = struct.pack('B', length)
|
||||
else:
|
||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||
|
||||
self.write(length + message)
|
||||
await self.write(length + message)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Read implementations
|
||||
|
||||
def read(self, length):
|
||||
async def read(self, length):
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _read_plain(self, length):
|
||||
return self.conn.read(length)
|
||||
async def _read_plain(self, length):
|
||||
return await self.conn.read(length)
|
||||
|
||||
def _read_obfuscated(self, length):
|
||||
return self._aes_decrypt.encrypt(
|
||||
self.conn.read(length)
|
||||
)
|
||||
async def _read_obfuscated(self, length):
|
||||
return await self._aes_decrypt.encrypt(self.conn.read(length))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Write implementations
|
||||
|
||||
def write(self, data):
|
||||
async def write(self, data):
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _write_plain(self, data):
|
||||
self.conn.write(data)
|
||||
async def _write_plain(self, data):
|
||||
await self.conn.write(data)
|
||||
|
||||
def _write_obfuscated(self, data):
|
||||
self.conn.write(self._aes_encrypt.encrypt(data))
|
||||
async def _write_obfuscated(self, data):
|
||||
await self.conn.write(self._aes_encrypt.encrypt(data))
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -16,23 +16,23 @@ class MtProtoPlainSender:
|
|||
self._last_msg_id = 0
|
||||
self._connection = connection
|
||||
|
||||
def connect(self):
|
||||
self._connection.connect()
|
||||
async def connect(self):
|
||||
await self._connection.connect()
|
||||
|
||||
def disconnect(self):
|
||||
self._connection.close()
|
||||
|
||||
def send(self, data):
|
||||
async def send(self, data):
|
||||
"""Sends a plain packet (auth_key_id = 0) containing the
|
||||
given message body (data)
|
||||
"""
|
||||
self._connection.send(
|
||||
await self._connection.send(
|
||||
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
||||
)
|
||||
|
||||
def receive(self):
|
||||
async def receive(self):
|
||||
"""Receives a plain packet, returning the body of the response"""
|
||||
body = self._connection.recv()
|
||||
body = await self._connection.recv()
|
||||
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
||||
# Broken authorization, must reset the auth key
|
||||
raise BrokenAuthKeyError()
|
||||
|
|
|
@ -41,9 +41,9 @@ class MtProtoSender:
|
|||
# Requests (as msg_id: Message) sent waiting to be received
|
||||
self._pending_receive = {}
|
||||
|
||||
def connect(self):
|
||||
async def connect(self):
|
||||
"""Connects to the server"""
|
||||
self.connection.connect(self.session.server_address, self.session.port)
|
||||
await self.connection.connect(self.session.server_address, self.session.port)
|
||||
|
||||
def is_connected(self):
|
||||
return self.connection.is_connected()
|
||||
|
@ -60,7 +60,7 @@ class MtProtoSender:
|
|||
|
||||
# region Send and receive
|
||||
|
||||
def send(self, *requests):
|
||||
async def send(self, *requests):
|
||||
"""Sends the specified MTProtoRequest, previously sending any message
|
||||
which needed confirmation."""
|
||||
|
||||
|
@ -80,13 +80,13 @@ class MtProtoSender:
|
|||
else:
|
||||
message = TLMessage(self.session, MessageContainer(messages))
|
||||
|
||||
self._send_message(message)
|
||||
await self._send_message(message)
|
||||
|
||||
def _send_acknowledge(self, msg_id):
|
||||
async def _send_acknowledge(self, msg_id):
|
||||
"""Sends a message acknowledge for the given msg_id"""
|
||||
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||
|
||||
def receive(self, update_state):
|
||||
async def receive(self, update_state):
|
||||
"""Receives a single message from the connected endpoint.
|
||||
|
||||
This method returns nothing, and will only affect other parts
|
||||
|
@ -97,7 +97,7 @@ class MtProtoSender:
|
|||
update_state.process(TLObject).
|
||||
"""
|
||||
try:
|
||||
body = self.connection.recv()
|
||||
body = await self.connection.recv()
|
||||
except (BufferError, InvalidChecksumError):
|
||||
# TODO BufferError, we should spot the cause...
|
||||
# "No more bytes left"; something wrong happened, clear
|
||||
|
@ -111,13 +111,13 @@ class MtProtoSender:
|
|||
|
||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
with BinaryReader(message) as reader:
|
||||
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Low level processing
|
||||
|
||||
def _send_message(self, message):
|
||||
async def _send_message(self, message):
|
||||
"""Sends the given Message(TLObject) encrypted through the network"""
|
||||
|
||||
plain_text = \
|
||||
|
@ -130,7 +130,7 @@ class MtProtoSender:
|
|||
cipher_text = AES.encrypt_ige(plain_text, key, iv)
|
||||
|
||||
result = key_id + msg_key + cipher_text
|
||||
self.connection.send(result)
|
||||
await self.connection.send(result)
|
||||
|
||||
def _decode_msg(self, body):
|
||||
"""Decodes an received encrypted message body bytes"""
|
||||
|
@ -163,7 +163,7 @@ class MtProtoSender:
|
|||
|
||||
return message, remote_msg_id, remote_sequence
|
||||
|
||||
def _process_msg(self, msg_id, sequence, reader, state):
|
||||
async def _process_msg(self, msg_id, sequence, reader, state):
|
||||
"""Processes and handles a Telegram message.
|
||||
|
||||
Returns True if the message was handled correctly and doesn't
|
||||
|
@ -178,22 +178,22 @@ class MtProtoSender:
|
|||
|
||||
# The following codes are "parsed manually"
|
||||
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
|
||||
return self._handle_rpc_result(msg_id, sequence, reader)
|
||||
return await self._handle_rpc_result(msg_id, sequence, reader)
|
||||
|
||||
if code == 0x347773c5: # pong
|
||||
return self._handle_pong(msg_id, sequence, reader)
|
||||
return await self._handle_pong(msg_id, sequence, reader)
|
||||
|
||||
if code == 0x73f1f8dc: # msg_container
|
||||
return self._handle_container(msg_id, sequence, reader, state)
|
||||
return await self._handle_container(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0x3072cfa1: # gzip_packed
|
||||
return self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0xedab447b: # bad_server_salt
|
||||
return self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||
return await self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||
|
||||
if code == 0xa7eff811: # bad_msg_notification
|
||||
return self._handle_bad_msg_notification(msg_id, sequence, reader)
|
||||
return await self._handle_bad_msg_notification(msg_id, sequence, reader)
|
||||
|
||||
# msgs_ack, it may handle the request we wanted
|
||||
if code == 0x62d6b459:
|
||||
|
@ -247,7 +247,7 @@ class MtProtoSender:
|
|||
r.confirm_received.set()
|
||||
self._pending_receive.clear()
|
||||
|
||||
def _handle_pong(self, msg_id, sequence, reader):
|
||||
async def _handle_pong(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling pong')
|
||||
reader.read_int(signed=False) # code
|
||||
received_msg_id = reader.read_long()
|
||||
|
@ -259,7 +259,7 @@ class MtProtoSender:
|
|||
|
||||
return True
|
||||
|
||||
def _handle_container(self, msg_id, sequence, reader, state):
|
||||
async def _handle_container(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling container')
|
||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
||||
begin_position = reader.tell_position()
|
||||
|
@ -267,7 +267,7 @@ class MtProtoSender:
|
|||
# Note that this code is IMPORTANT for skipping RPC results of
|
||||
# lost requests (i.e., ones from the previous connection session)
|
||||
try:
|
||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
if not await self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
reader.set_position(begin_position + inner_len)
|
||||
except:
|
||||
# If any error is raised, something went wrong; skip the packet
|
||||
|
@ -276,7 +276,7 @@ class MtProtoSender:
|
|||
|
||||
return True
|
||||
|
||||
def _handle_bad_server_salt(self, msg_id, sequence, reader):
|
||||
async def _handle_bad_server_salt(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling bad server salt')
|
||||
reader.read_int(signed=False) # code
|
||||
bad_msg_id = reader.read_long()
|
||||
|
@ -287,11 +287,11 @@ class MtProtoSender:
|
|||
|
||||
request = self._pop_request(bad_msg_id)
|
||||
if request:
|
||||
self.send(request)
|
||||
await self.send(request)
|
||||
|
||||
return True
|
||||
|
||||
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
|
||||
async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling bad message notification')
|
||||
reader.read_int(signed=False) # code
|
||||
reader.read_long() # request_id
|
||||
|
@ -318,7 +318,7 @@ class MtProtoSender:
|
|||
else:
|
||||
raise error
|
||||
|
||||
def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||
async def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling RPC result')
|
||||
reader.read_int(signed=False) # code
|
||||
request_id = reader.read_long()
|
||||
|
@ -338,7 +338,7 @@ class MtProtoSender:
|
|||
)
|
||||
|
||||
# Acknowledge that we received the error
|
||||
self._send_acknowledge(request_id)
|
||||
await self._send_acknowledge(request_id)
|
||||
|
||||
if request:
|
||||
request.rpc_error = error
|
||||
|
@ -366,9 +366,9 @@ class MtProtoSender:
|
|||
self._logger.debug('Lost request will be skipped.')
|
||||
return False
|
||||
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling gzip packed data')
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
return await self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -137,7 +137,7 @@ class TelegramBareClient:
|
|||
|
||||
# region Connecting
|
||||
|
||||
def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
|
||||
async def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
|
||||
"""Connects to the Telegram servers, executing authentication if
|
||||
required. Note that authenticating to the Telegram servers is
|
||||
not the same as authenticating the desired user itself, which
|
||||
|
@ -158,13 +158,13 @@ class TelegramBareClient:
|
|||
centers won't be invoked.
|
||||
"""
|
||||
try:
|
||||
self._sender.connect()
|
||||
await self._sender.connect()
|
||||
if not self.session.auth_key:
|
||||
# New key, we need to tell the server we're going to use
|
||||
# the latest layer
|
||||
try:
|
||||
self.session.auth_key, self.session.time_offset = \
|
||||
authenticator.do_authentication(self._sender.connection)
|
||||
await authenticator.do_authentication(self._sender.connection)
|
||||
except BrokenAuthKeyError:
|
||||
return False
|
||||
|
||||
|
@ -176,21 +176,21 @@ class TelegramBareClient:
|
|||
|
||||
if init_connection:
|
||||
if _exported_auth is not None:
|
||||
self._init_connection(ImportAuthorizationRequest(
|
||||
await self._init_connection(ImportAuthorizationRequest(
|
||||
_exported_auth.id, _exported_auth.bytes
|
||||
))
|
||||
elif not _cdn:
|
||||
TelegramBareClient._dc_options = \
|
||||
self._init_connection(GetConfigRequest()).dc_options
|
||||
(await self._init_connection(GetConfigRequest())).dc_options
|
||||
|
||||
elif _exported_auth is not None:
|
||||
self(ImportAuthorizationRequest(
|
||||
await self(ImportAuthorizationRequest(
|
||||
_exported_auth.id, _exported_auth.bytes
|
||||
))
|
||||
|
||||
if TelegramBareClient._dc_options is None and not _cdn:
|
||||
TelegramBareClient._dc_options = \
|
||||
self(GetConfigRequest()).dc_options
|
||||
(await self(GetConfigRequest())).dc_options
|
||||
|
||||
# Connection was successful! Try syncing the update state
|
||||
# UNLESS '_sync_updates' is False (we probably are in
|
||||
|
@ -199,7 +199,7 @@ class TelegramBareClient:
|
|||
self._user_connected = True
|
||||
if _sync_updates and not _cdn:
|
||||
try:
|
||||
self.sync_updates()
|
||||
await self.sync_updates()
|
||||
self._set_connected_and_authorized()
|
||||
except UnauthorizedError:
|
||||
self._authorized = False
|
||||
|
@ -227,8 +227,8 @@ class TelegramBareClient:
|
|||
def is_connected(self):
|
||||
return self._sender.is_connected()
|
||||
|
||||
def _init_connection(self, query=None):
|
||||
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
|
||||
async def _init_connection(self, query=None):
|
||||
result = await self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
|
||||
api_id=self.api_id,
|
||||
device_model=self.session.device_model,
|
||||
system_version=self.session.system_version,
|
||||
|
@ -249,7 +249,7 @@ class TelegramBareClient:
|
|||
# TODO Shall we clear the _exported_sessions, or may be reused?
|
||||
pass
|
||||
|
||||
def _reconnect(self, new_dc=None):
|
||||
async def _reconnect(self, new_dc=None):
|
||||
"""If 'new_dc' is not set, only a call to .connect() will be made
|
||||
since it's assumed that the connection has been lost and the
|
||||
library is reconnecting.
|
||||
|
@ -260,7 +260,7 @@ class TelegramBareClient:
|
|||
"""
|
||||
if new_dc is None:
|
||||
# Assume we are disconnected due to some error, so connect again
|
||||
return self.connect()
|
||||
return await self.connect()
|
||||
else:
|
||||
self.disconnect()
|
||||
self.session.auth_key = None # Force creating new auth_key
|
||||
|
@ -269,23 +269,24 @@ class TelegramBareClient:
|
|||
self.session.server_address = ip
|
||||
self.session.port = dc.port
|
||||
self.session.save()
|
||||
return self.connect()
|
||||
return await self.connect()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Working with different connections/Data Centers
|
||||
|
||||
def _get_dc(self, dc_id, ipv6=False, cdn=False):
|
||||
async def _get_dc(self, dc_id, ipv6=False, cdn=False):
|
||||
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
||||
if TelegramBareClient._dc_options is None:
|
||||
raise ConnectionError(
|
||||
'Cannot determine the required data center IP address. '
|
||||
'Stabilise a successful initial connection first.')
|
||||
'Stabilise a successful initial connection first.'
|
||||
)
|
||||
|
||||
try:
|
||||
if cdn:
|
||||
# Ensure we have the latest keys for the CDNs
|
||||
for pk in self(GetCdnConfigRequest()).public_keys:
|
||||
for pk in await (self(GetCdnConfigRequest())).public_keys:
|
||||
rsa.add_key(pk.public_key)
|
||||
|
||||
return next(
|
||||
|
@ -297,10 +298,10 @@ class TelegramBareClient:
|
|||
raise
|
||||
|
||||
# New configuration, perhaps a new CDN was added?
|
||||
TelegramBareClient._dc_options = self(GetConfigRequest()).dc_options
|
||||
TelegramBareClient._dc_options = await (self(GetConfigRequest())).dc_options
|
||||
return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn)
|
||||
|
||||
def _get_exported_client(self, dc_id):
|
||||
async def _get_exported_client(self, dc_id):
|
||||
"""Creates and connects a new TelegramBareClient for the desired DC.
|
||||
|
||||
If it's the first time calling the method with a given dc_id,
|
||||
|
@ -317,10 +318,10 @@ class TelegramBareClient:
|
|||
# TODO Add a lock, don't allow two threads to create an auth key
|
||||
# (when calling .connect() if there wasn't a previous session).
|
||||
# for the same data center.
|
||||
dc = self._get_dc(dc_id)
|
||||
dc = await self._get_dc(dc_id)
|
||||
|
||||
# Export the current authorization to the new DC.
|
||||
export_auth = self(ExportAuthorizationRequest(dc_id))
|
||||
export_auth = await self(ExportAuthorizationRequest(dc_id))
|
||||
|
||||
# Create a temporary session for this IP address, which needs
|
||||
# to be different because each auth_key is unique per DC.
|
||||
|
@ -337,15 +338,15 @@ class TelegramBareClient:
|
|||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
client.connect(_exported_auth=export_auth, _sync_updates=False)
|
||||
await client.connect(_exported_auth=export_auth, _sync_updates=False)
|
||||
client._authorized = True # We exported the auth, so we got auth
|
||||
return client
|
||||
|
||||
def _get_cdn_client(self, cdn_redirect):
|
||||
async def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._get_exported_client, but for CDNs"""
|
||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
if not session:
|
||||
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = Session(self.session)
|
||||
session.server_address = dc.ip_address
|
||||
session.port = dc.port
|
||||
|
@ -361,7 +362,7 @@ class TelegramBareClient:
|
|||
#
|
||||
# This relies on the fact that TelegramBareClient._dc_options is
|
||||
# static and it won't be called from this DC (it would fail).
|
||||
client.connect(_cdn=True) # Avoid invoking non-CDN specific methods
|
||||
await client.connect(_cdn=True) # Avoid invoking non-CDN methods
|
||||
client._authorized = self._authorized
|
||||
return client
|
||||
|
||||
|
@ -369,7 +370,7 @@ class TelegramBareClient:
|
|||
|
||||
# region Invoking Telegram requests
|
||||
|
||||
def __call__(self, *requests, retries=5):
|
||||
async def __call__(self, *requests, retries=5):
|
||||
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
|
||||
|
||||
The invoke will be retried up to 'retries' times before raising
|
||||
|
@ -384,7 +385,7 @@ class TelegramBareClient:
|
|||
|
||||
try:
|
||||
for _ in range(retries):
|
||||
result = self._invoke(sender, *requests)
|
||||
result = await self._invoke(sender, *requests)
|
||||
if result:
|
||||
return result
|
||||
|
||||
|
@ -396,16 +397,16 @@ class TelegramBareClient:
|
|||
# Let people use client.invoke(SomeRequest()) instead client(...)
|
||||
invoke = __call__
|
||||
|
||||
def _invoke(self, sender, *requests):
|
||||
async def _invoke(self, sender, *requests):
|
||||
try:
|
||||
# Ensure that we start with no previous errors (i.e. resending)
|
||||
for x in requests:
|
||||
x.confirm_received.clear()
|
||||
x.rpc_error = None
|
||||
|
||||
sender.send(*requests)
|
||||
await sender.send(*requests)
|
||||
while not all(x.confirm_received.is_set() for x in requests):
|
||||
sender.receive(update_state=self.updates)
|
||||
await sender.receive(update_state=self.updates)
|
||||
|
||||
except TimeoutError:
|
||||
pass # We will just retry
|
||||
|
@ -420,9 +421,9 @@ class TelegramBareClient:
|
|||
|
||||
if sender != self._sender:
|
||||
# TODO Try reconnecting forever too?
|
||||
sender.connect()
|
||||
await sender.connect()
|
||||
else:
|
||||
while self._user_connected and not self._reconnect():
|
||||
while self._user_connected and not await self._reconnect():
|
||||
sleep(0.1) # Retry forever until we can send the request
|
||||
|
||||
finally:
|
||||
|
@ -449,8 +450,8 @@ class TelegramBareClient:
|
|||
'attempting to reconnect at DC {}'.format(e.new_dc)
|
||||
)
|
||||
|
||||
self._reconnect(new_dc=e.new_dc)
|
||||
return self._invoke(sender, *requests)
|
||||
await self._reconnect(new_dc=e.new_dc)
|
||||
return await self._invoke(sender, *requests)
|
||||
|
||||
except ServerError as e:
|
||||
# Telegram is having some issues, just retry
|
||||
|
@ -474,7 +475,7 @@ class TelegramBareClient:
|
|||
|
||||
# region Uploading media
|
||||
|
||||
def upload_file(self,
|
||||
async def upload_file(self,
|
||||
file,
|
||||
part_size_kb=None,
|
||||
file_name=None,
|
||||
|
@ -537,7 +538,7 @@ class TelegramBareClient:
|
|||
else:
|
||||
request = SaveFilePartRequest(file_id, part_index, part)
|
||||
|
||||
result = self(request)
|
||||
result = await self(request)
|
||||
if result:
|
||||
if not is_large:
|
||||
# No need to update the hash if it's a large file
|
||||
|
@ -568,7 +569,7 @@ class TelegramBareClient:
|
|||
|
||||
# region Downloading media
|
||||
|
||||
def download_file(self,
|
||||
async def download_file(self,
|
||||
input_location,
|
||||
file,
|
||||
part_size_kb=None,
|
||||
|
@ -616,18 +617,20 @@ class TelegramBareClient:
|
|||
if cdn_decrypter:
|
||||
result = cdn_decrypter.get_file()
|
||||
else:
|
||||
result = client(GetFileRequest(
|
||||
result = await client(GetFileRequest(
|
||||
input_location, offset, part_size
|
||||
))
|
||||
|
||||
if isinstance(result, FileCdnRedirect):
|
||||
cdn_decrypter, result = \
|
||||
CdnDecrypter.prepare_decrypter(
|
||||
client, self._get_cdn_client(result), result
|
||||
client,
|
||||
await self._get_cdn_client(result),
|
||||
result
|
||||
)
|
||||
|
||||
except FileMigrateError as e:
|
||||
client = self._get_exported_client(e.new_dc)
|
||||
client = await self._get_exported_client(e.new_dc)
|
||||
continue
|
||||
|
||||
offset_index += 1
|
||||
|
@ -657,12 +660,12 @@ class TelegramBareClient:
|
|||
|
||||
# region Updates handling
|
||||
|
||||
def sync_updates(self):
|
||||
async def sync_updates(self):
|
||||
"""Synchronizes self.updates to their initial state. Will be
|
||||
called automatically on connection if self.updates.enabled = True,
|
||||
otherwise it should be called manually after enabling updates.
|
||||
"""
|
||||
self.updates.process(self(GetStateRequest()))
|
||||
self.updates.process(await self(GetStateRequest()))
|
||||
|
||||
def add_update_handler(self, handler):
|
||||
"""Adds an update handler (a function which takes a TLObject,
|
||||
|
|
|
@ -99,15 +99,15 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Authorization requests
|
||||
|
||||
def send_code_request(self, phone):
|
||||
async def send_code_request(self, phone):
|
||||
"""Sends a code request to the specified phone number"""
|
||||
phone = EntityDatabase.parse_phone(phone) or self._phone
|
||||
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
|
||||
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
|
||||
self._phone = phone
|
||||
self._phone_code_hash = result.phone_code_hash
|
||||
return result
|
||||
|
||||
def sign_in(self, phone=None, code=None,
|
||||
async def sign_in(self, phone=None, code=None,
|
||||
password=None, bot_token=None, phone_code_hash=None):
|
||||
"""Completes the sign in process with the phone number + code pair.
|
||||
|
||||
|
@ -132,7 +132,7 @@ class TelegramClient(TelegramBareClient):
|
|||
"""
|
||||
|
||||
if phone and not code:
|
||||
return self.send_code_request(phone)
|
||||
return await self.send_code_request(phone)
|
||||
elif code:
|
||||
phone = EntityDatabase.parse_phone(phone) or self._phone
|
||||
phone_code_hash = phone_code_hash or self._phone_code_hash
|
||||
|
@ -147,18 +147,18 @@ class TelegramClient(TelegramBareClient):
|
|||
if isinstance(code, int):
|
||||
code = str(code)
|
||||
|
||||
result = self(SignInRequest(phone, phone_code_hash, code))
|
||||
result = await self(SignInRequest(phone, phone_code_hash, code))
|
||||
|
||||
except (PhoneCodeEmptyError, PhoneCodeExpiredError,
|
||||
PhoneCodeHashEmptyError, PhoneCodeInvalidError):
|
||||
return None
|
||||
elif password:
|
||||
salt = self(GetPasswordRequest()).current_salt
|
||||
result = self(CheckPasswordRequest(
|
||||
salt = await self(GetPasswordRequest()).current_salt
|
||||
result = await self(CheckPasswordRequest(
|
||||
helpers.get_password_hash(password, salt)
|
||||
))
|
||||
elif bot_token:
|
||||
result = self(ImportBotAuthorizationRequest(
|
||||
result = await self(ImportBotAuthorizationRequest(
|
||||
flags=0, bot_auth_token=bot_token,
|
||||
api_id=self.api_id, api_hash=self.api_hash
|
||||
))
|
||||
|
@ -171,9 +171,9 @@ class TelegramClient(TelegramBareClient):
|
|||
self._set_connected_and_authorized()
|
||||
return result.user
|
||||
|
||||
def sign_up(self, code, first_name, last_name=''):
|
||||
async def sign_up(self, code, first_name, last_name=''):
|
||||
"""Signs up to Telegram. Make sure you sent a code request first!"""
|
||||
result = self(SignUpRequest(
|
||||
result = await self(SignUpRequest(
|
||||
phone_number=self._phone,
|
||||
phone_code_hash=self._phone_code_hash,
|
||||
phone_code=code,
|
||||
|
@ -184,11 +184,11 @@ class TelegramClient(TelegramBareClient):
|
|||
self._set_connected_and_authorized()
|
||||
return result.user
|
||||
|
||||
def log_out(self):
|
||||
async def log_out(self):
|
||||
"""Logs out and deletes the current session.
|
||||
Returns True if everything went okay."""
|
||||
try:
|
||||
self(LogOutRequest())
|
||||
await self(LogOutRequest())
|
||||
except RPCError:
|
||||
return False
|
||||
|
||||
|
@ -197,11 +197,11 @@ class TelegramClient(TelegramBareClient):
|
|||
self.session = None
|
||||
return True
|
||||
|
||||
def get_me(self):
|
||||
async def get_me(self):
|
||||
"""Gets "me" (the self user) which is currently authenticated,
|
||||
or None if the request fails (hence, not authenticated)."""
|
||||
try:
|
||||
return self(GetUsersRequest([InputUserSelf()]))[0]
|
||||
return await self(GetUsersRequest([InputUserSelf()]))[0]
|
||||
except UnauthorizedError:
|
||||
return None
|
||||
|
||||
|
@ -209,7 +209,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Dialogs ("chats") requests
|
||||
|
||||
def get_dialogs(self,
|
||||
async def get_dialogs(self,
|
||||
limit=10,
|
||||
offset_date=None,
|
||||
offset_id=0,
|
||||
|
@ -232,7 +232,7 @@ class TelegramClient(TelegramBareClient):
|
|||
entities = {}
|
||||
while len(dialogs) < limit:
|
||||
need = limit - len(dialogs)
|
||||
r = self(GetDialogsRequest(
|
||||
r = await self(GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
|
@ -281,7 +281,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Message requests
|
||||
|
||||
def send_message(self,
|
||||
async def send_message(self,
|
||||
entity,
|
||||
message,
|
||||
reply_to=None,
|
||||
|
@ -292,7 +292,7 @@ class TelegramClient(TelegramBareClient):
|
|||
If 'reply_to' is set to either a message or a message ID,
|
||||
the sent message will be replying to such message.
|
||||
"""
|
||||
entity = self.get_input_entity(entity)
|
||||
entity = await self.get_input_entity(entity)
|
||||
request = SendMessageRequest(
|
||||
peer=entity,
|
||||
message=message,
|
||||
|
@ -300,7 +300,7 @@ class TelegramClient(TelegramBareClient):
|
|||
no_webpage=not link_preview,
|
||||
reply_to_msg_id=self._get_reply_to(reply_to)
|
||||
)
|
||||
result = self(request)
|
||||
result = await self(request)
|
||||
if isinstance(result, UpdateShortSentMessage):
|
||||
return Message(
|
||||
id=result.id,
|
||||
|
@ -328,7 +328,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
return None # Should not happen
|
||||
|
||||
def delete_messages(self, entity, message_ids, revoke=True):
|
||||
async def delete_messages(self, entity, message_ids, revoke=True):
|
||||
"""
|
||||
Deletes a message from a chat, optionally "for everyone" with argument
|
||||
`revoke` set to `True`.
|
||||
|
@ -352,16 +352,16 @@ class TelegramClient(TelegramBareClient):
|
|||
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
|
||||
|
||||
if entity is None:
|
||||
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||
|
||||
entity = self.get_input_entity(entity)
|
||||
entity = await self.get_input_entity(entity)
|
||||
|
||||
if isinstance(entity, InputPeerChannel):
|
||||
return self(channels.DeleteMessagesRequest(entity, message_ids))
|
||||
return await self(channels.DeleteMessagesRequest(entity, message_ids))
|
||||
else:
|
||||
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||
|
||||
def get_message_history(self,
|
||||
async def get_message_history(self,
|
||||
entity,
|
||||
limit=20,
|
||||
offset_date=None,
|
||||
|
@ -386,8 +386,8 @@ class TelegramClient(TelegramBareClient):
|
|||
The entity may be a phone or an username at the expense of
|
||||
some performance loss.
|
||||
"""
|
||||
result = self(GetHistoryRequest(
|
||||
peer=self.get_input_entity(entity),
|
||||
result = await self(GetHistoryRequest(
|
||||
peer=await self.get_input_entity(entity),
|
||||
limit=limit,
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
|
@ -413,7 +413,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
return total_messages, result.messages, entities
|
||||
|
||||
def send_read_acknowledge(self, entity, messages=None, max_id=None):
|
||||
async def send_read_acknowledge(self, entity, messages=None, max_id=None):
|
||||
"""Sends a "read acknowledge" (i.e., notifying the given peer that we've
|
||||
read their messages, also known as the "double check").
|
||||
|
||||
|
@ -435,8 +435,8 @@ class TelegramClient(TelegramBareClient):
|
|||
else:
|
||||
max_id = messages.id
|
||||
|
||||
return self(ReadHistoryRequest(
|
||||
peer=self.get_input_entity(entity),
|
||||
return await self(ReadHistoryRequest(
|
||||
peer=await self.get_input_entity(entity),
|
||||
max_id=max_id
|
||||
))
|
||||
|
||||
|
@ -460,7 +460,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Uploading files
|
||||
|
||||
def send_file(self, entity, file, caption='',
|
||||
async def send_file(self, entity, file, caption='',
|
||||
force_document=False, progress_callback=None,
|
||||
reply_to=None,
|
||||
**kwargs):
|
||||
|
@ -500,7 +500,7 @@ class TelegramClient(TelegramBareClient):
|
|||
if file_hash in self._upload_cache:
|
||||
file_handle = self._upload_cache[file_hash]
|
||||
else:
|
||||
self._upload_cache[file_hash] = file_handle = self.upload_file(
|
||||
self._upload_cache[file_hash] = file_handle = await self.upload_file(
|
||||
file, progress_callback=progress_callback
|
||||
)
|
||||
|
||||
|
@ -538,16 +538,16 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# Once the media type is properly specified and the file uploaded,
|
||||
# send the media message to the desired entity.
|
||||
self(SendMediaRequest(
|
||||
peer=self.get_input_entity(entity),
|
||||
await self(SendMediaRequest(
|
||||
peer=await self.get_input_entity(entity),
|
||||
media=media,
|
||||
reply_to_msg_id=self._get_reply_to(reply_to)
|
||||
))
|
||||
|
||||
def send_voice_note(self, entity, file, caption='', upload_progress=None,
|
||||
async def send_voice_note(self, entity, file, caption='', upload_progress=None,
|
||||
reply_to=None):
|
||||
"""Wrapper method around .send_file() with is_voice_note=()"""
|
||||
return self.send_file(entity, file, caption,
|
||||
return await self.send_file(entity, file, caption,
|
||||
upload_progress=upload_progress,
|
||||
reply_to=reply_to,
|
||||
is_voice_note=()) # empty tuple is enough
|
||||
|
@ -564,7 +564,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Downloading media requests
|
||||
|
||||
def download_profile_photo(self, entity, file=None, download_big=True):
|
||||
async def download_profile_photo(self, entity, file=None, download_big=True):
|
||||
"""Downloads the profile photo for an user or a chat (channels too).
|
||||
Returns None if no photo was provided, or if it was Empty.
|
||||
|
||||
|
@ -590,12 +590,12 @@ class TelegramClient(TelegramBareClient):
|
|||
# The hexadecimal numbers above are simply:
|
||||
# hex(crc32(x.encode('ascii'))) for x in
|
||||
# ('User', 'Chat', 'UserFull', 'ChatFull')
|
||||
entity = self.get_entity(entity)
|
||||
entity = await self.get_entity(entity)
|
||||
if not hasattr(entity, 'photo'):
|
||||
# Special case: may be a ChatFull with photo:Photo
|
||||
# This is different from a normal UserProfilePhoto and Chat
|
||||
if hasattr(entity, 'chat_photo'):
|
||||
return self._download_photo(
|
||||
return await self._download_photo(
|
||||
entity.chat_photo, file,
|
||||
date=None, progress_callback=None
|
||||
)
|
||||
|
@ -623,7 +623,7 @@ class TelegramClient(TelegramBareClient):
|
|||
)
|
||||
|
||||
# Download the media with the largest size input file location
|
||||
self.download_file(
|
||||
await self.download_file(
|
||||
InputFileLocation(
|
||||
volume_id=photo_location.volume_id,
|
||||
local_id=photo_location.local_id,
|
||||
|
@ -633,7 +633,7 @@ class TelegramClient(TelegramBareClient):
|
|||
)
|
||||
return file
|
||||
|
||||
def download_media(self, message, file=None, progress_callback=None):
|
||||
async def download_media(self, message, file=None, progress_callback=None):
|
||||
"""Downloads the media from a specified Message (it can also be
|
||||
the message.media) into the desired file (a stream or str),
|
||||
optionally finding its extension automatically.
|
||||
|
@ -659,19 +659,19 @@ class TelegramClient(TelegramBareClient):
|
|||
media = message
|
||||
|
||||
if isinstance(media, MessageMediaPhoto):
|
||||
return self._download_photo(
|
||||
return await self._download_photo(
|
||||
media, file, date, progress_callback
|
||||
)
|
||||
elif isinstance(media, MessageMediaDocument):
|
||||
return self._download_document(
|
||||
return await self._download_document(
|
||||
media, file, date, progress_callback
|
||||
)
|
||||
elif isinstance(media, MessageMediaContact):
|
||||
return self._download_contact(
|
||||
return await self._download_contact(
|
||||
media, file
|
||||
)
|
||||
|
||||
def _download_photo(self, mm_photo, file, date, progress_callback):
|
||||
async def _download_photo(self, mm_photo, file, date, progress_callback):
|
||||
"""Specialized version of .download_media() for photos"""
|
||||
|
||||
# Determine the photo and its largest size
|
||||
|
@ -683,7 +683,7 @@ class TelegramClient(TelegramBareClient):
|
|||
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
|
||||
|
||||
# Download the media with the largest size input file location
|
||||
self.download_file(
|
||||
await self.download_file(
|
||||
InputFileLocation(
|
||||
volume_id=largest_size.volume_id,
|
||||
local_id=largest_size.local_id,
|
||||
|
@ -695,7 +695,7 @@ class TelegramClient(TelegramBareClient):
|
|||
)
|
||||
return file
|
||||
|
||||
def _download_document(self, mm_doc, file, date, progress_callback):
|
||||
async def _download_document(self, mm_doc, file, date, progress_callback):
|
||||
"""Specialized version of .download_media() for documents"""
|
||||
document = mm_doc.document
|
||||
file_size = document.size
|
||||
|
@ -715,7 +715,7 @@ class TelegramClient(TelegramBareClient):
|
|||
date=date, possible_names=possible_names
|
||||
)
|
||||
|
||||
self.download_file(
|
||||
await self.download_file(
|
||||
InputDocumentFileLocation(
|
||||
id=document.id,
|
||||
access_hash=document.access_hash,
|
||||
|
@ -826,7 +826,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
# region Small utilities to make users' life easier
|
||||
|
||||
def get_entity(self, entity):
|
||||
async def get_entity(self, entity):
|
||||
"""Turns an entity into a valid Telegram user or chat.
|
||||
If "entity" is a string which can be converted to an integer,
|
||||
or if it starts with '+' it will be resolved as if it
|
||||
|
@ -851,15 +851,14 @@ class TelegramClient(TelegramBareClient):
|
|||
isinstance(entity, TLObject) and
|
||||
# crc32(b'InputPeer') and crc32(b'Peer')
|
||||
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
|
||||
ie = self.get_input_entity(entity)
|
||||
ie = await self.get_input_entity(entity)
|
||||
result = None
|
||||
if isinstance(ie, InputPeerUser):
|
||||
result = self(GetUsersRequest([ie]))
|
||||
result = await self(GetUsersRequest([ie]))
|
||||
elif isinstance(ie, InputPeerChat):
|
||||
result = self(GetChatsRequest([ie.chat_id]))
|
||||
result = await self(GetChatsRequest([ie.chat_id]))
|
||||
elif isinstance(ie, InputPeerChannel):
|
||||
result = self(GetChannelsRequest([ie]))
|
||||
|
||||
result = await self(GetChannelsRequest([ie]))
|
||||
if result:
|
||||
self.session.process_entities(result)
|
||||
try:
|
||||
|
@ -868,23 +867,23 @@ class TelegramClient(TelegramBareClient):
|
|||
pass
|
||||
|
||||
if isinstance(entity, str):
|
||||
return self._get_entity_from_string(entity)
|
||||
return await self._get_entity_from_string(entity)
|
||||
|
||||
raise ValueError(
|
||||
'Cannot turn "{}" into any entity (user or chat)'.format(entity)
|
||||
)
|
||||
|
||||
def _get_entity_from_string(self, string):
|
||||
async def _get_entity_from_string(self, string):
|
||||
"""Gets an entity from the given string, which may be a phone or
|
||||
an username, and processes all the found entities on the session.
|
||||
"""
|
||||
phone = EntityDatabase.parse_phone(string)
|
||||
if phone:
|
||||
entity = phone
|
||||
self.session.process_entities(self(GetContactsRequest(0)))
|
||||
self.session.process_entities(await self(GetContactsRequest(0)))
|
||||
else:
|
||||
entity = string.strip('@').lower()
|
||||
self.session.process_entities(self(ResolveUsernameRequest(entity)))
|
||||
self.session.process_entities(await self(ResolveUsernameRequest(entity)))
|
||||
|
||||
try:
|
||||
return self.session.entities[entity]
|
||||
|
@ -893,7 +892,7 @@ class TelegramClient(TelegramBareClient):
|
|||
'Could not find user with username {}'.format(entity)
|
||||
)
|
||||
|
||||
def get_input_entity(self, peer):
|
||||
async def get_input_entity(self, peer):
|
||||
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel.
|
||||
If no Peer class is used, peer is assumed to be the integer ID
|
||||
of an User.
|
||||
|
@ -910,7 +909,7 @@ class TelegramClient(TelegramBareClient):
|
|||
pass
|
||||
|
||||
if isinstance(peer, str):
|
||||
return utils.get_input_peer(self._get_entity_from_string(peer))
|
||||
return utils.get_input_peer(await self._get_entity_from_string(peer))
|
||||
|
||||
is_peer = False
|
||||
if isinstance(peer, int):
|
||||
|
@ -932,7 +931,7 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
if self.session.save_entities:
|
||||
# Not found, look in the dialogs (this will save the users)
|
||||
self.get_dialogs(limit=None)
|
||||
await self.get_dialogs(limit=None)
|
||||
|
||||
try:
|
||||
return self.session.entities.get_input_entity(peer)
|
||||
|
|
Loading…
Reference in New Issue
Block a user