Revisit codebase to add missing async/await

This commit is contained in:
Lonami Exo 2018-06-14 17:09:20 +02:00
parent 1247d050ab
commit 908dfa148b
9 changed files with 73 additions and 111 deletions

View File

@ -518,7 +518,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
entities=msg_entities,
media=media
)
msg = self._get_response_message(request, self(request), entity)
msg = self._get_response_message(request, await self(request), entity)
self._cache_media(msg, file, file_handle)
return msg

View File

@ -164,15 +164,16 @@ class ChatAction(EventBuilder):
self.action_message = custom.Message(
client, self.action_message, self._entities, None)
def respond(self, *args, **kwargs):
async def respond(self, *args, **kwargs):
"""
Responds to the chat action message (not as a reply). Shorthand for
`telethon.telegram_client.TelegramClient.send_message` with
``entity`` already set.
"""
return self._client.send_message(self.input_chat, *args, **kwargs)
return await self._client.send_message(
await self.input_chat, *args, **kwargs)
def reply(self, *args, **kwargs):
async def reply(self, *args, **kwargs):
"""
Replies to the chat action message (as a reply). Shorthand for
`telethon.telegram_client.TelegramClient.send_message` with
@ -181,12 +182,13 @@ class ChatAction(EventBuilder):
Has the same effect as `respond` if there is no message.
"""
if not self.action_message:
return self.respond(*args, **kwargs)
return await self.respond(*args, **kwargs)
kwargs['reply_to'] = self.action_message.id
return self._client.send_message(self.input_chat, *args, **kwargs)
return await self._client.send_message(
await self.input_chat, *args, **kwargs)
def delete(self, *args, **kwargs):
async def delete(self, *args, **kwargs):
"""
Deletes the chat action message. You're responsible for checking
whether you have the permission to do so, or to except the error
@ -196,13 +198,14 @@ class ChatAction(EventBuilder):
Does nothing if no message action triggered this event.
"""
if self.action_message:
return self._client.delete_messages(self.input_chat,
[self.action_message],
*args, **kwargs)
if not self.action_message:
return
return await self._client.delete_messages(
await self.input_chat, [self.action_message], *args, **kwargs)
@property
def pinned_message(self):
async def pinned_message(self):
"""
If ``new_pin`` is ``True``, this returns the
`telethon.tl.custom.message.Message` object that was pinned.
@ -210,8 +213,8 @@ class ChatAction(EventBuilder):
if self._pinned_message == 0:
return None
if isinstance(self._pinned_message, int) and self.input_chat:
r = self._client(functions.channels.GetMessagesRequest(
if isinstance(self._pinned_message, int) and await self.input_chat:
r = await self._client(functions.channels.GetMessagesRequest(
self._input_chat, [self._pinned_message]
))
try:
@ -227,50 +230,48 @@ class ChatAction(EventBuilder):
return self._pinned_message
@property
def added_by(self):
async def added_by(self):
"""
The user who added ``users``, if applicable (``None`` otherwise).
"""
if self._added_by and not isinstance(self._added_by, types.User):
self._added_by =\
self._entities.get(utils.get_peer_id(self._added_by))
if not self._added_by:
self._added_by = self._client.get_entity(self._added_by)
aby = self._entities.get(utils.get_peer_id(self._added_by))
if not aby:
aby = await self._client.get_entity(self._added_by)
self._added_by = aby
return self._added_by
@property
def kicked_by(self):
async def kicked_by(self):
"""
The user who kicked ``users``, if applicable (``None`` otherwise).
"""
if self._kicked_by and not isinstance(self._kicked_by, types.User):
self._kicked_by =\
self._entities.get(utils.get_peer_id(self._kicked_by))
if not self._kicked_by:
self._kicked_by = self._client.get_entity(self._kicked_by)
kby = self._entities.get(utils.get_peer_id(self._kicked_by))
if kby:
kby = await self._client.get_entity(self._kicked_by)
self._kicked_by = kby
return self._kicked_by
@property
def user(self):
async def user(self):
"""
The first user that takes part in this action (e.g. joined).
Might be ``None`` if the information can't be retrieved or
there is no user taking part.
"""
if self.users:
if await self.users:
return self._users[0]
@property
def input_user(self):
async def input_user(self):
"""
Input version of the ``self.user`` property.
"""
if self.input_users:
if await self.input_users:
return self._input_users[0]
@property
@ -282,7 +283,7 @@ class ChatAction(EventBuilder):
return utils.get_peer_id(self._user_peers[0])
@property
def users(self):
async def users(self):
"""
A list of users that take part in this action (e.g. joined).
@ -302,7 +303,7 @@ class ChatAction(EventBuilder):
missing.append(peer)
try:
missing = self._client.get_entity(missing)
missing = await self._client.get_entity(missing)
except (TypeError, ValueError):
missing = []
@ -311,7 +312,7 @@ class ChatAction(EventBuilder):
return self._users
@property
def input_users(self):
async def input_users(self):
"""
Input version of the ``self.users`` property.
"""
@ -319,9 +320,9 @@ class ChatAction(EventBuilder):
self._input_users = []
for peer in self._user_peers:
try:
self._input_users.append(self._client.get_input_entity(
peer
))
self._input_users.append(
await self._client.get_input_entity(peer)
)
except (TypeError, ValueError):
pass
return self._input_users

