Revisit codebase to add missing async/await

This commit is contained in:
Lonami Exo 2018-06-14 17:09:20 +02:00
parent 1247d050ab
commit 908dfa148b
9 changed files with 73 additions and 111 deletions

View File

@ -518,7 +518,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
entities=msg_entities, entities=msg_entities,
media=media media=media
) )
msg = self._get_response_message(request, self(request), entity) msg = self._get_response_message(request, await self(request), entity)
self._cache_media(msg, file, file_handle) self._cache_media(msg, file, file_handle)
return msg return msg

View File

@ -164,15 +164,16 @@ class ChatAction(EventBuilder):
self.action_message = custom.Message( self.action_message = custom.Message(
client, self.action_message, self._entities, None) client, self.action_message, self._entities, None)
def respond(self, *args, **kwargs): async def respond(self, *args, **kwargs):
""" """
Responds to the chat action message (not as a reply). Shorthand for Responds to the chat action message (not as a reply). Shorthand for
`telethon.telegram_client.TelegramClient.send_message` with `telethon.telegram_client.TelegramClient.send_message` with
``entity`` already set. ``entity`` already set.
""" """
return self._client.send_message(self.input_chat, *args, **kwargs) return await self._client.send_message(
await self.input_chat, *args, **kwargs)
def reply(self, *args, **kwargs): async def reply(self, *args, **kwargs):
""" """
Replies to the chat action message (as a reply). Shorthand for Replies to the chat action message (as a reply). Shorthand for
`telethon.telegram_client.TelegramClient.send_message` with `telethon.telegram_client.TelegramClient.send_message` with
@ -181,12 +182,13 @@ class ChatAction(EventBuilder):
Has the same effect as `respond` if there is no message. Has the same effect as `respond` if there is no message.
""" """
if not self.action_message: if not self.action_message:
return self.respond(*args, **kwargs) return await self.respond(*args, **kwargs)
kwargs['reply_to'] = self.action_message.id kwargs['reply_to'] = self.action_message.id
return self._client.send_message(self.input_chat, *args, **kwargs) return await self._client.send_message(
await self.input_chat, *args, **kwargs)
def delete(self, *args, **kwargs): async def delete(self, *args, **kwargs):
""" """
Deletes the chat action message. You're responsible for checking Deletes the chat action message. You're responsible for checking
whether you have the permission to do so, or to except the error whether you have the permission to do so, or to except the error
@ -196,13 +198,14 @@ class ChatAction(EventBuilder):
Does nothing if no message action triggered this event. Does nothing if no message action triggered this event.
""" """
if self.action_message: if not self.action_message:
return self._client.delete_messages(self.input_chat, return
[self.action_message],
*args, **kwargs) return await self._client.delete_messages(
await self.input_chat, [self.action_message], *args, **kwargs)
@property @property
def pinned_message(self): async def pinned_message(self):
""" """
If ``new_pin`` is ``True``, this returns the If ``new_pin`` is ``True``, this returns the
`telethon.tl.custom.message.Message` object that was pinned. `telethon.tl.custom.message.Message` object that was pinned.
@ -210,8 +213,8 @@ class ChatAction(EventBuilder):
if self._pinned_message == 0: if self._pinned_message == 0:
return None return None
if isinstance(self._pinned_message, int) and self.input_chat: if isinstance(self._pinned_message, int) and await self.input_chat:
r = self._client(functions.channels.GetMessagesRequest( r = await self._client(functions.channels.GetMessagesRequest(
self._input_chat, [self._pinned_message] self._input_chat, [self._pinned_message]
)) ))
try: try:
@ -227,50 +230,48 @@ class ChatAction(EventBuilder):
return self._pinned_message return self._pinned_message
@property @property
def added_by(self): async def added_by(self):
""" """
The user who added ``users``, if applicable (``None`` otherwise). The user who added ``users``, if applicable (``None`` otherwise).
""" """
if self._added_by and not isinstance(self._added_by, types.User): if self._added_by and not isinstance(self._added_by, types.User):
self._added_by =\ aby = self._entities.get(utils.get_peer_id(self._added_by))
self._entities.get(utils.get_peer_id(self._added_by)) if not aby:
aby = await self._client.get_entity(self._added_by)
if not self._added_by: self._added_by = aby
self._added_by = self._client.get_entity(self._added_by)
return self._added_by return self._added_by
@property @property
def kicked_by(self): async def kicked_by(self):
""" """
The user who kicked ``users``, if applicable (``None`` otherwise). The user who kicked ``users``, if applicable (``None`` otherwise).
""" """
if self._kicked_by and not isinstance(self._kicked_by, types.User): if self._kicked_by and not isinstance(self._kicked_by, types.User):
self._kicked_by =\ kby = self._entities.get(utils.get_peer_id(self._kicked_by))
self._entities.get(utils.get_peer_id(self._kicked_by)) if kby:
kby = await self._client.get_entity(self._kicked_by)
if not self._kicked_by: self._kicked_by = kby
self._kicked_by = self._client.get_entity(self._kicked_by)
return self._kicked_by return self._kicked_by
@property @property
def user(self): async def user(self):
""" """
The first user that takes part in this action (e.g. joined). The first user that takes part in this action (e.g. joined).
Might be ``None`` if the information can't be retrieved or Might be ``None`` if the information can't be retrieved or
there is no user taking part. there is no user taking part.
""" """
if self.users: if await self.users:
return self._users[0] return self._users[0]
@property @property
def input_user(self): async def input_user(self):
""" """
Input version of the ``self.user`` property. Input version of the ``self.user`` property.
""" """
if self.input_users: if await self.input_users:
return self._input_users[0] return self._input_users[0]
@property @property
@ -282,7 +283,7 @@ class ChatAction(EventBuilder):
return utils.get_peer_id(self._user_peers[0]) return utils.get_peer_id(self._user_peers[0])
@property @property
def users(self): async def users(self):
""" """
A list of users that take part in this action (e.g. joined). A list of users that take part in this action (e.g. joined).
@ -302,7 +303,7 @@ class ChatAction(EventBuilder):
missing.append(peer) missing.append(peer)
try: try:
missing = self._client.get_entity(missing) missing = await self._client.get_entity(missing)
except (TypeError, ValueError): except (TypeError, ValueError):
missing = [] missing = []
@ -311,7 +312,7 @@ class ChatAction(EventBuilder):
return self._users return self._users
@property @property
def input_users(self): async def input_users(self):
""" """
Input version of the ``self.users`` property. Input version of the ``self.users`` property.
""" """
@ -319,9 +320,9 @@ class ChatAction(EventBuilder):
self._input_users = [] self._input_users = []
for peer in self._user_peers: for peer in self._user_peers:
try: try:
self._input_users.append(self._client.get_input_entity( self._input_users.append(
peer await self._client.get_input_entity(peer)
)) )
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
return self._input_users return self._input_users

View File

@ -109,40 +109,8 @@ class EventCommon(abc.ABC):
""" """
self._client = client self._client = client
def _get_entity(self, msg_id, entity_id, chat=None):
"""
Helper function to call :tl:`GetMessages` on the give msg_id and
return the input entity whose ID is the given entity ID.
If ``chat`` is present it must be an :tl:`InputPeer`.
Returns a tuple of ``(entity, input_peer)`` if it was found, or
a tuple of ``(None, None)`` if it couldn't be.
"""
try:
if isinstance(chat, types.InputPeerChannel):
result = self._client(
functions.channels.GetMessagesRequest(chat, [msg_id])
)
else:
result = self._client(
functions.messages.GetMessagesRequest([msg_id])
)
except RPCError:
return None, None
entity = {
utils.get_peer_id(x): x for x in itertools.chain(
getattr(result, 'chats', []),
getattr(result, 'users', []))
}.get(entity_id)
if entity:
return entity, utils.get_input_peer(entity)
else:
return None, None
@property @property
def input_chat(self): async def input_chat(self):
""" """
The (:tl:`InputPeer`) (group, megagroup or channel) on which The (:tl:`InputPeer`) (group, megagroup or channel) on which
the event occurred. This doesn't have the title or anything, the event occurred. This doesn't have the title or anything,
@ -153,19 +121,12 @@ class EventCommon(abc.ABC):
""" """
if self._input_chat is None and self._chat_peer is not None: if self._input_chat is None and self._chat_peer is not None:
try: try:
self._input_chat = self._client.get_input_entity( self._input_chat = await self._client.get_input_entity(
self._chat_peer self._chat_peer
) )
except (ValueError, TypeError): except ValueError:
# The library hasn't seen this chat, get the message pass
if not isinstance(self._chat_peer, types.PeerChannel):
# TODO For channels, getDifference? Maybe looking
# in the dialogs (which is already done) is enough.
if self._message_id is not None:
self._chat, self._input_chat = self._get_entity(
self._message_id,
utils.get_peer_id(self._chat_peer)
)
return self._input_chat return self._input_chat
@property @property
@ -173,7 +134,7 @@ class EventCommon(abc.ABC):
return self._client return self._client
@property @property
def chat(self): async def chat(self):
""" """
The (:tl:`User` | :tl:`Chat` | :tl:`Channel`, optional) on which The (:tl:`User` | :tl:`Chat` | :tl:`Channel`, optional) on which
the event occurred. This property may make an API call the first time the event occurred. This property may make an API call the first time
@ -184,10 +145,10 @@ class EventCommon(abc.ABC):
return None return None
if self._chat is None: if self._chat is None:
self._chat = self._entities.get(utils.get_peer_id(self._input_chat)) self._chat = self._entities.get(utils.get_peer_id(self._chat_peer))
if self._chat is None: if self._chat is None:
self._chat = self._client.get_entity(self._input_chat) self._chat = await self._client.get_entity(self._input_chat)
return self._chat return self._chat