View File

@ -109,40 +109,8 @@ class EventCommon(abc.ABC):
"""
self._client = client
def _get_entity(self, msg_id, entity_id, chat=None):
"""
Helper function to call :tl:`GetMessages` on the give msg_id and
return the input entity whose ID is the given entity ID.
If ``chat`` is present it must be an :tl:`InputPeer`.
Returns a tuple of ``(entity, input_peer)`` if it was found, or
a tuple of ``(None, None)`` if it couldn't be.
"""
try:
if isinstance(chat, types.InputPeerChannel):
result = self._client(
functions.channels.GetMessagesRequest(chat, [msg_id])
)
else:
result = self._client(
functions.messages.GetMessagesRequest([msg_id])
)
except RPCError:
return None, None
entity = {
utils.get_peer_id(x): x for x in itertools.chain(
getattr(result, 'chats', []),
getattr(result, 'users', []))
}.get(entity_id)
if entity:
return entity, utils.get_input_peer(entity)
else:
return None, None
@property
def input_chat(self):
async def input_chat(self):
"""
The (:tl:`InputPeer`) (group, megagroup or channel) on which
the event occurred. This doesn't have the title or anything,
@ -153,19 +121,12 @@ class EventCommon(abc.ABC):
"""
if self._input_chat is None and self._chat_peer is not None:
try:
self._input_chat = self._client.get_input_entity(
self._input_chat = await self._client.get_input_entity(
self._chat_peer
)
except (ValueError, TypeError):
# The library hasn't seen this chat, get the message
if not isinstance(self._chat_peer, types.PeerChannel):
# TODO For channels, getDifference? Maybe looking
# in the dialogs (which is already done) is enough.
if self._message_id is not None:
self._chat, self._input_chat = self._get_entity(
self._message_id,
utils.get_peer_id(self._chat_peer)
)
except ValueError:
pass
return self._input_chat
@property
@ -173,7 +134,7 @@ class EventCommon(abc.ABC):
return self._client
@property
def chat(self):
async def chat(self):
"""
The (:tl:`User` | :tl:`Chat` | :tl:`Channel`, optional) on which
the event occurred. This property may make an API call the first time
@ -184,10 +145,10 @@ class EventCommon(abc.ABC):
return None
if self._chat is None:
self._chat = self._entities.get(utils.get_peer_id(self._input_chat))
self._chat = self._entities.get(utils.get_peer_id(self._chat_peer))
if self._chat is None:
self._chat = self._client.get_entity(self._input_chat)
self._chat = await self._client.get_entity(self._input_chat)
return self._chat

View File

@ -89,7 +89,7 @@ class MessageRead(EventBuilder):
return self._message_ids
@property
def messages(self):
async def messages(self):
"""
The list of `telethon.tl.custom.message.Message`
**which contents'** were read.
@ -98,11 +98,11 @@ class MessageRead(EventBuilder):
was read instead checking if it's in here.
"""
if self._messages is None:
chat = self.input_chat
chat = await self.input_chat
if not chat:
self._messages = []
else:
self._messages = self._client.get_messages(
self._messages = await self._client.get_messages(
chat, ids=self._message_ids)
return self._messages

View File

@ -148,16 +148,16 @@ class UserUpdate(EventBuilder):
self.uploading = self.video = True
@property
def user(self):
async def user(self):
"""Alias around the chat (conversation)."""
return self.chat
return await self.chat
@property
def input_user(self):
async def input_user(self):
"""Alias around the input chat."""
return self.input_chat
return await self.input_chat
@property
def user_id(self):
async def user_id(self):
"""Alias around `chat_id`."""
return self.chat_id

View File

@ -9,26 +9,26 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
def connect(self, ip, port):
result = super().connect(ip, port)
self.conn.write(b'\xef')
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xef')
return result
def clone(self):
return ConnectionTcpAbridged(self._proxy, self._timeout)
def recv(self):
length = struct.unpack('<B', self.read(1))[0]
async def recv(self):
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:
length = struct.unpack('<i', self.read(3) + b'\0')[0]
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return self.read(length << 2)
return await self.read(length << 2)
def send(self, message):
async def send(self, message):
length = len(message) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
await self.write(length + message)

View File

@ -8,16 +8,16 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
def connect(self, ip, port):
result = super().connect(ip, port)
self.conn.write(b'\xee\xee\xee\xee')
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xee\xee\xee\xee')
return result
def clone(self):
return ConnectionTcpIntermediate(self._proxy, self._timeout)
def recv(self):
return self.read(struct.unpack('<i', self.read(4))[0])
async def recv(self):
return await self.read(struct.unpack('<i', await self.read(4))[0])
def send(self, message):
self.write(struct.pack('<i', len(message)) + message)
async def send(self, message):
await self.write(struct.pack('<i', len(message)) + message)

View File

@ -18,8 +18,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
def connect(self, ip, port):
result = ConnectionTcpFull.connect(self, ip, port)
async def connect(self, ip, port):
result = await ConnectionTcpFull.connect(self, ip, port)
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True:
@ -43,7 +43,7 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random))
await self.conn.write(bytes(random))
return result
def clone(self):

View File

@ -2,7 +2,7 @@ import itertools
import logging
from datetime import datetime
from queue import Queue, Empty
from threading import RLock, Thread
from threading import RLock
from . import utils
from .tl import types as tl