View File

@ -89,7 +89,7 @@ class MessageRead(EventBuilder):
return self._message_ids return self._message_ids
@property @property
def messages(self): async def messages(self):
""" """
The list of `telethon.tl.custom.message.Message` The list of `telethon.tl.custom.message.Message`
**which contents'** were read. **which contents'** were read.
@ -98,11 +98,11 @@ class MessageRead(EventBuilder):
was read instead checking if it's in here. was read instead checking if it's in here.
""" """
if self._messages is None: if self._messages is None:
chat = self.input_chat chat = await self.input_chat
if not chat: if not chat:
self._messages = [] self._messages = []
else: else:
self._messages = self._client.get_messages( self._messages = await self._client.get_messages(
chat, ids=self._message_ids) chat, ids=self._message_ids)
return self._messages return self._messages

View File

@ -148,16 +148,16 @@ class UserUpdate(EventBuilder):
self.uploading = self.video = True self.uploading = self.video = True
@property @property
def user(self): async def user(self):
"""Alias around the chat (conversation).""" """Alias around the chat (conversation)."""
return self.chat return await self.chat
@property @property
def input_user(self): async def input_user(self):
"""Alias around the input chat.""" """Alias around the input chat."""
return self.input_chat return await self.input_chat
@property @property
def user_id(self): async def user_id(self):
"""Alias around `chat_id`.""" """Alias around `chat_id`."""
return self.chat_id return self.chat_id

View File

@ -9,26 +9,26 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
only require 1 byte if the packet length is less than only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common). 508 bytes (127 << 2, which is very common).
""" """
def connect(self, ip, port): async def connect(self, ip, port):
result = super().connect(ip, port) result = await super().connect(ip, port)
self.conn.write(b'\xef') await self.conn.write(b'\xef')
return result return result
def clone(self): def clone(self):
return ConnectionTcpAbridged(self._proxy, self._timeout) return ConnectionTcpAbridged(self._proxy, self._timeout)
def recv(self): async def recv(self):
length = struct.unpack('<B', self.read(1))[0] length = struct.unpack('<B', await self.read(1))[0]
if length >= 127: if length >= 127:
length = struct.unpack('<i', self.read(3) + b'\0')[0] length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return self.read(length << 2) return await self.read(length << 2)
def send(self, message): async def send(self, message):
length = len(message) >> 2 length = len(message) >> 2
if length < 127: if length < 127:
length = struct.pack('B', length) length = struct.pack('B', length)
else: else:
length = b'\x7f' + int.to_bytes(length, 3, 'little') length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message) await self.write(length + message)

View File

@ -8,16 +8,16 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`. Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length. Always sends 4 extra bytes for the packet length.
""" """
def connect(self, ip, port): async def connect(self, ip, port):
result = super().connect(ip, port) result = await super().connect(ip, port)
self.conn.write(b'\xee\xee\xee\xee') await self.conn.write(b'\xee\xee\xee\xee')
return result return result
def clone(self): def clone(self):
return ConnectionTcpIntermediate(self._proxy, self._timeout) return ConnectionTcpIntermediate(self._proxy, self._timeout)
def recv(self): async def recv(self):
return self.read(struct.unpack('<i', self.read(4))[0]) return await self.read(struct.unpack('<i', await self.read(4))[0])
def send(self, message): async def send(self, message):
self.write(struct.pack('<i', len(message)) + message) await self.write(struct.pack('<i', len(message)) + message)

View File

@ -18,8 +18,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s)) self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d)) self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
def connect(self, ip, port): async def connect(self, ip, port):
result = ConnectionTcpFull.connect(self, ip, port) result = await ConnectionTcpFull.connect(self, ip, port)
# Obfuscated messages secrets cannot start with any of these # Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True: while True:
@ -43,7 +43,7 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random)) await self.conn.write(bytes(random))
return result return result
def clone(self): def clone(self):

View File

@ -2,7 +2,7 @@ import itertools
import logging import logging
from datetime import datetime from datetime import datetime
from queue import Queue, Empty from queue import Queue, Empty
from threading import RLock, Thread from threading import RLock
from . import utils from . import utils
from .tl import types as tl from .tl import types as tl