mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-04 04:00:18 +03:00
Merge remote-tracking branch 'tulir/asyncio' into asyncio
This commit is contained in:
commit
563d731c95
25
README.rst
25
README.rst
|
@ -30,6 +30,7 @@ Creating a client
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from telethon import TelegramClient
|
from telethon import TelegramClient
|
||||||
|
|
||||||
# These example values won't work. You must get your own api_id and
|
# These example values won't work. You must get your own api_id and
|
||||||
|
@ -38,22 +39,28 @@ Creating a client
|
||||||
api_hash = '0123456789abcdef0123456789abcdef'
|
api_hash = '0123456789abcdef0123456789abcdef'
|
||||||
|
|
||||||
client = TelegramClient('session_name', api_id, api_hash)
|
client = TelegramClient('session_name', api_id, api_hash)
|
||||||
client.start()
|
async def main():
|
||||||
|
await client.start()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(main())
|
||||||
|
|
||||||
Doing stuff
|
Doing stuff
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
|
Note that this assumes you're inside an "async def" method. Check out the
|
||||||
|
`Python documentation <https://docs.python.org/3/library/asyncio-dev.html>`_
|
||||||
|
if you're new with ``asyncio``.
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
print(client.get_me().stringify())
|
print(client.get_me().stringify())
|
||||||
|
|
||||||
client.send_message('username', 'Hello! Talking to you from Telethon')
|
await client.send_message('username', 'Hello! Talking to you from Telethon')
|
||||||
client.send_file('username', '/home/myself/Pictures/holidays.jpg')
|
await client.send_file('username', '/home/myself/Pictures/holidays.jpg')
|
||||||
|
|
||||||
client.download_profile_photo('me')
|
await client.download_profile_photo('me')
|
||||||
messages = client.get_message_history('username')
|
messages = await client.get_message_history('username')
|
||||||
client.download_media(messages[0])
|
await client.download_media(messages[0])
|
||||||
|
|
||||||
|
|
||||||
Next steps
|
Next steps
|
||||||
|
@ -61,5 +68,7 @@ Next steps
|
||||||
|
|
||||||
Do you like how Telethon looks? Check out
|
Do you like how Telethon looks? Check out
|
||||||
`Read The Docs <http://telethon.rtfd.io/>`_
|
`Read The Docs <http://telethon.rtfd.io/>`_
|
||||||
for a more in-depth explanation, with examples,
|
for a more in-depth explanation, with examples, troubleshooting issues,
|
||||||
troubleshooting issues, and more useful information.
|
and more useful information. Note that the examples there are written for
|
||||||
|
the threaded version, not the one using asyncio. However, you just need to
|
||||||
|
await every remote call.
|
||||||
|
|
|
@ -30,7 +30,7 @@ class CdnDecrypter:
|
||||||
self.cdn_file_hashes = cdn_file_hashes
|
self.cdn_file_hashes = cdn_file_hashes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_decrypter(client, cdn_client, cdn_redirect):
|
async def prepare_decrypter(client, cdn_client, cdn_redirect):
|
||||||
"""
|
"""
|
||||||
Prepares a new CDN decrypter.
|
Prepares a new CDN decrypter.
|
||||||
|
|
||||||
|
@ -52,14 +52,14 @@ class CdnDecrypter:
|
||||||
cdn_aes, cdn_redirect.cdn_file_hashes
|
cdn_aes, cdn_redirect.cdn_file_hashes
|
||||||
)
|
)
|
||||||
|
|
||||||
cdn_file = cdn_client(GetCdnFileRequest(
|
cdn_file = await cdn_client(GetCdnFileRequest(
|
||||||
file_token=cdn_redirect.file_token,
|
file_token=cdn_redirect.file_token,
|
||||||
offset=cdn_redirect.cdn_file_hashes[0].offset,
|
offset=cdn_redirect.cdn_file_hashes[0].offset,
|
||||||
limit=cdn_redirect.cdn_file_hashes[0].limit
|
limit=cdn_redirect.cdn_file_hashes[0].limit
|
||||||
))
|
))
|
||||||
if isinstance(cdn_file, CdnFileReuploadNeeded):
|
if isinstance(cdn_file, CdnFileReuploadNeeded):
|
||||||
# We need to use the original client here
|
# We need to use the original client here
|
||||||
client(ReuploadCdnFileRequest(
|
await client(ReuploadCdnFileRequest(
|
||||||
file_token=cdn_redirect.file_token,
|
file_token=cdn_redirect.file_token,
|
||||||
request_token=cdn_file.request_token
|
request_token=cdn_file.request_token
|
||||||
))
|
))
|
||||||
|
@ -73,7 +73,7 @@ class CdnDecrypter:
|
||||||
|
|
||||||
return decrypter, cdn_file
|
return decrypter, cdn_file
|
||||||
|
|
||||||
def get_file(self):
|
async def get_file(self):
|
||||||
"""
|
"""
|
||||||
Calls GetCdnFileRequest and decrypts its bytes.
|
Calls GetCdnFileRequest and decrypts its bytes.
|
||||||
Also ensures that the file hasn't been tampered.
|
Also ensures that the file hasn't been tampered.
|
||||||
|
@ -82,7 +82,7 @@ class CdnDecrypter:
|
||||||
"""
|
"""
|
||||||
if self.cdn_file_hashes:
|
if self.cdn_file_hashes:
|
||||||
cdn_hash = self.cdn_file_hashes.pop(0)
|
cdn_hash = self.cdn_file_hashes.pop(0)
|
||||||
cdn_file = self.client(GetCdnFileRequest(
|
cdn_file = await self.client(GetCdnFileRequest(
|
||||||
self.file_token, cdn_hash.offset, cdn_hash.limit
|
self.file_token, cdn_hash.offset, cdn_hash.limit
|
||||||
))
|
))
|
||||||
cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes)
|
cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..extensions import markdown
|
||||||
from ..tl import types, functions
|
from ..tl import types, functions
|
||||||
|
|
||||||
|
|
||||||
def _into_id_set(client, chats):
|
async def _into_id_set(client, chats):
|
||||||
"""Helper util to turn the input chat or chats into a set of IDs."""
|
"""Helper util to turn the input chat or chats into a set of IDs."""
|
||||||
if chats is None:
|
if chats is None:
|
||||||
return None
|
return None
|
||||||
|
@ -19,9 +19,9 @@ def _into_id_set(client, chats):
|
||||||
|
|
||||||
result = set()
|
result = set()
|
||||||
for chat in chats:
|
for chat in chats:
|
||||||
chat = client.get_input_entity(chat)
|
chat = await client.get_input_entity(chat)
|
||||||
if isinstance(chat, types.InputPeerSelf):
|
if isinstance(chat, types.InputPeerSelf):
|
||||||
chat = client.get_me(input_peer=True)
|
chat = await client.get_me(input_peer=True)
|
||||||
result.add(utils.get_peer_id(chat))
|
result.add(utils.get_peer_id(chat))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -48,10 +48,10 @@ class _EventBuilder(abc.ABC):
|
||||||
def build(self, update):
|
def build(self, update):
|
||||||
"""Builds an event for the given update if possible, or returns None"""
|
"""Builds an event for the given update if possible, or returns None"""
|
||||||
|
|
||||||
def resolve(self, client):
|
async def resolve(self, client):
|
||||||
"""Helper method to allow event builders to be resolved before usage"""
|
"""Helper method to allow event builders to be resolved before usage"""
|
||||||
self.chats = _into_id_set(client, self.chats)
|
self.chats = await _into_id_set(client, self.chats)
|
||||||
self._self_id = client.get_me(input_peer=True).user_id
|
self._self_id = (await client.get_me(input_peer=True)).user_id
|
||||||
|
|
||||||
def _filter_event(self, event):
|
def _filter_event(self, event):
|
||||||
"""
|
"""
|
||||||
|
@ -86,7 +86,7 @@ class _EventCommon(abc.ABC):
|
||||||
)
|
)
|
||||||
self.is_channel = isinstance(chat_peer, types.PeerChannel)
|
self.is_channel = isinstance(chat_peer, types.PeerChannel)
|
||||||
|
|
||||||
def _get_input_entity(self, msg_id, entity_id, chat=None):
|
async def _get_input_entity(self, msg_id, entity_id, chat=None):
|
||||||
"""
|
"""
|
||||||
Helper function to call GetMessages on the give msg_id and
|
Helper function to call GetMessages on the give msg_id and
|
||||||
return the input entity whose ID is the given entity ID.
|
return the input entity whose ID is the given entity ID.
|
||||||
|
@ -95,11 +95,11 @@ class _EventCommon(abc.ABC):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(chat, types.InputPeerChannel):
|
if isinstance(chat, types.InputPeerChannel):
|
||||||
result = self._client(
|
result = await self._client(
|
||||||
functions.channels.GetMessagesRequest(chat, [msg_id])
|
functions.channels.GetMessagesRequest(chat, [msg_id])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = self._client(
|
result = await self._client(
|
||||||
functions.messages.GetMessagesRequest([msg_id])
|
functions.messages.GetMessagesRequest([msg_id])
|
||||||
)
|
)
|
||||||
except RPCError:
|
except RPCError:
|
||||||
|
@ -113,7 +113,7 @@ class _EventCommon(abc.ABC):
|
||||||
return utils.get_input_peer(entity)
|
return utils.get_input_peer(entity)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_chat(self):
|
async def input_chat(self):
|
||||||
"""
|
"""
|
||||||
The (:obj:`InputPeer`) (group, megagroup or channel) on which
|
The (:obj:`InputPeer`) (group, megagroup or channel) on which
|
||||||
the event occurred. This doesn't have the title or anything,
|
the event occurred. This doesn't have the title or anything,
|
||||||
|
@ -125,7 +125,7 @@ class _EventCommon(abc.ABC):
|
||||||
|
|
||||||
if self._input_chat is None and self._chat_peer is not None:
|
if self._input_chat is None and self._chat_peer is not None:
|
||||||
try:
|
try:
|
||||||
self._input_chat = self._client.get_input_entity(
|
self._input_chat = await self._client.get_input_entity(
|
||||||
self._chat_peer
|
self._chat_peer
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
|
@ -134,22 +134,22 @@ class _EventCommon(abc.ABC):
|
||||||
# TODO For channels, getDifference? Maybe looking
|
# TODO For channels, getDifference? Maybe looking
|
||||||
# in the dialogs (which is already done) is enough.
|
# in the dialogs (which is already done) is enough.
|
||||||
if self._message_id is not None:
|
if self._message_id is not None:
|
||||||
self._input_chat = self._get_input_entity(
|
self._input_chat = await self._get_input_entity(
|
||||||
self._message_id,
|
self._message_id,
|
||||||
utils.get_peer_id(self._chat_peer)
|
utils.get_peer_id(self._chat_peer)
|
||||||
)
|
)
|
||||||
return self._input_chat
|
return self._input_chat
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat(self):
|
async def chat(self):
|
||||||
"""
|
"""
|
||||||
The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which
|
The (:obj:`User` | :obj:`Chat` | :obj:`Channel`, optional) on which
|
||||||
the event occurred. This property will make an API call the first time
|
the event occurred. This property will make an API call the first time
|
||||||
to get the most up to date version of the chat, so use with care as
|
to get the most up to date version of the chat, so use with care as
|
||||||
there is no caching besides local caching yet.
|
there is no caching besides local caching yet.
|
||||||
"""
|
"""
|
||||||
if self._chat is None and self.input_chat:
|
if self._chat is None and await self.input_chat:
|
||||||
self._chat = self._client.get_entity(self._input_chat)
|
self._chat = await self._client.get_entity(await self._input_chat)
|
||||||
return self._chat
|
return self._chat
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class Raw(_EventBuilder):
|
||||||
"""
|
"""
|
||||||
Represents a raw event. The event is the update itself.
|
Represents a raw event. The event is the update itself.
|
||||||
"""
|
"""
|
||||||
def resolve(self, client):
|
async def resolve(self, client):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def build(self, update):
|
def build(self, update):
|
||||||
|
@ -304,23 +304,23 @@ class NewMessage(_EventBuilder):
|
||||||
self.is_reply = bool(message.reply_to_msg_id)
|
self.is_reply = bool(message.reply_to_msg_id)
|
||||||
self._reply_message = None
|
self._reply_message = None
|
||||||
|
|
||||||
def respond(self, *args, **kwargs):
|
async def respond(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Responds to the message (not as a reply). This is a shorthand for
|
Responds to the message (not as a reply). This is a shorthand for
|
||||||
``client.send_message(event.chat, ...)``.
|
``client.send_message(event.chat, ...)``.
|
||||||
"""
|
"""
|
||||||
return self._client.send_message(self.input_chat, *args, **kwargs)
|
return await self._client.send_message(await self.input_chat, *args, **kwargs)
|
||||||
|
|
||||||
def reply(self, *args, **kwargs):
|
async def reply(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Replies to the message (as a reply). This is a shorthand for
|
Replies to the message (as a reply). This is a shorthand for
|
||||||
``client.send_message(event.chat, ..., reply_to=event.message.id)``.
|
``client.send_message(event.chat, ..., reply_to=event.message.id)``.
|
||||||
"""
|
"""
|
||||||
return self._client.send_message(self.input_chat,
|
return await self._client.send_message(await self.input_chat,
|
||||||
reply_to=self.message.id,
|
reply_to=self.message.id,
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
def edit(self, *args, **kwargs):
|
async def edit(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Edits the message iff it's outgoing. This is a shorthand for
|
Edits the message iff it's outgoing. This is a shorthand for
|
||||||
``client.edit_message(event.chat, event.message, ...)``.
|
``client.edit_message(event.chat, event.message, ...)``.
|
||||||
|
@ -331,27 +331,27 @@ class NewMessage(_EventBuilder):
|
||||||
if not self.message.out:
|
if not self.message.out:
|
||||||
if not isinstance(self.message.to_id, types.PeerUser):
|
if not isinstance(self.message.to_id, types.PeerUser):
|
||||||
return None
|
return None
|
||||||
me = self._client.get_me(input_peer=True)
|
me = await self._client.get_me(input_peer=True)
|
||||||
if self.message.to_id.user_id != me.user_id:
|
if self.message.to_id.user_id != me.user_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._client.edit_message(self.input_chat,
|
return await self._client.edit_message(await self.input_chat,
|
||||||
self.message,
|
self.message,
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
def delete(self, *args, **kwargs):
|
async def delete(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Deletes the message. You're responsible for checking whether you
|
Deletes the message. You're responsible for checking whether you
|
||||||
have the permission to do so, or to except the error otherwise.
|
have the permission to do so, or to except the error otherwise.
|
||||||
This is a shorthand for
|
This is a shorthand for
|
||||||
``client.delete_messages(event.chat, event.message, ...)``.
|
``client.delete_messages(event.chat, event.message, ...)``.
|
||||||
"""
|
"""
|
||||||
return self._client.delete_messages(self.input_chat,
|
return await self._client.delete_messages(await self.input_chat,
|
||||||
[self.message],
|
[self.message],
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_sender(self):
|
async def input_sender(self):
|
||||||
"""
|
"""
|
||||||
This (:obj:`InputPeer`) is the input version of the user who
|
This (:obj:`InputPeer`) is the input version of the user who
|
||||||
sent the message. Similarly to ``input_chat``, this doesn't have
|
sent the message. Similarly to ``input_chat``, this doesn't have
|
||||||
|
@ -365,21 +365,21 @@ class NewMessage(_EventBuilder):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._input_sender = self._client.get_input_entity(
|
self._input_sender = await self._client.get_input_entity(
|
||||||
self.message.from_id
|
self.message.from_id
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# We can rely on self.input_chat for this
|
# We can rely on self.input_chat for this
|
||||||
self._input_sender = self._get_input_entity(
|
self._input_sender = await self._get_input_entity(
|
||||||
self.message.id,
|
self.message.id,
|
||||||
self.message.from_id,
|
self.message.from_id,
|
||||||
chat=self.input_chat
|
chat=await self.input_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._input_sender
|
return self._input_sender
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sender(self):
|
async def sender(self):
|
||||||
"""
|
"""
|
||||||
This (:obj:`User`) will make an API call the first time to get
|
This (:obj:`User`) will make an API call the first time to get
|
||||||
the most up to date version of the sender, so use with care as
|
the most up to date version of the sender, so use with care as
|
||||||
|
@ -387,8 +387,8 @@ class NewMessage(_EventBuilder):
|
||||||
|
|
||||||
``input_sender`` needs to be available (often the case).
|
``input_sender`` needs to be available (often the case).
|
||||||
"""
|
"""
|
||||||
if self._sender is None and self.input_sender:
|
if self._sender is None and await self.input_sender:
|
||||||
self._sender = self._client.get_entity(self._input_sender)
|
self._sender = await self._client.get_entity(self._input_sender)
|
||||||
return self._sender
|
return self._sender
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -411,7 +411,7 @@ class NewMessage(_EventBuilder):
|
||||||
return self.message.message
|
return self.message.message
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reply_message(self):
|
async def reply_message(self):
|
||||||
"""
|
"""
|
||||||
This (:obj:`Message`, optional) will make an API call the first
|
This (:obj:`Message`, optional) will make an API call the first
|
||||||
time to get the full ``Message`` object that one was replying to,
|
time to get the full ``Message`` object that one was replying to,
|
||||||
|
@ -421,12 +421,12 @@ class NewMessage(_EventBuilder):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._reply_message is None:
|
if self._reply_message is None:
|
||||||
if isinstance(self.input_chat, types.InputPeerChannel):
|
if isinstance(await self.input_chat, types.InputPeerChannel):
|
||||||
r = self._client(functions.channels.GetMessagesRequest(
|
r = await self._client(functions.channels.GetMessagesRequest(
|
||||||
self.input_chat, [self.message.reply_to_msg_id]
|
await self.input_chat, [self.message.reply_to_msg_id]
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
r = self._client(functions.messages.GetMessagesRequest(
|
r = await self._client(functions.messages.GetMessagesRequest(
|
||||||
[self.message.reply_to_msg_id]
|
[self.message.reply_to_msg_id]
|
||||||
))
|
))
|
||||||
if not isinstance(r, types.messages.MessagesNotModified):
|
if not isinstance(r, types.messages.MessagesNotModified):
|
||||||
|
@ -610,7 +610,7 @@ class ChatAction(_EventBuilder):
|
||||||
self.new_title = new_title
|
self.new_title = new_title
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pinned_message(self):
|
async def pinned_message(self):
|
||||||
"""
|
"""
|
||||||
If ``new_pin`` is ``True``, this returns the (:obj:`Message`)
|
If ``new_pin`` is ``True``, this returns the (:obj:`Message`)
|
||||||
object that was pinned.
|
object that was pinned.
|
||||||
|
@ -618,8 +618,8 @@ class ChatAction(_EventBuilder):
|
||||||
if self._pinned_message == 0:
|
if self._pinned_message == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(self._pinned_message, int) and self.input_chat:
|
if isinstance(self._pinned_message, int) and await self.input_chat:
|
||||||
r = self._client(functions.channels.GetMessagesRequest(
|
r = await self._client(functions.channels.GetMessagesRequest(
|
||||||
self._input_chat, [self._pinned_message]
|
self._input_chat, [self._pinned_message]
|
||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
|
@ -635,25 +635,25 @@ class ChatAction(_EventBuilder):
|
||||||
return self._pinned_message
|
return self._pinned_message
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def added_by(self):
|
async def added_by(self):
|
||||||
"""
|
"""
|
||||||
The user who added ``users``, if applicable (``None`` otherwise).
|
The user who added ``users``, if applicable (``None`` otherwise).
|
||||||
"""
|
"""
|
||||||
if self._added_by and not isinstance(self._added_by, types.User):
|
if self._added_by and not isinstance(self._added_by, types.User):
|
||||||
self._added_by = self._client.get_entity(self._added_by)
|
self._added_by = await self._client.get_entity(self._added_by)
|
||||||
return self._added_by
|
return self._added_by
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kicked_by(self):
|
async def kicked_by(self):
|
||||||
"""
|
"""
|
||||||
The user who kicked ``users``, if applicable (``None`` otherwise).
|
The user who kicked ``users``, if applicable (``None`` otherwise).
|
||||||
"""
|
"""
|
||||||
if self._kicked_by and not isinstance(self._kicked_by, types.User):
|
if self._kicked_by and not isinstance(self._kicked_by, types.User):
|
||||||
self._kicked_by = self._client.get_entity(self._kicked_by)
|
self._kicked_by = await self._client.get_entity(self._kicked_by)
|
||||||
return self._kicked_by
|
return self._kicked_by
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user(self):
|
async def user(self):
|
||||||
"""
|
"""
|
||||||
The single user that takes part in this action (e.g. joined).
|
The single user that takes part in this action (e.g. joined).
|
||||||
|
|
||||||
|
@ -661,12 +661,12 @@ class ChatAction(_EventBuilder):
|
||||||
there is no user taking part.
|
there is no user taking part.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return next(self.users)
|
return next(await self.users)
|
||||||
except (StopIteration, TypeError):
|
except (StopIteration, TypeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def users(self):
|
async def users(self):
|
||||||
"""
|
"""
|
||||||
A list of users that take part in this action (e.g. joined).
|
A list of users that take part in this action (e.g. joined).
|
||||||
|
|
||||||
|
@ -675,7 +675,7 @@ class ChatAction(_EventBuilder):
|
||||||
"""
|
"""
|
||||||
if self._users is None and self._user_peers:
|
if self._users is None and self._user_peers:
|
||||||
try:
|
try:
|
||||||
self._users = self._client.get_entity(self._user_peers)
|
self._users = await self._client.get_entity(self._user_peers)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
self._users = []
|
self._users = []
|
||||||
|
|
||||||
|
|
|
@ -194,9 +194,9 @@ def get_inner_text(text, entity):
|
||||||
"""
|
"""
|
||||||
if isinstance(entity, TLObject):
|
if isinstance(entity, TLObject):
|
||||||
entity = (entity,)
|
entity = (entity,)
|
||||||
multiple = True
|
|
||||||
else:
|
|
||||||
multiple = False
|
multiple = False
|
||||||
|
else:
|
||||||
|
multiple = True
|
||||||
|
|
||||||
text = _add_surrogate(text)
|
text = _add_surrogate(text)
|
||||||
result = []
|
result = []
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
"""
|
"""
|
||||||
This module holds a rough implementation of the C# TCP client.
|
This module holds a rough implementation of the C# TCP client.
|
||||||
"""
|
"""
|
||||||
|
# Python rough implementation of a C# TCP client
|
||||||
|
import asyncio
|
||||||
import errno
|
import errno
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from io import BytesIO, BufferedWriter
|
from io import BytesIO, BufferedWriter
|
||||||
from threading import Lock
|
|
||||||
|
MAX_TIMEOUT = 15 # in seconds
|
||||||
|
CONN_RESET_ERRNOS = {
|
||||||
|
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||||
|
errno.EINVAL, errno.ENOTCONN
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import socks
|
import socks
|
||||||
|
@ -25,7 +32,7 @@ __log__ = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TcpClient:
|
class TcpClient:
|
||||||
"""A simple TCP client to ease the work with sockets and proxies."""
|
"""A simple TCP client to ease the work with sockets and proxies."""
|
||||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||||
"""
|
"""
|
||||||
Initializes the TCP client.
|
Initializes the TCP client.
|
||||||
|
|
||||||
|
@ -34,7 +41,7 @@ class TcpClient:
|
||||||
"""
|
"""
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self._socket = None
|
self._socket = None
|
||||||
self._closing_lock = Lock()
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
|
|
||||||
if isinstance(timeout, timedelta):
|
if isinstance(timeout, timedelta):
|
||||||
self.timeout = timeout.seconds
|
self.timeout = timeout.seconds
|
||||||
|
@ -54,9 +61,9 @@ class TcpClient:
|
||||||
else: # tuple, list, etc.
|
else: # tuple, list, etc.
|
||||||
self._socket.set_proxy(*self.proxy)
|
self._socket.set_proxy(*self.proxy)
|
||||||
|
|
||||||
self._socket.settimeout(self.timeout)
|
self._socket.setblocking(False)
|
||||||
|
|
||||||
def connect(self, ip, port):
|
async def connect(self, ip, port):
|
||||||
"""
|
"""
|
||||||
Tries connecting forever to IP:port unless an OSError is raised.
|
Tries connecting forever to IP:port unless an OSError is raised.
|
||||||
|
|
||||||
|
@ -72,11 +79,15 @@ class TcpClient:
|
||||||
timeout = 1
|
timeout = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
while not self._socket:
|
if not self._socket:
|
||||||
self._recreate_socket(mode)
|
self._recreate_socket(mode)
|
||||||
|
|
||||||
self._socket.connect(address)
|
await self._loop.sock_connect(self._socket, address)
|
||||||
break # Successful connection, stop retrying to connect
|
break # Successful connection, stop retrying to connect
|
||||||
|
except ConnectionError:
|
||||||
|
self._socket = None
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
timeout = min(timeout * 2, MAX_TIMEOUT)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
__log__.info('OSError "%s" raised while connecting', e)
|
__log__.info('OSError "%s" raised while connecting', e)
|
||||||
# Stop retrying to connect if proxy connection error occurred
|
# Stop retrying to connect if proxy connection error occurred
|
||||||
|
@ -90,7 +101,7 @@ class TcpClient:
|
||||||
# Bad file descriptor, i.e. socket was closed, set it
|
# Bad file descriptor, i.e. socket was closed, set it
|
||||||
# to none to recreate it on the next iteration
|
# to none to recreate it on the next iteration
|
||||||
self._socket = None
|
self._socket = None
|
||||||
time.sleep(timeout)
|
await asyncio.sleep(timeout)
|
||||||
timeout = min(timeout * 2, MAX_TIMEOUT)
|
timeout = min(timeout * 2, MAX_TIMEOUT)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
@ -103,21 +114,16 @@ class TcpClient:
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Closes the connection."""
|
"""Closes the connection."""
|
||||||
if self._closing_lock.locked():
|
try:
|
||||||
# Already closing, no need to close again (avoid None.close())
|
if self._socket is not None:
|
||||||
return
|
self._socket.shutdown(socket.SHUT_RDWR)
|
||||||
|
self._socket.close()
|
||||||
|
except OSError:
|
||||||
|
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||||
|
finally:
|
||||||
|
self._socket = None
|
||||||
|
|
||||||
with self._closing_lock:
|
async def write(self, data):
|
||||||
try:
|
|
||||||
if self._socket is not None:
|
|
||||||
self._socket.shutdown(socket.SHUT_RDWR)
|
|
||||||
self._socket.close()
|
|
||||||
except OSError:
|
|
||||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
|
||||||
finally:
|
|
||||||
self._socket = None
|
|
||||||
|
|
||||||
def write(self, data):
|
|
||||||
"""
|
"""
|
||||||
Writes (sends) the specified bytes to the connected peer.
|
Writes (sends) the specified bytes to the connected peer.
|
||||||
|
|
||||||
|
@ -126,11 +132,13 @@ class TcpClient:
|
||||||
if self._socket is None:
|
if self._socket is None:
|
||||||
self._raise_connection_reset(None)
|
self._raise_connection_reset(None)
|
||||||
|
|
||||||
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
|
|
||||||
# The socket timeout is now the maximum total duration to send all data.
|
|
||||||
try:
|
try:
|
||||||
self._socket.sendall(data)
|
await asyncio.wait_for(
|
||||||
except socket.timeout as e:
|
self.sock_sendall(data),
|
||||||
|
timeout=self.timeout,
|
||||||
|
loop=self._loop
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
__log__.debug('socket.timeout "%s" while writing data', e)
|
__log__.debug('socket.timeout "%s" while writing data', e)
|
||||||
raise TimeoutError() from e
|
raise TimeoutError() from e
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
|
@ -143,7 +151,7 @@ class TcpClient:
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def read(self, size):
|
async def read(self, size):
|
||||||
"""
|
"""
|
||||||
Reads (receives) a whole block of size bytes from the connected peer.
|
Reads (receives) a whole block of size bytes from the connected peer.
|
||||||
|
|
||||||
|
@ -153,13 +161,18 @@ class TcpClient:
|
||||||
if self._socket is None:
|
if self._socket is None:
|
||||||
self._raise_connection_reset(None)
|
self._raise_connection_reset(None)
|
||||||
|
|
||||||
# TODO Remove the timeout from this method, always use previous one
|
|
||||||
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
||||||
bytes_left = size
|
bytes_left = size
|
||||||
while bytes_left != 0:
|
while bytes_left != 0:
|
||||||
try:
|
try:
|
||||||
partial = self._socket.recv(bytes_left)
|
if self._socket is None:
|
||||||
except socket.timeout as e:
|
self._raise_connection_reset()
|
||||||
|
partial = await asyncio.wait_for(
|
||||||
|
self.sock_recv(bytes_left),
|
||||||
|
timeout=self.timeout,
|
||||||
|
loop=self._loop
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
# These are somewhat common if the server has nothing
|
# These are somewhat common if the server has nothing
|
||||||
# to send to us, so use a lower logging priority.
|
# to send to us, so use a lower logging priority.
|
||||||
__log__.debug('socket.timeout "%s" while reading data', e)
|
__log__.debug('socket.timeout "%s" while reading data', e)
|
||||||
|
@ -168,7 +181,7 @@ class TcpClient:
|
||||||
__log__.info('ConnectionError "%s" while reading data', e)
|
__log__.info('ConnectionError "%s" while reading data', e)
|
||||||
self._raise_connection_reset(e)
|
self._raise_connection_reset(e)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if e.errno != errno.EBADF and self._closing_lock.locked():
|
if e.errno != errno.EBADF:
|
||||||
# Ignore bad file descriptor while closing
|
# Ignore bad file descriptor while closing
|
||||||
__log__.info('OSError "%s" while reading data', e)
|
__log__.info('OSError "%s" while reading data', e)
|
||||||
|
|
||||||
|
@ -190,5 +203,56 @@ class TcpClient:
|
||||||
def _raise_connection_reset(self, original):
|
def _raise_connection_reset(self, original):
|
||||||
"""Disconnects the client and raises ConnectionResetError."""
|
"""Disconnects the client and raises ConnectionResetError."""
|
||||||
self.close() # Connection reset -> flag as socket closed
|
self.close() # Connection reset -> flag as socket closed
|
||||||
raise ConnectionResetError('The server has closed the connection.')\
|
raise ConnectionResetError('The server has closed the connection.') from original
|
||||||
from original
|
|
||||||
|
# due to new https://github.com/python/cpython/pull/4386
|
||||||
|
def sock_recv(self, n):
|
||||||
|
fut = self._loop.create_future()
|
||||||
|
self._sock_recv(fut, None, n)
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def _sock_recv(self, fut, registered_fd, n):
|
||||||
|
if registered_fd is not None:
|
||||||
|
self._loop.remove_reader(registered_fd)
|
||||||
|
if fut.cancelled():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = self._socket.recv(n)
|
||||||
|
except (BlockingIOError, InterruptedError):
|
||||||
|
fd = self._socket.fileno()
|
||||||
|
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
|
||||||
|
except Exception as exc:
|
||||||
|
fut.set_exception(exc)
|
||||||
|
else:
|
||||||
|
fut.set_result(data)
|
||||||
|
|
||||||
|
def sock_sendall(self, data):
|
||||||
|
fut = self._loop.create_future()
|
||||||
|
if data:
|
||||||
|
self._sock_sendall(fut, None, data)
|
||||||
|
else:
|
||||||
|
fut.set_result(None)
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def _sock_sendall(self, fut, registered_fd, data):
|
||||||
|
if registered_fd:
|
||||||
|
self._loop.remove_writer(registered_fd)
|
||||||
|
if fut.cancelled():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
n = self._socket.send(data)
|
||||||
|
except (BlockingIOError, InterruptedError):
|
||||||
|
n = 0
|
||||||
|
except Exception as exc:
|
||||||
|
fut.set_exception(exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
if n == len(data):
|
||||||
|
fut.set_result(None)
|
||||||
|
else:
|
||||||
|
if n:
|
||||||
|
data = data[n:]
|
||||||
|
fd = self._socket.fileno()
|
||||||
|
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ..tl.functions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def do_authentication(connection, retries=5):
|
async def do_authentication(connection, retries=5):
|
||||||
"""
|
"""
|
||||||
Performs the authentication steps on the given connection.
|
Performs the authentication steps on the given connection.
|
||||||
Raises an error if all attempts fail.
|
Raises an error if all attempts fail.
|
||||||
|
@ -36,14 +36,14 @@ def do_authentication(connection, retries=5):
|
||||||
last_error = None
|
last_error = None
|
||||||
while retries:
|
while retries:
|
||||||
try:
|
try:
|
||||||
return _do_authentication(connection)
|
return await _do_authentication(connection)
|
||||||
except (SecurityError, AssertionError, NotImplementedError) as e:
|
except (SecurityError, AssertionError, NotImplementedError) as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
retries -= 1
|
retries -= 1
|
||||||
raise last_error
|
raise last_error
|
||||||
|
|
||||||
|
|
||||||
def _do_authentication(connection):
|
async def _do_authentication(connection):
|
||||||
"""
|
"""
|
||||||
Executes the authentication process with the Telegram servers.
|
Executes the authentication process with the Telegram servers.
|
||||||
|
|
||||||
|
@ -56,8 +56,8 @@ def _do_authentication(connection):
|
||||||
req_pq_request = ReqPqMultiRequest(
|
req_pq_request = ReqPqMultiRequest(
|
||||||
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||||
)
|
)
|
||||||
sender.send(bytes(req_pq_request))
|
await sender.send(bytes(req_pq_request))
|
||||||
with BinaryReader(sender.receive()) as reader:
|
with BinaryReader(await sender.receive()) as reader:
|
||||||
req_pq_request.on_response(reader)
|
req_pq_request.on_response(reader)
|
||||||
|
|
||||||
res_pq = req_pq_request.result
|
res_pq = req_pq_request.result
|
||||||
|
@ -104,10 +104,10 @@ def _do_authentication(connection):
|
||||||
public_key_fingerprint=target_fingerprint,
|
public_key_fingerprint=target_fingerprint,
|
||||||
encrypted_data=cipher_text
|
encrypted_data=cipher_text
|
||||||
)
|
)
|
||||||
sender.send(bytes(req_dh_params))
|
await sender.send(bytes(req_dh_params))
|
||||||
|
|
||||||
# Step 2 response: DH Exchange
|
# Step 2 response: DH Exchange
|
||||||
with BinaryReader(sender.receive()) as reader:
|
with BinaryReader(await sender.receive()) as reader:
|
||||||
req_dh_params.on_response(reader)
|
req_dh_params.on_response(reader)
|
||||||
|
|
||||||
server_dh_params = req_dh_params.result
|
server_dh_params = req_dh_params.result
|
||||||
|
@ -174,10 +174,10 @@ def _do_authentication(connection):
|
||||||
server_nonce=res_pq.server_nonce,
|
server_nonce=res_pq.server_nonce,
|
||||||
encrypted_data=client_dh_encrypted,
|
encrypted_data=client_dh_encrypted,
|
||||||
)
|
)
|
||||||
sender.send(bytes(set_client_dh))
|
await sender.send(bytes(set_client_dh))
|
||||||
|
|
||||||
# Step 3 response: Complete DH Exchange
|
# Step 3 response: Complete DH Exchange
|
||||||
with BinaryReader(sender.receive()) as reader:
|
with BinaryReader(await sender.receive()) as reader:
|
||||||
set_client_dh.on_response(reader)
|
set_client_dh.on_response(reader)
|
||||||
|
|
||||||
dh_gen = set_client_dh.result
|
dh_gen = set_client_dh.result
|
||||||
|
|
|
@ -2,18 +2,17 @@
|
||||||
This module holds both the Connection class and the ConnectionMode enum,
|
This module holds both the Connection class and the ConnectionMode enum,
|
||||||
which specifies the protocol to be used by the Connection.
|
which specifies the protocol to be used by the Connection.
|
||||||
"""
|
"""
|
||||||
|
import errno
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from zlib import crc32
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from zlib import crc32
|
||||||
import errno
|
|
||||||
|
|
||||||
from ..crypto import AESModeCTR
|
from ..crypto import AESModeCTR
|
||||||
from ..extensions import TcpClient
|
|
||||||
from ..errors import InvalidChecksumError
|
from ..errors import InvalidChecksumError
|
||||||
|
from ..extensions import TcpClient
|
||||||
|
|
||||||
__log__ = logging.getLogger(__name__)
|
__log__ = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,7 +51,7 @@ class Connection:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mode=ConnectionMode.TCP_FULL,
|
def __init__(self, mode=ConnectionMode.TCP_FULL,
|
||||||
proxy=None, timeout=timedelta(seconds=5)):
|
proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||||
"""
|
"""
|
||||||
Initializes a new connection.
|
Initializes a new connection.
|
||||||
|
|
||||||
|
@ -65,7 +64,7 @@ class Connection:
|
||||||
self._aes_encrypt, self._aes_decrypt = None, None
|
self._aes_encrypt, self._aes_decrypt = None, None
|
||||||
|
|
||||||
# TODO Rename "TcpClient" as some sort of generic socket?
|
# TODO Rename "TcpClient" as some sort of generic socket?
|
||||||
self.conn = TcpClient(proxy=proxy, timeout=timeout)
|
self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
|
||||||
|
|
||||||
# Sending messages
|
# Sending messages
|
||||||
if mode == ConnectionMode.TCP_FULL:
|
if mode == ConnectionMode.TCP_FULL:
|
||||||
|
@ -89,7 +88,7 @@ class Connection:
|
||||||
setattr(self, 'write', self._write_plain)
|
setattr(self, 'write', self._write_plain)
|
||||||
setattr(self, 'read', self._read_plain)
|
setattr(self, 'read', self._read_plain)
|
||||||
|
|
||||||
def connect(self, ip, port):
|
async def connect(self, ip, port):
|
||||||
"""
|
"""
|
||||||
Estabilishes a connection to IP:port.
|
Estabilishes a connection to IP:port.
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ class Connection:
|
||||||
:param port: the port to connect to.
|
:param port: the port to connect to.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.conn.connect(ip, port)
|
await self.conn.connect(ip, port)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if e.errno == errno.EISCONN:
|
if e.errno == errno.EISCONN:
|
||||||
return # Already connected, no need to re-set everything up
|
return # Already connected, no need to re-set everything up
|
||||||
|
@ -106,17 +105,17 @@ class Connection:
|
||||||
|
|
||||||
self._send_counter = 0
|
self._send_counter = 0
|
||||||
if self._mode == ConnectionMode.TCP_ABRIDGED:
|
if self._mode == ConnectionMode.TCP_ABRIDGED:
|
||||||
self.conn.write(b'\xef')
|
await self.conn.write(b'\xef')
|
||||||
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
|
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
|
||||||
self.conn.write(b'\xee\xee\xee\xee')
|
await self.conn.write(b'\xee\xee\xee\xee')
|
||||||
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
|
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
|
||||||
self._setup_obfuscation()
|
await self._setup_obfuscation()
|
||||||
|
|
||||||
def get_timeout(self):
|
def get_timeout(self):
|
||||||
"""Returns the timeout used by the connection."""
|
"""Returns the timeout used by the connection."""
|
||||||
return self.conn.timeout
|
return self.conn.timeout
|
||||||
|
|
||||||
def _setup_obfuscation(self):
|
async def _setup_obfuscation(self):
|
||||||
"""
|
"""
|
||||||
Sets up the obfuscated protocol.
|
Sets up the obfuscated protocol.
|
||||||
"""
|
"""
|
||||||
|
@ -144,7 +143,7 @@ class Connection:
|
||||||
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
|
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
|
||||||
|
|
||||||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||||
self.conn.write(bytes(random))
|
await self.conn.write(bytes(random))
|
||||||
|
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
"""
|
"""
|
||||||
|
@ -166,12 +165,12 @@ class Connection:
|
||||||
|
|
||||||
# region Receive message implementations
|
# region Receive message implementations
|
||||||
|
|
||||||
def recv(self):
|
async def recv(self):
|
||||||
"""Receives and unpacks a message"""
|
"""Receives and unpacks a message"""
|
||||||
# Default implementation is just an error
|
# Default implementation is just an error
|
||||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||||
|
|
||||||
def _recv_tcp_full(self):
|
async def _recv_tcp_full(self):
|
||||||
"""
|
"""
|
||||||
Receives a message from the network,
|
Receives a message from the network,
|
||||||
internally encoded using the TCP full protocol.
|
internally encoded using the TCP full protocol.
|
||||||
|
@ -181,7 +180,10 @@ class Connection:
|
||||||
|
|
||||||
:return: the read message payload.
|
:return: the read message payload.
|
||||||
"""
|
"""
|
||||||
packet_len_seq = self.read(8) # 4 and 4
|
# TODO We don't want another call to this method that could
|
||||||
|
# potentially await on another self.read(n). Is this guaranteed
|
||||||
|
# by asyncio?
|
||||||
|
packet_len_seq = await self.read(8) # 4 and 4
|
||||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
||||||
|
|
||||||
# Sometimes Telegram seems to send a packet length of 0 (12)
|
# Sometimes Telegram seems to send a packet length of 0 (12)
|
||||||
|
@ -192,15 +194,15 @@ class Connection:
|
||||||
'reading data left:', packet_len)
|
'reading data left:', packet_len)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
__log__.error(repr(self.read(1)))
|
__log__.error(repr(await self.read(1)))
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
break
|
break
|
||||||
# Connection reset and hope it's fixed after
|
# Connection reset and hope it's fixed after
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
raise ConnectionResetError()
|
raise ConnectionResetError()
|
||||||
|
|
||||||
body = self.read(packet_len - 12)
|
body = await self.read(packet_len - 12)
|
||||||
checksum = struct.unpack('<I', self.read(4))[0]
|
checksum = struct.unpack('<I', await self.read(4))[0]
|
||||||
|
|
||||||
valid_checksum = crc32(packet_len_seq + body)
|
valid_checksum = crc32(packet_len_seq + body)
|
||||||
if checksum != valid_checksum:
|
if checksum != valid_checksum:
|
||||||
|
@ -208,38 +210,38 @@ class Connection:
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
|
||||||
def _recv_intermediate(self):
|
async def _recv_intermediate(self):
|
||||||
"""
|
"""
|
||||||
Receives a message from the network,
|
Receives a message from the network,
|
||||||
internally encoded using the TCP intermediate protocol.
|
internally encoded using the TCP intermediate protocol.
|
||||||
|
|
||||||
:return: the read message payload.
|
:return: the read message payload.
|
||||||
"""
|
"""
|
||||||
return self.read(struct.unpack('<i', self.read(4))[0])
|
return await self.read(struct.unpack('<i', await self.read(4))[0])
|
||||||
|
|
||||||
def _recv_abridged(self):
|
async def _recv_abridged(self):
|
||||||
"""
|
"""
|
||||||
Receives a message from the network,
|
Receives a message from the network,
|
||||||
internally encoded using the TCP abridged protocol.
|
internally encoded using the TCP abridged protocol.
|
||||||
|
|
||||||
:return: the read message payload.
|
:return: the read message payload.
|
||||||
"""
|
"""
|
||||||
length = struct.unpack('<B', self.read(1))[0]
|
length = struct.unpack('<B', await self.read(1))[0]
|
||||||
if length >= 127:
|
if length >= 127:
|
||||||
length = struct.unpack('<i', self.read(3) + b'\0')[0]
|
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
|
||||||
|
|
||||||
return self.read(length << 2)
|
return await self.read(length << 2)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Send message implementations
|
# region Send message implementations
|
||||||
|
|
||||||
def send(self, message):
|
async def send(self, message):
|
||||||
"""Encapsulates and sends the given message"""
|
"""Encapsulates and sends the given message"""
|
||||||
# Default implementation is just an error
|
# Default implementation is just an error
|
||||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||||
|
|
||||||
def _send_tcp_full(self, message):
|
async def _send_tcp_full(self, message):
|
||||||
"""
|
"""
|
||||||
Encapsulates and sends the given message payload
|
Encapsulates and sends the given message payload
|
||||||
using the TCP full mode (length, sequence, message, crc32).
|
using the TCP full mode (length, sequence, message, crc32).
|
||||||
|
@ -252,18 +254,18 @@ class Connection:
|
||||||
data = struct.pack('<ii', length, self._send_counter) + message
|
data = struct.pack('<ii', length, self._send_counter) + message
|
||||||
crc = struct.pack('<I', crc32(data))
|
crc = struct.pack('<I', crc32(data))
|
||||||
self._send_counter += 1
|
self._send_counter += 1
|
||||||
self.write(data + crc)
|
await self.write(data + crc)
|
||||||
|
|
||||||
def _send_intermediate(self, message):
|
async def _send_intermediate(self, message):
|
||||||
"""
|
"""
|
||||||
Encapsulates and sends the given message payload
|
Encapsulates and sends the given message payload
|
||||||
using the TCP intermediate mode (length, message).
|
using the TCP intermediate mode (length, message).
|
||||||
|
|
||||||
:param message: the message to be sent.
|
:param message: the message to be sent.
|
||||||
"""
|
"""
|
||||||
self.write(struct.pack('<i', len(message)) + message)
|
await self.write(struct.pack('<i', len(message)) + message)
|
||||||
|
|
||||||
def _send_abridged(self, message):
|
async def _send_abridged(self, message):
|
||||||
"""
|
"""
|
||||||
Encapsulates and sends the given message payload
|
Encapsulates and sends the given message payload
|
||||||
using the TCP abridged mode (short length, message).
|
using the TCP abridged mode (short length, message).
|
||||||
|
@ -276,57 +278,55 @@ class Connection:
|
||||||
else:
|
else:
|
||||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||||
|
|
||||||
self.write(length + message)
|
await self.write(length + message)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Read implementations
|
# region Read implementations
|
||||||
|
|
||||||
def read(self, length):
|
async def read(self, length):
|
||||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||||
|
|
||||||
def _read_plain(self, length):
|
async def _read_plain(self, length):
|
||||||
"""
|
"""
|
||||||
Reads data from the socket connection.
|
Reads data from the socket connection.
|
||||||
|
|
||||||
:param length: how many bytes should be read.
|
:param length: how many bytes should be read.
|
||||||
:return: a byte sequence with len(data) == length
|
:return: a byte sequence with len(data) == length
|
||||||
"""
|
"""
|
||||||
return self.conn.read(length)
|
return await self.conn.read(length)
|
||||||
|
|
||||||
def _read_obfuscated(self, length):
|
async def _read_obfuscated(self, length):
|
||||||
"""
|
"""
|
||||||
Reads data and decrypts from the socket connection.
|
Reads data and decrypts from the socket connection.
|
||||||
|
|
||||||
:param length: how many bytes should be read.
|
:param length: how many bytes should be read.
|
||||||
:return: the decrypted byte sequence with len(data) == length
|
:return: the decrypted byte sequence with len(data) == length
|
||||||
"""
|
"""
|
||||||
return self._aes_decrypt.encrypt(
|
return self._aes_decrypt.encrypt(await self.conn.read(length))
|
||||||
self.conn.read(length)
|
|
||||||
)
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Write implementations
|
# region Write implementations
|
||||||
|
|
||||||
def write(self, data):
|
async def write(self, data):
|
||||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||||
|
|
||||||
def _write_plain(self, data):
|
async def _write_plain(self, data):
|
||||||
"""
|
"""
|
||||||
Writes the given data through the socket connection.
|
Writes the given data through the socket connection.
|
||||||
|
|
||||||
:param data: the data in bytes to be written.
|
:param data: the data in bytes to be written.
|
||||||
"""
|
"""
|
||||||
self.conn.write(data)
|
await self.conn.write(data)
|
||||||
|
|
||||||
def _write_obfuscated(self, data):
|
async def _write_obfuscated(self, data):
|
||||||
"""
|
"""
|
||||||
Writes the given data through the socket connection,
|
Writes the given data through the socket connection,
|
||||||
using the obfuscated mode (AES encryption is applied on top).
|
using the obfuscated mode (AES encryption is applied on top).
|
||||||
|
|
||||||
:param data: the data in bytes to be written.
|
:param data: the data in bytes to be written.
|
||||||
"""
|
"""
|
||||||
self.conn.write(self._aes_encrypt.encrypt(data))
|
await self.conn.write(self._aes_encrypt.encrypt(data))
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -26,32 +26,32 @@ class MtProtoPlainSender:
|
||||||
self._last_msg_id = 0
|
self._last_msg_id = 0
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
|
|
||||||
def connect(self):
|
async def connect(self):
|
||||||
"""Connects to Telegram's servers."""
|
"""Connects to Telegram's servers."""
|
||||||
self._connection.connect()
|
await self._connection.connect()
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnects from Telegram's servers."""
|
"""Disconnects from Telegram's servers."""
|
||||||
self._connection.close()
|
self._connection.close()
|
||||||
|
|
||||||
def send(self, data):
|
async def send(self, data):
|
||||||
"""
|
"""
|
||||||
Sends a plain packet (auth_key_id = 0) containing the
|
Sends a plain packet (auth_key_id = 0) containing the
|
||||||
given message body (data).
|
given message body (data).
|
||||||
|
|
||||||
:param data: the data to be sent.
|
:param data: the data to be sent.
|
||||||
"""
|
"""
|
||||||
self._connection.send(
|
await self._connection.send(
|
||||||
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
||||||
)
|
)
|
||||||
|
|
||||||
def receive(self):
|
async def receive(self):
|
||||||
"""
|
"""
|
||||||
Receives a plain packet from the network.
|
Receives a plain packet from the network.
|
||||||
|
|
||||||
:return: the response body.
|
:return: the response body.
|
||||||
"""
|
"""
|
||||||
body = self._connection.recv()
|
body = await self._connection.recv()
|
||||||
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
||||||
# Broken authorization, must reset the auth key
|
# Broken authorization, must reset the auth key
|
||||||
raise BrokenAuthKeyError()
|
raise BrokenAuthKeyError()
|
||||||
|
|
|
@ -2,9 +2,10 @@
|
||||||
This module contains the class used to communicate with Telegram's servers
|
This module contains the class used to communicate with Telegram's servers
|
||||||
encrypting every packet, and relies on a valid AuthKey in the used Session.
|
encrypting every packet, and relies on a valid AuthKey in the used Session.
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import gzip
|
import gzip
|
||||||
import logging
|
import logging
|
||||||
from threading import Lock
|
from asyncio import Event
|
||||||
|
|
||||||
from .. import helpers as utils
|
from .. import helpers as utils
|
||||||
from ..errors import (
|
from ..errors import (
|
||||||
|
@ -33,7 +34,7 @@ class MtProtoSender:
|
||||||
in parallel, so thread-safety (hence locking) isn't needed.
|
in parallel, so thread-safety (hence locking) isn't needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session, connection):
|
def __init__(self, session, connection, loop=None):
|
||||||
"""
|
"""
|
||||||
Initializes a new MTProto sender.
|
Initializes a new MTProto sender.
|
||||||
|
|
||||||
|
@ -42,22 +43,20 @@ class MtProtoSender:
|
||||||
port of the server, salt, ID, and AuthKey,
|
port of the server, salt, ID, and AuthKey,
|
||||||
:param connection:
|
:param connection:
|
||||||
the Connection to be used.
|
the Connection to be used.
|
||||||
|
:param loop:
|
||||||
|
the asyncio loop to be used, or the default one.
|
||||||
"""
|
"""
|
||||||
self.session = session
|
self.session = session
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
# Message IDs that need confirmation
|
self._recv_lock = asyncio.Lock()
|
||||||
self._need_confirmation = set()
|
|
||||||
|
|
||||||
# Requests (as msg_id: Message) sent waiting to be received
|
# Requests (as msg_id: Message) sent waiting to be received
|
||||||
self._pending_receive = {}
|
self._pending_receive = {}
|
||||||
|
|
||||||
# Multithreading
|
async def connect(self):
|
||||||
self._send_lock = Lock()
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
"""Connects to the server."""
|
"""Connects to the server."""
|
||||||
self.connection.connect(self.session.server_address, self.session.port)
|
await self.connection.connect(self.session.server_address, self.session.port)
|
||||||
|
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
"""
|
"""
|
||||||
|
@ -70,18 +69,25 @@ class MtProtoSender:
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnects from the server."""
|
"""Disconnects from the server."""
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
self._need_confirmation.clear()
|
|
||||||
self._clear_all_pending()
|
self._clear_all_pending()
|
||||||
|
|
||||||
# region Send and receive
|
# region Send and receive
|
||||||
|
|
||||||
def send(self, *requests):
|
async def send(self, *requests):
|
||||||
"""
|
"""
|
||||||
Sends the specified TLObject(s) (which must be requests),
|
Sends the specified TLObject(s) (which must be requests),
|
||||||
and acknowledging any message which needed confirmation.
|
and acknowledging any message which needed confirmation.
|
||||||
|
|
||||||
:param requests: the requests to be sent.
|
:param requests: the requests to be sent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Prepare the event of every request
|
||||||
|
for r in requests:
|
||||||
|
if r.confirm_received is None:
|
||||||
|
r.confirm_received = Event(loop=self._loop)
|
||||||
|
else:
|
||||||
|
r.confirm_received.clear()
|
||||||
|
|
||||||
# Finally send our packed request(s)
|
# Finally send our packed request(s)
|
||||||
messages = [TLMessage(self.session, r) for r in requests]
|
messages = [TLMessage(self.session, r) for r in requests]
|
||||||
self._pending_receive.update({m.msg_id: m for m in messages})
|
self._pending_receive.update({m.msg_id: m for m in messages})
|
||||||
|
@ -91,13 +97,6 @@ class MtProtoSender:
|
||||||
for m in messages
|
for m in messages
|
||||||
))
|
))
|
||||||
|
|
||||||
# Pack everything in the same container if we need to send AckRequests
|
|
||||||
if self._need_confirmation:
|
|
||||||
messages.append(
|
|
||||||
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
|
|
||||||
)
|
|
||||||
self._need_confirmation.clear()
|
|
||||||
|
|
||||||
if len(messages) == 1:
|
if len(messages) == 1:
|
||||||
message = messages[0]
|
message = messages[0]
|
||||||
else:
|
else:
|
||||||
|
@ -108,13 +107,13 @@ class MtProtoSender:
|
||||||
for m in messages:
|
for m in messages:
|
||||||
m.container_msg_id = message.msg_id
|
m.container_msg_id = message.msg_id
|
||||||
|
|
||||||
self._send_message(message)
|
await self._send_message(message)
|
||||||
|
|
||||||
def _send_acknowledge(self, msg_id):
|
async def _send_acknowledge(self, msg_id):
|
||||||
"""Sends a message acknowledge for the given msg_id."""
|
"""Sends a message acknowledge for the given msg_id."""
|
||||||
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||||
|
|
||||||
def receive(self, update_state):
|
async def receive(self, update_state):
|
||||||
"""
|
"""
|
||||||
Receives a single message from the connected endpoint.
|
Receives a single message from the connected endpoint.
|
||||||
|
|
||||||
|
@ -130,7 +129,10 @@ class MtProtoSender:
|
||||||
Update and Updates objects.
|
Update and Updates objects.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body = self.connection.recv()
|
with await self._recv_lock:
|
||||||
|
# Receiving items is not an "atomic" operation since we
|
||||||
|
# need to read the length and then upcoming parts separated.
|
||||||
|
body = await self.connection.recv()
|
||||||
except (BufferError, InvalidChecksumError):
|
except (BufferError, InvalidChecksumError):
|
||||||
# TODO BufferError, we should spot the cause...
|
# TODO BufferError, we should spot the cause...
|
||||||
# "No more bytes left"; something wrong happened, clear
|
# "No more bytes left"; something wrong happened, clear
|
||||||
|
@ -147,20 +149,20 @@ class MtProtoSender:
|
||||||
|
|
||||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||||
with BinaryReader(message) as reader:
|
with BinaryReader(message) as reader:
|
||||||
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||||
|
await self._send_acknowledge(remote_msg_id)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Low level processing
|
# region Low level processing
|
||||||
|
|
||||||
def _send_message(self, message):
|
async def _send_message(self, message):
|
||||||
"""
|
"""
|
||||||
Sends the given encrypted through the network.
|
Sends the given encrypted through the network.
|
||||||
|
|
||||||
:param message: the TLMessage to be sent.
|
:param message: the TLMessage to be sent.
|
||||||
"""
|
"""
|
||||||
with self._send_lock:
|
await self.connection.send(utils.pack_message(self.session, message))
|
||||||
self.connection.send(utils.pack_message(self.session, message))
|
|
||||||
|
|
||||||
def _decode_msg(self, body):
|
def _decode_msg(self, body):
|
||||||
"""
|
"""
|
||||||
|
@ -178,7 +180,7 @@ class MtProtoSender:
|
||||||
with BinaryReader(body) as reader:
|
with BinaryReader(body) as reader:
|
||||||
return utils.unpack_message(self.session, reader)
|
return utils.unpack_message(self.session, reader)
|
||||||
|
|
||||||
def _process_msg(self, msg_id, sequence, reader, state):
|
async def _process_msg(self, msg_id, sequence, reader, state):
|
||||||
"""
|
"""
|
||||||
Processes the message read from the network inside reader.
|
Processes the message read from the network inside reader.
|
||||||
|
|
||||||
|
@ -189,7 +191,6 @@ class MtProtoSender:
|
||||||
:return: true if the message was handled correctly, false otherwise.
|
:return: true if the message was handled correctly, false otherwise.
|
||||||
"""
|
"""
|
||||||
# TODO Check salt, session_id and sequence_number
|
# TODO Check salt, session_id and sequence_number
|
||||||
self._need_confirmation.add(msg_id)
|
|
||||||
|
|
||||||
code = reader.read_int(signed=False)
|
code = reader.read_int(signed=False)
|
||||||
reader.seek(-4)
|
reader.seek(-4)
|
||||||
|
@ -197,15 +198,15 @@ class MtProtoSender:
|
||||||
# These are a bit of special case, not yet generated by the code gen
|
# These are a bit of special case, not yet generated by the code gen
|
||||||
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
|
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
|
||||||
__log__.debug('Processing Remote Procedure Call result')
|
__log__.debug('Processing Remote Procedure Call result')
|
||||||
return self._handle_rpc_result(msg_id, sequence, reader)
|
return await self._handle_rpc_result(msg_id, sequence, reader)
|
||||||
|
|
||||||
if code == MessageContainer.CONSTRUCTOR_ID:
|
if code == MessageContainer.CONSTRUCTOR_ID:
|
||||||
__log__.debug('Processing container result')
|
__log__.debug('Processing container result')
|
||||||
return self._handle_container(msg_id, sequence, reader, state)
|
return await self._handle_container(msg_id, sequence, reader, state)
|
||||||
|
|
||||||
if code == GzipPacked.CONSTRUCTOR_ID:
|
if code == GzipPacked.CONSTRUCTOR_ID:
|
||||||
__log__.debug('Processing gzipped result')
|
__log__.debug('Processing gzipped result')
|
||||||
return self._handle_gzip_packed(msg_id, sequence, reader, state)
|
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||||
|
|
||||||
if code not in tlobjects:
|
if code not in tlobjects:
|
||||||
__log__.warning(
|
__log__.warning(
|
||||||
|
@ -218,22 +219,22 @@ class MtProtoSender:
|
||||||
__log__.debug('Processing %s result', type(obj).__name__)
|
__log__.debug('Processing %s result', type(obj).__name__)
|
||||||
|
|
||||||
if isinstance(obj, Pong):
|
if isinstance(obj, Pong):
|
||||||
return self._handle_pong(msg_id, sequence, obj)
|
return await self._handle_pong(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, BadServerSalt):
|
if isinstance(obj, BadServerSalt):
|
||||||
return self._handle_bad_server_salt(msg_id, sequence, obj)
|
return await self._handle_bad_server_salt(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, BadMsgNotification):
|
if isinstance(obj, BadMsgNotification):
|
||||||
return self._handle_bad_msg_notification(msg_id, sequence, obj)
|
return await self._handle_bad_msg_notification(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, MsgDetailedInfo):
|
if isinstance(obj, MsgDetailedInfo):
|
||||||
return self._handle_msg_detailed_info(msg_id, sequence, obj)
|
return await self._handle_msg_detailed_info(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, MsgNewDetailedInfo):
|
if isinstance(obj, MsgNewDetailedInfo):
|
||||||
return self._handle_msg_new_detailed_info(msg_id, sequence, obj)
|
return await self._handle_msg_new_detailed_info(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, NewSessionCreated):
|
if isinstance(obj, NewSessionCreated):
|
||||||
return self._handle_new_session_created(msg_id, sequence, obj)
|
return await self._handle_new_session_created(msg_id, sequence, obj)
|
||||||
|
|
||||||
if isinstance(obj, MsgsAck): # may handle the request we wanted
|
if isinstance(obj, MsgsAck): # may handle the request we wanted
|
||||||
# Ignore every ack request *unless* when logging out, when it's
|
# Ignore every ack request *unless* when logging out, when it's
|
||||||
|
@ -310,7 +311,7 @@ class MtProtoSender:
|
||||||
r.request.confirm_received.set()
|
r.request.confirm_received.set()
|
||||||
self._pending_receive.clear()
|
self._pending_receive.clear()
|
||||||
|
|
||||||
def _resend_request(self, msg_id):
|
async def _resend_request(self, msg_id):
|
||||||
"""
|
"""
|
||||||
Re-sends the request that belongs to a certain msg_id. This may
|
Re-sends the request that belongs to a certain msg_id. This may
|
||||||
also be the msg_id of a container if they were sent in one.
|
also be the msg_id of a container if they were sent in one.
|
||||||
|
@ -319,12 +320,13 @@ class MtProtoSender:
|
||||||
"""
|
"""
|
||||||
request = self._pop_request(msg_id)
|
request = self._pop_request(msg_id)
|
||||||
if request:
|
if request:
|
||||||
return self.send(request)
|
await self.send(request)
|
||||||
|
return
|
||||||
requests = self._pop_requests_of_container(msg_id)
|
requests = self._pop_requests_of_container(msg_id)
|
||||||
if requests:
|
if requests:
|
||||||
return self.send(*requests)
|
await self.send(*requests)
|
||||||
|
|
||||||
def _handle_pong(self, msg_id, sequence, pong):
|
async def _handle_pong(self, msg_id, sequence, pong):
|
||||||
"""
|
"""
|
||||||
Handles a Pong response.
|
Handles a Pong response.
|
||||||
|
|
||||||
|
@ -340,7 +342,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_container(self, msg_id, sequence, reader, state):
|
async def _handle_container(self, msg_id, sequence, reader, state):
|
||||||
"""
|
"""
|
||||||
Handles a MessageContainer response.
|
Handles a MessageContainer response.
|
||||||
|
|
||||||
|
@ -355,7 +357,7 @@ class MtProtoSender:
|
||||||
# Note that this code is IMPORTANT for skipping RPC results of
|
# Note that this code is IMPORTANT for skipping RPC results of
|
||||||
# lost requests (i.e., ones from the previous connection session)
|
# lost requests (i.e., ones from the previous connection session)
|
||||||
try:
|
try:
|
||||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
if not await self._process_msg(inner_msg_id, sequence, reader, state):
|
||||||
reader.set_position(begin_position + inner_len)
|
reader.set_position(begin_position + inner_len)
|
||||||
except:
|
except:
|
||||||
# If any error is raised, something went wrong; skip the packet
|
# If any error is raised, something went wrong; skip the packet
|
||||||
|
@ -364,7 +366,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
|
async def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
|
||||||
"""
|
"""
|
||||||
Handles a BadServerSalt response.
|
Handles a BadServerSalt response.
|
||||||
|
|
||||||
|
@ -378,10 +380,11 @@ class MtProtoSender:
|
||||||
|
|
||||||
# "the bad_server_salt response is received with the
|
# "the bad_server_salt response is received with the
|
||||||
# correct salt, and the message is to be re-sent with it"
|
# correct salt, and the message is to be re-sent with it"
|
||||||
self._resend_request(bad_salt.bad_msg_id)
|
await self._resend_request(bad_salt.bad_msg_id)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
|
async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
|
||||||
"""
|
"""
|
||||||
Handles a BadMessageError response.
|
Handles a BadMessageError response.
|
||||||
|
|
||||||
|
@ -397,25 +400,25 @@ class MtProtoSender:
|
||||||
# Use the current msg_id to determine the right time offset.
|
# Use the current msg_id to determine the right time offset.
|
||||||
self.session.update_time_offset(correct_msg_id=msg_id)
|
self.session.update_time_offset(correct_msg_id=msg_id)
|
||||||
__log__.info('Attempting to use the correct time offset')
|
__log__.info('Attempting to use the correct time offset')
|
||||||
self._resend_request(bad_msg.bad_msg_id)
|
await self._resend_request(bad_msg.bad_msg_id)
|
||||||
return True
|
return True
|
||||||
elif bad_msg.error_code == 32:
|
elif bad_msg.error_code == 32:
|
||||||
# msg_seqno too low, so just pump it up by some "large" amount
|
# msg_seqno too low, so just pump it up by some "large" amount
|
||||||
# TODO A better fix would be to start with a new fresh session ID
|
# TODO A better fix would be to start with a new fresh session ID
|
||||||
self.session.sequence += 64
|
self.session.sequence += 64
|
||||||
__log__.info('Attempting to set the right higher sequence')
|
__log__.info('Attempting to set the right higher sequence')
|
||||||
self._resend_request(bad_msg.bad_msg_id)
|
await self._resend_request(bad_msg.bad_msg_id)
|
||||||
return True
|
return True
|
||||||
elif bad_msg.error_code == 33:
|
elif bad_msg.error_code == 33:
|
||||||
# msg_seqno too high never seems to happen but just in case
|
# msg_seqno too high never seems to happen but just in case
|
||||||
self.session.sequence -= 16
|
self.session.sequence -= 16
|
||||||
__log__.info('Attempting to set the right lower sequence')
|
__log__.info('Attempting to set the right lower sequence')
|
||||||
self._resend_request(bad_msg.bad_msg_id)
|
await self._resend_request(bad_msg.bad_msg_id)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
|
async def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
|
||||||
"""
|
"""
|
||||||
Handles a MsgDetailedInfo response.
|
Handles a MsgDetailedInfo response.
|
||||||
|
|
||||||
|
@ -426,10 +429,10 @@ class MtProtoSender:
|
||||||
"""
|
"""
|
||||||
# TODO For now, simply ack msg_new.answer_msg_id
|
# TODO For now, simply ack msg_new.answer_msg_id
|
||||||
# Relevant tdesktop source code: https://goo.gl/VvpCC6
|
# Relevant tdesktop source code: https://goo.gl/VvpCC6
|
||||||
self._send_acknowledge(msg_new.answer_msg_id)
|
await self._send_acknowledge(msg_new.answer_msg_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
|
async def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
|
||||||
"""
|
"""
|
||||||
Handles a MsgNewDetailedInfo response.
|
Handles a MsgNewDetailedInfo response.
|
||||||
|
|
||||||
|
@ -440,10 +443,10 @@ class MtProtoSender:
|
||||||
"""
|
"""
|
||||||
# TODO For now, simply ack msg_new.answer_msg_id
|
# TODO For now, simply ack msg_new.answer_msg_id
|
||||||
# Relevant tdesktop source code: https://goo.gl/G7DPsR
|
# Relevant tdesktop source code: https://goo.gl/G7DPsR
|
||||||
self._send_acknowledge(msg_new.answer_msg_id)
|
await self._send_acknowledge(msg_new.answer_msg_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_new_session_created(self, msg_id, sequence, new_session):
|
async def _handle_new_session_created(self, msg_id, sequence, new_session):
|
||||||
"""
|
"""
|
||||||
Handles a NewSessionCreated response.
|
Handles a NewSessionCreated response.
|
||||||
|
|
||||||
|
@ -456,7 +459,7 @@ class MtProtoSender:
|
||||||
# TODO https://goo.gl/LMyN7A
|
# TODO https://goo.gl/LMyN7A
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_rpc_result(self, msg_id, sequence, reader):
|
async def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||||
"""
|
"""
|
||||||
Handles a RPCResult response.
|
Handles a RPCResult response.
|
||||||
|
|
||||||
|
@ -484,9 +487,6 @@ class MtProtoSender:
|
||||||
reader.read_int(), reader.tgread_string()
|
reader.read_int(), reader.tgread_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Acknowledge that we received the error
|
|
||||||
self._send_acknowledge(request_id)
|
|
||||||
|
|
||||||
if request:
|
if request:
|
||||||
request.rpc_error = error
|
request.rpc_error = error
|
||||||
request.confirm_received.set()
|
request.confirm_received.set()
|
||||||
|
@ -522,7 +522,7 @@ class MtProtoSender:
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||||
"""
|
"""
|
||||||
Handles a GzipPacked response.
|
Handles a GzipPacked response.
|
||||||
|
|
||||||
|
@ -532,11 +532,6 @@ class MtProtoSender:
|
||||||
:return: the result of processing the packed message.
|
:return: the result of processing the packed message.
|
||||||
"""
|
"""
|
||||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||||
# We are reentering process_msg, which seemingly the same msg_id
|
return await self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||||
# to the self._need_confirmation set. Remove it from there first
|
|
||||||
# to avoid any future conflicts (i.e. if we "ignore" messages
|
|
||||||
# that we are already aware of, see 1a91c02 and old 63dfb1e)
|
|
||||||
self._need_confirmation -= {msg_id}
|
|
||||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -3,7 +3,6 @@ import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from os.path import isfile as file_exists
|
from os.path import isfile as file_exists
|
||||||
from threading import Lock, RLock
|
|
||||||
|
|
||||||
from .memory import MemorySession, _SentFileType
|
from .memory import MemorySession, _SentFileType
|
||||||
from ..crypto import AuthKey
|
from ..crypto import AuthKey
|
||||||
|
@ -39,11 +38,6 @@ class SQLiteSession(MemorySession):
|
||||||
if not self.filename.endswith(EXTENSION):
|
if not self.filename.endswith(EXTENSION):
|
||||||
self.filename += EXTENSION
|
self.filename += EXTENSION
|
||||||
|
|
||||||
# Cross-thread safety
|
|
||||||
self._seq_no_lock = Lock()
|
|
||||||
self._msg_id_lock = Lock()
|
|
||||||
self._db_lock = RLock()
|
|
||||||
|
|
||||||
# Migrating from .json -> SQL
|
# Migrating from .json -> SQL
|
||||||
entities = self._check_migrate_json()
|
entities = self._check_migrate_json()
|
||||||
|
|
||||||
|
@ -189,42 +183,37 @@ class SQLiteSession(MemorySession):
|
||||||
self._update_session_table()
|
self._update_session_table()
|
||||||
|
|
||||||
def _update_session_table(self):
|
def _update_session_table(self):
|
||||||
with self._db_lock:
|
c = self._cursor()
|
||||||
c = self._cursor()
|
# While we can save multiple rows into the sessions table
|
||||||
# While we can save multiple rows into the sessions table
|
# currently we only want to keep ONE as the tables don't
|
||||||
# currently we only want to keep ONE as the tables don't
|
# tell us which auth_key's are usable and will work. Needs
|
||||||
# tell us which auth_key's are usable and will work. Needs
|
# some more work before being able to save auth_key's for
|
||||||
# some more work before being able to save auth_key's for
|
# multiple DCs. Probably done differently.
|
||||||
# multiple DCs. Probably done differently.
|
c.execute('delete from sessions')
|
||||||
c.execute('delete from sessions')
|
c.execute('insert or replace into sessions values (?,?,?,?)', (
|
||||||
c.execute('insert or replace into sessions values (?,?,?,?)', (
|
self._dc_id,
|
||||||
self._dc_id,
|
self._server_address,
|
||||||
self._server_address,
|
self._port,
|
||||||
self._port,
|
self._auth_key.key if self._auth_key else b''
|
||||||
self._auth_key.key if self._auth_key else b''
|
))
|
||||||
))
|
c.close()
|
||||||
c.close()
|
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Saves the current session object as session_user_id.session"""
|
"""Saves the current session object as session_user_id.session"""
|
||||||
with self._db_lock:
|
self._conn.commit()
|
||||||
self._conn.commit()
|
|
||||||
|
|
||||||
def _cursor(self):
|
def _cursor(self):
|
||||||
"""Asserts that the connection is open and returns a cursor"""
|
"""Asserts that the connection is open and returns a cursor"""
|
||||||
with self._db_lock:
|
if self._conn is None:
|
||||||
if self._conn is None:
|
self._conn = sqlite3.connect(self.filename)
|
||||||
self._conn = sqlite3.connect(self.filename,
|
return self._conn.cursor()
|
||||||
check_same_thread=False)
|
|
||||||
return self._conn.cursor()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Closes the connection unless we're working in-memory"""
|
"""Closes the connection unless we're working in-memory"""
|
||||||
if self.filename != ':memory:':
|
if self.filename != ':memory:':
|
||||||
with self._db_lock:
|
if self._conn is not None:
|
||||||
if self._conn is not None:
|
self._conn.close()
|
||||||
self._conn.close()
|
self._conn = None
|
||||||
self._conn = None
|
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
"""Deletes the current session file"""
|
"""Deletes the current session file"""
|
||||||
|
@ -259,11 +248,10 @@ class SQLiteSession(MemorySession):
|
||||||
if not rows:
|
if not rows:
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._db_lock:
|
self._cursor().executemany(
|
||||||
self._cursor().executemany(
|
'insert or replace into entities values (?,?,?,?,?)', rows
|
||||||
'insert or replace into entities values (?,?,?,?,?)', rows
|
)
|
||||||
)
|
self.save()
|
||||||
self.save()
|
|
||||||
|
|
||||||
def _fetchone_entity(self, query, args):
|
def _fetchone_entity(self, query, args):
|
||||||
c = self._cursor()
|
c = self._cursor()
|
||||||
|
@ -302,11 +290,10 @@ class SQLiteSession(MemorySession):
|
||||||
if not isinstance(instance, (InputDocument, InputPhoto)):
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||||
raise TypeError('Cannot cache %s instance' % type(instance))
|
raise TypeError('Cannot cache %s instance' % type(instance))
|
||||||
|
|
||||||
with self._db_lock:
|
self._cursor().execute(
|
||||||
self._cursor().execute(
|
'insert or replace into sent_files values (?,?,?,?,?)', (
|
||||||
'insert or replace into sent_files values (?,?,?,?,?)', (
|
md5_digest, file_size,
|
||||||
md5_digest, file_size,
|
_SentFileType.from_type(type(instance)).value,
|
||||||
_SentFileType.from_type(type(instance)).value,
|
instance.id, instance.access_hash
|
||||||
instance.id, instance.access_hash
|
))
|
||||||
))
|
self.save()
|
||||||
self.save()
|
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from asyncio import Lock
|
||||||
|
from datetime import timedelta
|
||||||
import platform
|
import platform
|
||||||
import threading
|
|
||||||
from datetime import timedelta, datetime
|
|
||||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
|
||||||
from threading import Lock
|
|
||||||
from time import sleep
|
|
||||||
from . import version, utils
|
from . import version, utils
|
||||||
from .crypto import rsa
|
from .crypto import rsa
|
||||||
from .errors import (
|
from .errors import (
|
||||||
|
@ -70,8 +68,6 @@ class TelegramBareClient:
|
||||||
connection_mode=ConnectionMode.TCP_FULL,
|
connection_mode=ConnectionMode.TCP_FULL,
|
||||||
use_ipv6=False,
|
use_ipv6=False,
|
||||||
proxy=None,
|
proxy=None,
|
||||||
update_workers=None,
|
|
||||||
spawn_read_thread=False,
|
|
||||||
timeout=timedelta(seconds=5),
|
timeout=timedelta(seconds=5),
|
||||||
loop=None,
|
loop=None,
|
||||||
device_model=None,
|
device_model=None,
|
||||||
|
@ -95,6 +91,8 @@ class TelegramBareClient:
|
||||||
'The given session must be a str or a Session instance.'
|
'The given session must be a str or a Session instance.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
|
|
||||||
# ':' in session.server_address is True if it's an IPv6 address
|
# ':' in session.server_address is True if it's an IPv6 address
|
||||||
if (not session.server_address or
|
if (not session.server_address or
|
||||||
(':' in session.server_address) != use_ipv6):
|
(':' in session.server_address) != use_ipv6):
|
||||||
|
@ -112,13 +110,15 @@ class TelegramBareClient:
|
||||||
# that calls .connect(). Every other thread will spawn a new
|
# that calls .connect(). Every other thread will spawn a new
|
||||||
# temporary connection. The connection on this one is always
|
# temporary connection. The connection on this one is always
|
||||||
# kept open so Telegram can send us updates.
|
# kept open so Telegram can send us updates.
|
||||||
self._sender = MtProtoSender(self.session, Connection(
|
self._sender = MtProtoSender(
|
||||||
mode=connection_mode, proxy=proxy, timeout=timeout
|
self.session,
|
||||||
))
|
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
|
||||||
|
self._loop
|
||||||
|
)
|
||||||
|
|
||||||
# Two threads may be calling reconnect() when the connection is lost,
|
# Two coroutines may be calling reconnect() when the connection
|
||||||
# we only want one to actually perform the reconnection.
|
# is lost, we only want one to actually perform the reconnection.
|
||||||
self._reconnect_lock = Lock()
|
self._reconnect_lock = Lock(loop=self._loop)
|
||||||
|
|
||||||
# Cache "exported" sessions as 'dc_id: Session' not to recreate
|
# Cache "exported" sessions as 'dc_id: Session' not to recreate
|
||||||
# them all the time since generating a new key is a relatively
|
# them all the time since generating a new key is a relatively
|
||||||
|
@ -127,7 +127,7 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# This member will process updates if enabled.
|
# This member will process updates if enabled.
|
||||||
# One may change self.updates.enabled at any later point.
|
# One may change self.updates.enabled at any later point.
|
||||||
self.updates = UpdateState(workers=update_workers)
|
self.updates = UpdateState(self._loop)
|
||||||
|
|
||||||
# Used on connection - the user may modify these and reconnect
|
# Used on connection - the user may modify these and reconnect
|
||||||
system = platform.uname()
|
system = platform.uname()
|
||||||
|
@ -153,34 +153,25 @@ class TelegramBareClient:
|
||||||
# See https://core.telegram.org/api/invoking#saving-client-info.
|
# See https://core.telegram.org/api/invoking#saving-client-info.
|
||||||
self._first_request = True
|
self._first_request = True
|
||||||
|
|
||||||
# Constantly read for results and updates from within the main client,
|
self._recv_loop = None
|
||||||
# if the user has left enabled such option.
|
self._ping_loop = None
|
||||||
self._spawn_read_thread = spawn_read_thread
|
self._state_loop = None
|
||||||
self._recv_thread = None
|
self._idling = asyncio.Event()
|
||||||
self._idling = threading.Event()
|
|
||||||
|
|
||||||
# Default PingRequest delay
|
# Default PingRequest delay
|
||||||
self._last_ping = datetime.now()
|
|
||||||
self._ping_delay = timedelta(minutes=1)
|
self._ping_delay = timedelta(minutes=1)
|
||||||
|
|
||||||
# Also have another delay for GetStateRequest.
|
# Also have another delay for GetStateRequest.
|
||||||
#
|
#
|
||||||
# If the connection is kept alive for long without invoking any
|
# If the connection is kept alive for long without invoking any
|
||||||
# high level request the server simply stops sending updates.
|
# high level request the server simply stops sending updates.
|
||||||
# TODO maybe we can have ._last_request instead if any req works?
|
# TODO maybe we can have ._last_request instead if any req works?
|
||||||
self._last_state = datetime.now()
|
|
||||||
self._state_delay = timedelta(hours=1)
|
self._state_delay = timedelta(hours=1)
|
||||||
|
|
||||||
# Some errors are known but there's nothing we can do from the
|
|
||||||
# background thread. If any of these happens, call .disconnect(),
|
|
||||||
# and raise them next time .invoke() is tried to be called.
|
|
||||||
self._background_error = None
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Connecting
|
# region Connecting
|
||||||
|
|
||||||
def connect(self, _sync_updates=True):
|
async def connect(self, _sync_updates=True):
|
||||||
"""Connects to the Telegram servers, executing authentication if
|
"""Connects to the Telegram servers, executing authentication if
|
||||||
required. Note that authenticating to the Telegram servers is
|
required. Note that authenticating to the Telegram servers is
|
||||||
not the same as authenticating the desired user itself, which
|
not the same as authenticating the desired user itself, which
|
||||||
|
@ -197,10 +188,8 @@ class TelegramBareClient:
|
||||||
__log__.info('Connecting to %s:%d...',
|
__log__.info('Connecting to %s:%d...',
|
||||||
self.session.server_address, self.session.port)
|
self.session.server_address, self.session.port)
|
||||||
|
|
||||||
self._background_error = None # Clear previous errors
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._sender.connect()
|
await self._sender.connect()
|
||||||
__log__.info('Connection success!')
|
__log__.info('Connection success!')
|
||||||
|
|
||||||
# Connection was successful! Try syncing the update state
|
# Connection was successful! Try syncing the update state
|
||||||
|
@ -210,12 +199,12 @@ class TelegramBareClient:
|
||||||
self._user_connected = True
|
self._user_connected = True
|
||||||
if self._authorized is None and _sync_updates:
|
if self._authorized is None and _sync_updates:
|
||||||
try:
|
try:
|
||||||
self.sync_updates()
|
await self.sync_updates()
|
||||||
self._set_connected_and_authorized()
|
await self._set_connected_and_authorized()
|
||||||
except UnauthorizedError:
|
except UnauthorizedError:
|
||||||
self._authorized = False
|
self._authorized = False
|
||||||
elif self._authorized:
|
elif self._authorized:
|
||||||
self._set_connected_and_authorized()
|
await self._set_connected_and_authorized()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -224,7 +213,7 @@ class TelegramBareClient:
|
||||||
__log__.warning('Connection failed, got unexpected type with ID '
|
__log__.warning('Connection failed, got unexpected type with ID '
|
||||||
'%s. Migrating?', hex(e.invalid_constructor_id))
|
'%s. Migrating?', hex(e.invalid_constructor_id))
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
return self.connect(_sync_updates=_sync_updates)
|
return await self.connect(_sync_updates=_sync_updates)
|
||||||
|
|
||||||
except (RPCError, ConnectionError) as e:
|
except (RPCError, ConnectionError) as e:
|
||||||
# Probably errors from the previous session, ignore them
|
# Probably errors from the previous session, ignore them
|
||||||
|
@ -249,24 +238,15 @@ class TelegramBareClient:
|
||||||
))
|
))
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnects from the Telegram server
|
"""Disconnects from the Telegram server"""
|
||||||
and stops all the spawned threads"""
|
|
||||||
__log__.info('Disconnecting...')
|
__log__.info('Disconnecting...')
|
||||||
self._user_connected = False # This will stop recv_thread's loop
|
self._user_connected = False
|
||||||
|
|
||||||
__log__.debug('Stopping all workers...')
|
|
||||||
self.updates.stop_workers()
|
|
||||||
|
|
||||||
# This will trigger a "ConnectionResetError" on the recv_thread,
|
|
||||||
# which won't attempt reconnecting as ._user_connected is False.
|
|
||||||
__log__.debug('Disconnecting the socket...')
|
|
||||||
self._sender.disconnect()
|
self._sender.disconnect()
|
||||||
|
|
||||||
# TODO Shall we clear the _exported_sessions, or may be reused?
|
# TODO Shall we clear the _exported_sessions, or may be reused?
|
||||||
self._first_request = True # On reconnect it will be first again
|
self._first_request = True # On reconnect it will be first again
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
|
||||||
def _reconnect(self, new_dc=None):
|
async def _reconnect(self, new_dc=None):
|
||||||
"""If 'new_dc' is not set, only a call to .connect() will be made
|
"""If 'new_dc' is not set, only a call to .connect() will be made
|
||||||
since it's assumed that the connection has been lost and the
|
since it's assumed that the connection has been lost and the
|
||||||
library is reconnecting.
|
library is reconnecting.
|
||||||
|
@ -276,13 +256,14 @@ class TelegramBareClient:
|
||||||
connects to the new data center.
|
connects to the new data center.
|
||||||
"""
|
"""
|
||||||
if new_dc is None:
|
if new_dc is None:
|
||||||
if self.is_connected():
|
# Assume we are disconnected due to some error, so connect again
|
||||||
__log__.info('Reconnection aborted: already connected')
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if self.is_connected():
|
||||||
|
__log__.info('Reconnection aborted: already connected')
|
||||||
|
return True
|
||||||
|
|
||||||
__log__.info('Attempting reconnection...')
|
__log__.info('Attempting reconnection...')
|
||||||
return self.connect()
|
return await self.connect()
|
||||||
except ConnectionResetError as e:
|
except ConnectionResetError as e:
|
||||||
__log__.warning('Reconnection failed due to %s', e)
|
__log__.warning('Reconnection failed due to %s', e)
|
||||||
return False
|
return False
|
||||||
|
@ -290,7 +271,7 @@ class TelegramBareClient:
|
||||||
# Since we're reconnecting possibly due to a UserMigrateError,
|
# Since we're reconnecting possibly due to a UserMigrateError,
|
||||||
# we need to first know the Data Centers we can connect to. Do
|
# we need to first know the Data Centers we can connect to. Do
|
||||||
# that before disconnecting.
|
# that before disconnecting.
|
||||||
dc = self._get_dc(new_dc)
|
dc = await self._get_dc(new_dc)
|
||||||
__log__.info('Reconnecting to new data center %s', dc)
|
__log__.info('Reconnecting to new data center %s', dc)
|
||||||
|
|
||||||
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
|
@ -299,7 +280,7 @@ class TelegramBareClient:
|
||||||
self.session.auth_key = None
|
self.session.auth_key = None
|
||||||
self.session.save()
|
self.session.save()
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
return self.connect()
|
return await self.connect()
|
||||||
|
|
||||||
def set_proxy(self, proxy):
|
def set_proxy(self, proxy):
|
||||||
"""Change the proxy used by the connections.
|
"""Change the proxy used by the connections.
|
||||||
|
@ -312,19 +293,15 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# region Working with different connections/Data Centers
|
# region Working with different connections/Data Centers
|
||||||
|
|
||||||
def _on_read_thread(self):
|
async def _get_dc(self, dc_id, cdn=False):
|
||||||
return self._recv_thread is not None and \
|
|
||||||
threading.get_ident() == self._recv_thread.ident
|
|
||||||
|
|
||||||
def _get_dc(self, dc_id, cdn=False):
|
|
||||||
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
||||||
if not TelegramBareClient._config:
|
if not TelegramBareClient._config:
|
||||||
TelegramBareClient._config = self(GetConfigRequest())
|
TelegramBareClient._config = await self(GetConfigRequest())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cdn:
|
if cdn:
|
||||||
# Ensure we have the latest keys for the CDNs
|
# Ensure we have the latest keys for the CDNs
|
||||||
for pk in self(GetCdnConfigRequest()).public_keys:
|
for pk in await (self(GetCdnConfigRequest())).public_keys:
|
||||||
rsa.add_key(pk.public_key)
|
rsa.add_key(pk.public_key)
|
||||||
|
|
||||||
return next(
|
return next(
|
||||||
|
@ -336,10 +313,10 @@ class TelegramBareClient:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# New configuration, perhaps a new CDN was added?
|
# New configuration, perhaps a new CDN was added?
|
||||||
TelegramBareClient._config = self(GetConfigRequest())
|
TelegramBareClient._config = await self(GetConfigRequest())
|
||||||
return self._get_dc(dc_id, cdn=cdn)
|
return await self._get_dc(dc_id, cdn=cdn)
|
||||||
|
|
||||||
def _get_exported_client(self, dc_id):
|
async def _get_exported_client(self, dc_id):
|
||||||
"""Creates and connects a new TelegramBareClient for the desired DC.
|
"""Creates and connects a new TelegramBareClient for the desired DC.
|
||||||
|
|
||||||
If it's the first time calling the method with a given dc_id,
|
If it's the first time calling the method with a given dc_id,
|
||||||
|
@ -356,11 +333,11 @@ class TelegramBareClient:
|
||||||
# TODO Add a lock, don't allow two threads to create an auth key
|
# TODO Add a lock, don't allow two threads to create an auth key
|
||||||
# (when calling .connect() if there wasn't a previous session).
|
# (when calling .connect() if there wasn't a previous session).
|
||||||
# for the same data center.
|
# for the same data center.
|
||||||
dc = self._get_dc(dc_id)
|
dc = await self._get_dc(dc_id)
|
||||||
|
|
||||||
# Export the current authorization to the new DC.
|
# Export the current authorization to the new DC.
|
||||||
__log__.info('Exporting authorization for data center %s', dc)
|
__log__.info('Exporting authorization for data center %s', dc)
|
||||||
export_auth = self(ExportAuthorizationRequest(dc_id))
|
export_auth = await self(ExportAuthorizationRequest(dc_id))
|
||||||
|
|
||||||
# Create a temporary session for this IP address, which needs
|
# Create a temporary session for this IP address, which needs
|
||||||
# to be different because each auth_key is unique per DC.
|
# to be different because each auth_key is unique per DC.
|
||||||
|
@ -375,11 +352,12 @@ class TelegramBareClient:
|
||||||
client = TelegramBareClient(
|
client = TelegramBareClient(
|
||||||
session, self.api_id, self.api_hash,
|
session, self.api_id, self.api_hash,
|
||||||
proxy=self._sender.connection.conn.proxy,
|
proxy=self._sender.connection.conn.proxy,
|
||||||
timeout=self._sender.connection.get_timeout()
|
timeout=self._sender.connection.get_timeout(),
|
||||||
|
loop=self._loop
|
||||||
)
|
)
|
||||||
client.connect(_sync_updates=False)
|
await client.connect(_sync_updates=False)
|
||||||
if isinstance(export_auth, ExportedAuthorization):
|
if isinstance(export_auth, ExportedAuthorization):
|
||||||
client(ImportAuthorizationRequest(
|
await client(ImportAuthorizationRequest(
|
||||||
id=export_auth.id, bytes=export_auth.bytes
|
id=export_auth.id, bytes=export_auth.bytes
|
||||||
))
|
))
|
||||||
elif export_auth is not None:
|
elif export_auth is not None:
|
||||||
|
@ -388,11 +366,11 @@ class TelegramBareClient:
|
||||||
client._authorized = True # We exported the auth, so we got auth
|
client._authorized = True # We exported the auth, so we got auth
|
||||||
return client
|
return client
|
||||||
|
|
||||||
def _get_cdn_client(self, cdn_redirect):
|
async def _get_cdn_client(self, cdn_redirect):
|
||||||
"""Similar to ._get_exported_client, but for CDNs"""
|
"""Similar to ._get_exported_client, but for CDNs"""
|
||||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||||
if not session:
|
if not session:
|
||||||
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
|
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||||
session = self.session.clone()
|
session = self.session.clone()
|
||||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||||
|
@ -401,7 +379,8 @@ class TelegramBareClient:
|
||||||
client = TelegramBareClient(
|
client = TelegramBareClient(
|
||||||
session, self.api_id, self.api_hash,
|
session, self.api_id, self.api_hash,
|
||||||
proxy=self._sender.connection.conn.proxy,
|
proxy=self._sender.connection.conn.proxy,
|
||||||
timeout=self._sender.connection.get_timeout()
|
timeout=self._sender.connection.get_timeout(),
|
||||||
|
loop=self._loop
|
||||||
)
|
)
|
||||||
|
|
||||||
# This will make use of the new RSA keys for this specific CDN.
|
# This will make use of the new RSA keys for this specific CDN.
|
||||||
|
@ -409,7 +388,7 @@ class TelegramBareClient:
|
||||||
# We won't be calling GetConfigRequest because it's only called
|
# We won't be calling GetConfigRequest because it's only called
|
||||||
# when needed by ._get_dc, and also it's static so it's likely
|
# when needed by ._get_dc, and also it's static so it's likely
|
||||||
# set already. Avoid invoking non-CDN methods by not syncing updates.
|
# set already. Avoid invoking non-CDN methods by not syncing updates.
|
||||||
client.connect(_sync_updates=False)
|
await client.connect(_sync_updates=False)
|
||||||
client._authorized = self._authorized
|
client._authorized = self._authorized
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -417,7 +396,7 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# region Invoking Telegram requests
|
# region Invoking Telegram requests
|
||||||
|
|
||||||
def __call__(self, *requests, retries=5):
|
async def __call__(self, *requests, retries=5):
|
||||||
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
|
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
|
||||||
|
|
||||||
The invoke will be retried up to 'retries' times before raising
|
The invoke will be retried up to 'retries' times before raising
|
||||||
|
@ -427,11 +406,8 @@ class TelegramBareClient:
|
||||||
x.content_related for x in requests):
|
x.content_related for x in requests):
|
||||||
raise TypeError('You can only invoke requests, not types!')
|
raise TypeError('You can only invoke requests, not types!')
|
||||||
|
|
||||||
if self._background_error:
|
|
||||||
raise self._background_error
|
|
||||||
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
request.resolve(self, utils)
|
await request.resolve(self, utils)
|
||||||
|
|
||||||
# For logging purposes
|
# For logging purposes
|
||||||
if len(requests) == 1:
|
if len(requests) == 1:
|
||||||
|
@ -440,26 +416,23 @@ class TelegramBareClient:
|
||||||
which = '{} requests ({})'.format(
|
which = '{} requests ({})'.format(
|
||||||
len(requests), [type(x).__name__ for x in requests])
|
len(requests), [type(x).__name__ for x in requests])
|
||||||
|
|
||||||
# Determine the sender to be used (main or a new connection)
|
|
||||||
__log__.debug('Invoking %s', which)
|
__log__.debug('Invoking %s', which)
|
||||||
call_receive = \
|
call_receive = \
|
||||||
not self._idling.is_set() or self._reconnect_lock.locked()
|
not self._idling.is_set() or self._reconnect_lock.locked()
|
||||||
|
|
||||||
for retry in range(retries):
|
for retry in range(retries):
|
||||||
result = self._invoke(call_receive, *requests)
|
result = await self._invoke(call_receive, retry, *requests)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
__log__.warning('Invoking %s failed %d times, '
|
__log__.warning('Invoking %s failed %d times, '
|
||||||
'reconnecting and retrying',
|
'reconnecting and retrying',
|
||||||
[str(x) for x in requests], retry + 1)
|
[str(x) for x in requests], retry + 1)
|
||||||
sleep(1)
|
|
||||||
# The ReadThread has priority when attempting reconnection,
|
await asyncio.sleep(retry + 1, loop=self._loop)
|
||||||
# since this thread is constantly running while __call__ is
|
|
||||||
# only done sometimes. Here try connecting only once/retry.
|
|
||||||
if not self._reconnect_lock.locked():
|
if not self._reconnect_lock.locked():
|
||||||
with self._reconnect_lock:
|
with await self._reconnect_lock:
|
||||||
self._reconnect()
|
await self._reconnect()
|
||||||
|
|
||||||
raise RuntimeError('Number of retries reached 0 for {}.'.format(
|
raise RuntimeError('Number of retries reached 0 for {}.'.format(
|
||||||
[type(x).__name__ for x in requests]
|
[type(x).__name__ for x in requests]
|
||||||
|
@ -468,18 +441,17 @@ class TelegramBareClient:
|
||||||
# Let people use client.invoke(SomeRequest()) instead client(...)
|
# Let people use client.invoke(SomeRequest()) instead client(...)
|
||||||
invoke = __call__
|
invoke = __call__
|
||||||
|
|
||||||
def _invoke(self, call_receive, *requests):
|
async def _invoke(self, call_receive, retry, *requests):
|
||||||
try:
|
try:
|
||||||
# Ensure that we start with no previous errors (i.e. resending)
|
# Ensure that we start with no previous errors (i.e. resending)
|
||||||
for x in requests:
|
for x in requests:
|
||||||
x.confirm_received.clear()
|
|
||||||
x.rpc_error = None
|
x.rpc_error = None
|
||||||
|
|
||||||
if not self.session.auth_key:
|
if not self.session.auth_key:
|
||||||
__log__.info('Need to generate new auth key before invoking')
|
__log__.info('Need to generate new auth key before invoking')
|
||||||
self._first_request = True
|
self._first_request = True
|
||||||
self.session.auth_key, self.session.time_offset = \
|
self.session.auth_key, self.session.time_offset = \
|
||||||
authenticator.do_authentication(self._sender.connection)
|
await authenticator.do_authentication(self._sender.connection)
|
||||||
|
|
||||||
if self._first_request:
|
if self._first_request:
|
||||||
__log__.info('Initializing a new connection while invoking')
|
__log__.info('Initializing a new connection while invoking')
|
||||||
|
@ -489,24 +461,21 @@ class TelegramBareClient:
|
||||||
# We need a SINGLE request (like GetConfig) to init conn.
|
# We need a SINGLE request (like GetConfig) to init conn.
|
||||||
# Once that's done, the N original requests will be
|
# Once that's done, the N original requests will be
|
||||||
# invoked.
|
# invoked.
|
||||||
TelegramBareClient._config = self(
|
TelegramBareClient._config = await self(
|
||||||
self._wrap_init_connection(GetConfigRequest())
|
self._wrap_init_connection(GetConfigRequest())
|
||||||
)
|
)
|
||||||
|
|
||||||
self._sender.send(*requests)
|
await self._sender.send(*requests)
|
||||||
|
|
||||||
if not call_receive:
|
if not call_receive:
|
||||||
# TODO This will be slightly troublesome if we allow
|
await asyncio.wait(
|
||||||
# switching between constant read or not on the fly.
|
list(map(lambda x: x.confirm_received.wait(), requests)),
|
||||||
# Must also watch out for calling .read() from two places,
|
timeout=self._sender.connection.get_timeout(),
|
||||||
# in which case a Lock would be required for .receive().
|
loop=self._loop
|
||||||
for x in requests:
|
)
|
||||||
x.confirm_received.wait(
|
|
||||||
self._sender.connection.get_timeout()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
while not all(x.confirm_received.is_set() for x in requests):
|
while not all(x.confirm_received.is_set() for x in requests):
|
||||||
self._sender.receive(update_state=self.updates)
|
await self._sender.receive(update_state=self.updates)
|
||||||
|
|
||||||
except BrokenAuthKeyError:
|
except BrokenAuthKeyError:
|
||||||
__log__.error('Authorization key seems broken and was invalid!')
|
__log__.error('Authorization key seems broken and was invalid!')
|
||||||
|
@ -552,12 +521,8 @@ class TelegramBareClient:
|
||||||
except (PhoneMigrateError, NetworkMigrateError,
|
except (PhoneMigrateError, NetworkMigrateError,
|
||||||
UserMigrateError) as e:
|
UserMigrateError) as e:
|
||||||
|
|
||||||
# TODO What happens with the background thread here?
|
await self._reconnect(new_dc=e.new_dc)
|
||||||
# For normal use cases, this won't happen, because this will only
|
return await self._invoke(call_receive, retry, *requests)
|
||||||
# be on the very first connection (not authorized, not running),
|
|
||||||
# but may be an issue for people who actually travel?
|
|
||||||
self._reconnect(new_dc=e.new_dc)
|
|
||||||
return self._invoke(call_receive, *requests)
|
|
||||||
|
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
# Telegram is having some issues, just retry
|
# Telegram is having some issues, just retry
|
||||||
|
@ -568,7 +533,8 @@ class TelegramBareClient:
|
||||||
if e.seconds > self.session.flood_sleep_threshold | 0:
|
if e.seconds > self.session.flood_sleep_threshold | 0:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
sleep(e.seconds)
|
await asyncio.sleep(e.seconds, loop=self._loop)
|
||||||
|
return None
|
||||||
|
|
||||||
# Some really basic functionality
|
# Some really basic functionality
|
||||||
|
|
||||||
|
@ -588,90 +554,69 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# region Updates handling
|
# region Updates handling
|
||||||
|
|
||||||
def sync_updates(self):
|
async def sync_updates(self):
|
||||||
"""Synchronizes self.updates to their initial state. Will be
|
"""Synchronizes self.updates to their initial state. Will be
|
||||||
called automatically on connection if self.updates.enabled = True,
|
called automatically on connection if self.updates.enabled = True,
|
||||||
otherwise it should be called manually after enabling updates.
|
otherwise it should be called manually after enabling updates.
|
||||||
"""
|
"""
|
||||||
self.updates.process(self(GetStateRequest()))
|
self.updates.process(await self(GetStateRequest()))
|
||||||
self._last_state = datetime.now()
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Constant read
|
# Constant read
|
||||||
|
|
||||||
def _set_connected_and_authorized(self):
|
# This is async so that the overrided version in TelegramClient can be
|
||||||
|
# async without problems.
|
||||||
|
async def _set_connected_and_authorized(self):
|
||||||
self._authorized = True
|
self._authorized = True
|
||||||
self.updates.setup_workers()
|
if self._recv_loop is None:
|
||||||
if self._spawn_read_thread and self._recv_thread is None:
|
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
|
||||||
self._recv_thread = threading.Thread(
|
if self._ping_loop is None:
|
||||||
name='ReadThread', daemon=True,
|
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
|
||||||
target=self._recv_thread_impl
|
if self._state_loop is None:
|
||||||
)
|
self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop)
|
||||||
self._recv_thread.start()
|
|
||||||
|
|
||||||
def _signal_handler(self, signum, frame):
|
async def _ping_loop_impl(self):
|
||||||
if self._user_connected:
|
while self._user_connected:
|
||||||
self.disconnect()
|
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
|
||||||
else:
|
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
|
||||||
os._exit(1)
|
self._ping_loop = None
|
||||||
|
|
||||||
def idle(self, stop_signals=(SIGINT, SIGTERM, SIGABRT)):
|
async def _state_loop_impl(self):
|
||||||
"""
|
while self._user_connected:
|
||||||
Idles the program by looping forever and listening for updates
|
await asyncio.sleep(self._state_delay.seconds, loop=self._loop)
|
||||||
until one of the signals are received, which breaks the loop.
|
await self._sender.send(GetStateRequest())
|
||||||
|
|
||||||
:param stop_signals:
|
|
||||||
Iterable containing signals from the signal module that will
|
|
||||||
be subscribed to TelegramClient.disconnect() (effectively
|
|
||||||
stopping the idle loop), which will be called on receiving one
|
|
||||||
of those signals.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if self._spawn_read_thread and not self._on_read_thread():
|
|
||||||
raise RuntimeError('Can only idle if spawn_read_thread=False')
|
|
||||||
|
|
||||||
|
async def _recv_loop_impl(self):
|
||||||
|
__log__.info('Starting to wait for items from the network')
|
||||||
self._idling.set()
|
self._idling.set()
|
||||||
for sig in stop_signals:
|
need_reconnect = False
|
||||||
signal(sig, self._signal_handler)
|
|
||||||
|
|
||||||
if self._on_read_thread():
|
|
||||||
__log__.info('Starting to wait for items from the network')
|
|
||||||
else:
|
|
||||||
__log__.info('Idling to receive items from the network')
|
|
||||||
|
|
||||||
while self._user_connected:
|
while self._user_connected:
|
||||||
try:
|
try:
|
||||||
if datetime.now() > self._last_ping + self._ping_delay:
|
if need_reconnect:
|
||||||
self._sender.send(PingRequest(
|
__log__.info('Attempting reconnection from read loop')
|
||||||
int.from_bytes(os.urandom(8), 'big', signed=True)
|
need_reconnect = False
|
||||||
))
|
with await self._reconnect_lock:
|
||||||
self._last_ping = datetime.now()
|
while self._user_connected and not await self._reconnect():
|
||||||
|
# Retry forever, this is instant messaging
|
||||||
|
await asyncio.sleep(0.1, loop=self._loop)
|
||||||
|
|
||||||
if datetime.now() > self._last_state + self._state_delay:
|
|
||||||
self._sender.send(GetStateRequest())
|
|
||||||
self._last_state = datetime.now()
|
|
||||||
|
|
||||||
__log__.debug('Receiving items from the network...')
|
|
||||||
self._sender.receive(update_state=self.updates)
|
|
||||||
except TimeoutError:
|
|
||||||
# No problem
|
|
||||||
__log__.debug('Receiving items from the network timed out')
|
|
||||||
except ConnectionResetError:
|
|
||||||
if self._user_connected:
|
|
||||||
__log__.error('Connection was reset while receiving '
|
|
||||||
'items. Reconnecting')
|
|
||||||
with self._reconnect_lock:
|
|
||||||
while self._user_connected and not self._reconnect():
|
|
||||||
sleep(0.1) # Retry forever, this is instant messaging
|
|
||||||
|
|
||||||
if self.is_connected():
|
|
||||||
# Telegram seems to kick us every 1024 items received
|
# Telegram seems to kick us every 1024 items received
|
||||||
# from the network not considering things like bad salt.
|
# from the network not considering things like bad salt.
|
||||||
# We must execute some *high level* request (that's not
|
# We must execute some *high level* request (that's not
|
||||||
# a ping) if we want to receive updates again.
|
# a ping) if we want to receive updates again.
|
||||||
# TODO Test if getDifference works too (better alternative)
|
# TODO Test if getDifference works too (better alternative)
|
||||||
self._sender.send(GetStateRequest())
|
await self._sender.send(GetStateRequest())
|
||||||
|
|
||||||
|
__log__.debug('Receiving items from the network...')
|
||||||
|
await self._sender.receive(update_state=self.updates)
|
||||||
|
except TimeoutError:
|
||||||
|
# No problem.
|
||||||
|
__log__.debug('Receiving items from the network timed out')
|
||||||
|
except ConnectionError:
|
||||||
|
need_reconnect = True
|
||||||
|
__log__.error('Connection was reset while receiving items')
|
||||||
|
await asyncio.sleep(1, loop=self._loop)
|
||||||
except:
|
except:
|
||||||
self._idling.clear()
|
self._idling.clear()
|
||||||
raise
|
raise
|
||||||
|
@ -679,39 +624,4 @@ class TelegramBareClient:
|
||||||
self._idling.clear()
|
self._idling.clear()
|
||||||
__log__.info('Connection closed by the user, not reading anymore')
|
__log__.info('Connection closed by the user, not reading anymore')
|
||||||
|
|
||||||
# By using this approach, another thread will be
|
|
||||||
# created and started upon connection to constantly read
|
|
||||||
# from the other end. Otherwise, manual calls to .receive()
|
|
||||||
# must be performed. The MtProtoSender cannot be connected,
|
|
||||||
# or an error will be thrown.
|
|
||||||
#
|
|
||||||
# This way, sending and receiving will be completely independent.
|
|
||||||
def _recv_thread_impl(self):
|
|
||||||
# This thread is "idle" (only listening for updates), but also
|
|
||||||
# excepts everything unlike the manual idle because it should
|
|
||||||
# not crash.
|
|
||||||
while self._user_connected:
|
|
||||||
try:
|
|
||||||
self.idle(stop_signals=tuple())
|
|
||||||
except Exception as error:
|
|
||||||
__log__.exception('Unknown exception in the read thread! '
|
|
||||||
'Disconnecting and leaving it to main thread')
|
|
||||||
# Unknown exception, pass it to the main thread
|
|
||||||
|
|
||||||
try:
|
|
||||||
import socks
|
|
||||||
if isinstance(error, (
|
|
||||||
socks.GeneralProxyError, socks.ProxyConnectionError
|
|
||||||
)):
|
|
||||||
# This is a known error, and it's not related to
|
|
||||||
# Telegram but rather to the proxy. Disconnect and
|
|
||||||
# hand it over to the main thread.
|
|
||||||
self._background_error = error
|
|
||||||
self.disconnect()
|
|
||||||
break
|
|
||||||
except ImportError:
|
|
||||||
"Not using PySocks, so it can't be a proxy error"
|
|
||||||
|
|
||||||
self._recv_thread = None
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import getpass
|
import getpass
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
|
@ -6,7 +7,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserList
|
from collections import OrderedDict, UserList
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -158,18 +158,16 @@ class TelegramClient(TelegramBareClient):
|
||||||
connection_mode=ConnectionMode.TCP_FULL,
|
connection_mode=ConnectionMode.TCP_FULL,
|
||||||
use_ipv6=False,
|
use_ipv6=False,
|
||||||
proxy=None,
|
proxy=None,
|
||||||
update_workers=None,
|
|
||||||
timeout=timedelta(seconds=5),
|
timeout=timedelta(seconds=5),
|
||||||
spawn_read_thread=True,
|
loop=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
session, api_id, api_hash,
|
session, api_id, api_hash,
|
||||||
connection_mode=connection_mode,
|
connection_mode=connection_mode,
|
||||||
use_ipv6=use_ipv6,
|
use_ipv6=use_ipv6,
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
update_workers=update_workers,
|
|
||||||
spawn_read_thread=spawn_read_thread,
|
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
loop=loop,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -190,7 +188,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# region Authorization requests
|
# region Authorization requests
|
||||||
|
|
||||||
def send_code_request(self, phone, force_sms=False):
|
async def send_code_request(self, phone, force_sms=False):
|
||||||
"""
|
"""
|
||||||
Sends a code request to the specified phone number.
|
Sends a code request to the specified phone number.
|
||||||
|
|
||||||
|
@ -208,7 +206,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
phone_hash = self._phone_code_hash.get(phone)
|
phone_hash = self._phone_code_hash.get(phone)
|
||||||
|
|
||||||
if not phone_hash:
|
if not phone_hash:
|
||||||
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
|
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
|
||||||
self._phone_code_hash[phone] = phone_hash = result.phone_code_hash
|
self._phone_code_hash[phone] = phone_hash = result.phone_code_hash
|
||||||
else:
|
else:
|
||||||
force_sms = True
|
force_sms = True
|
||||||
|
@ -216,22 +214,23 @@ class TelegramClient(TelegramBareClient):
|
||||||
self._phone = phone
|
self._phone = phone
|
||||||
|
|
||||||
if force_sms:
|
if force_sms:
|
||||||
result = self(ResendCodeRequest(phone, phone_hash))
|
result = await self(ResendCodeRequest(phone, phone_hash))
|
||||||
self._phone_code_hash[phone] = result.phone_code_hash
|
self._phone_code_hash[phone] = result.phone_code_hash
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def start(self,
|
async def start(self,
|
||||||
phone=lambda: input('Please enter your phone: '),
|
phone=lambda: input('Please enter your phone: '),
|
||||||
password=lambda: getpass.getpass('Please enter your password: '),
|
password=lambda: getpass.getpass(
|
||||||
bot_token=None, force_sms=False, code_callback=None,
|
'Please enter your password: '),
|
||||||
first_name='New User', last_name=''):
|
bot_token=None, force_sms=False, code_callback=None,
|
||||||
|
first_name='New User', last_name=''):
|
||||||
"""
|
"""
|
||||||
Convenience method to interactively connect and sign in if required,
|
Convenience method to interactively connect and sign in if required,
|
||||||
also taking into consideration that 2FA may be enabled in the account.
|
also taking into consideration that 2FA may be enabled in the account.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
>>> client = TelegramClient(session, api_id, api_hash).start(phone)
|
>>> client = await TelegramClient(session, api_id, api_hash).start(phone)
|
||||||
Please enter the code you received: 12345
|
Please enter the code you received: 12345
|
||||||
Please enter your password: *******
|
Please enter your password: *******
|
||||||
(You are now logged in)
|
(You are now logged in)
|
||||||
|
@ -286,14 +285,14 @@ class TelegramClient(TelegramBareClient):
|
||||||
'must only provide one of either')
|
'must only provide one of either')
|
||||||
|
|
||||||
if not self.is_connected():
|
if not self.is_connected():
|
||||||
self.connect()
|
await self.connect()
|
||||||
|
|
||||||
if self.is_user_authorized():
|
if self.is_user_authorized():
|
||||||
self._check_events_pending_resolve()
|
self._check_events_pending_resolve()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if bot_token:
|
if bot_token:
|
||||||
self.sign_in(bot_token=bot_token)
|
await self.sign_in(bot_token=bot_token)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# Turn the callable into a valid phone number
|
# Turn the callable into a valid phone number
|
||||||
|
@ -305,15 +304,15 @@ class TelegramClient(TelegramBareClient):
|
||||||
max_attempts = 3
|
max_attempts = 3
|
||||||
two_step_detected = False
|
two_step_detected = False
|
||||||
|
|
||||||
sent_code = self.send_code_request(phone, force_sms=force_sms)
|
sent_code = await self.send_code_request(phone, force_sms=force_sms)
|
||||||
sign_up = not sent_code.phone_registered
|
sign_up = not sent_code.phone_registered
|
||||||
while attempts < max_attempts:
|
while attempts < max_attempts:
|
||||||
try:
|
try:
|
||||||
if sign_up:
|
if sign_up:
|
||||||
me = self.sign_up(code_callback(), first_name, last_name)
|
me = await self.sign_up(code_callback(), first_name, last_name)
|
||||||
else:
|
else:
|
||||||
# Raises SessionPasswordNeededError if 2FA enabled
|
# Raises SessionPasswordNeededError if 2FA enabled
|
||||||
me = self.sign_in(phone, code_callback())
|
me = await self.sign_in(phone, code_callback())
|
||||||
break
|
break
|
||||||
except SessionPasswordNeededError:
|
except SessionPasswordNeededError:
|
||||||
two_step_detected = True
|
two_step_detected = True
|
||||||
|
@ -342,15 +341,15 @@ class TelegramClient(TelegramBareClient):
|
||||||
# TODO If callable given make it retry on invalid
|
# TODO If callable given make it retry on invalid
|
||||||
if callable(password):
|
if callable(password):
|
||||||
password = password()
|
password = password()
|
||||||
me = self.sign_in(phone=phone, password=password)
|
me = await self.sign_in(phone=phone, password=password)
|
||||||
|
|
||||||
# We won't reach here if any step failed (exit by exception)
|
# We won't reach here if any step failed (exit by exception)
|
||||||
print('Signed in successfully as', utils.get_display_name(me))
|
print('Signed in successfully as', utils.get_display_name(me))
|
||||||
self._check_events_pending_resolve()
|
self._check_events_pending_resolve()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def sign_in(self, phone=None, code=None,
|
async def sign_in(self, phone=None, code=None,
|
||||||
password=None, bot_token=None, phone_code_hash=None):
|
password=None, bot_token=None, phone_code_hash=None):
|
||||||
"""
|
"""
|
||||||
Starts or completes the sign in process with the given phone number
|
Starts or completes the sign in process with the given phone number
|
||||||
or code that Telegram sent.
|
or code that Telegram sent.
|
||||||
|
@ -385,7 +384,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
return self.get_me()
|
return self.get_me()
|
||||||
|
|
||||||
if phone and not code and not password:
|
if phone and not code and not password:
|
||||||
return self.send_code_request(phone)
|
return await self.send_code_request(phone)
|
||||||
elif code:
|
elif code:
|
||||||
phone = utils.parse_phone(phone) or self._phone
|
phone = utils.parse_phone(phone) or self._phone
|
||||||
phone_code_hash = \
|
phone_code_hash = \
|
||||||
|
@ -400,14 +399,14 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# May raise PhoneCodeEmptyError, PhoneCodeExpiredError,
|
# May raise PhoneCodeEmptyError, PhoneCodeExpiredError,
|
||||||
# PhoneCodeHashEmptyError or PhoneCodeInvalidError.
|
# PhoneCodeHashEmptyError or PhoneCodeInvalidError.
|
||||||
result = self(SignInRequest(phone, phone_code_hash, str(code)))
|
result = await self(SignInRequest(phone, phone_code_hash, str(code)))
|
||||||
elif password:
|
elif password:
|
||||||
salt = self(GetPasswordRequest()).current_salt
|
salt = (await self(GetPasswordRequest())).current_salt
|
||||||
result = self(CheckPasswordRequest(
|
result = await self(CheckPasswordRequest(
|
||||||
helpers.get_password_hash(password, salt)
|
helpers.get_password_hash(password, salt)
|
||||||
))
|
))
|
||||||
elif bot_token:
|
elif bot_token:
|
||||||
result = self(ImportBotAuthorizationRequest(
|
result = await self(ImportBotAuthorizationRequest(
|
||||||
flags=0, bot_auth_token=bot_token,
|
flags=0, bot_auth_token=bot_token,
|
||||||
api_id=self.api_id, api_hash=self.api_hash
|
api_id=self.api_id, api_hash=self.api_hash
|
||||||
))
|
))
|
||||||
|
@ -420,10 +419,10 @@ class TelegramClient(TelegramBareClient):
|
||||||
self._self_input_peer = utils.get_input_peer(
|
self._self_input_peer = utils.get_input_peer(
|
||||||
result.user, allow_self=False
|
result.user, allow_self=False
|
||||||
)
|
)
|
||||||
self._set_connected_and_authorized()
|
await self._set_connected_and_authorized()
|
||||||
return result.user
|
return result.user
|
||||||
|
|
||||||
def sign_up(self, code, first_name, last_name=''):
|
async def sign_up(self, code, first_name, last_name=''):
|
||||||
"""
|
"""
|
||||||
Signs up to Telegram if you don't have an account yet.
|
Signs up to Telegram if you don't have an account yet.
|
||||||
You must call .send_code_request(phone) first.
|
You must call .send_code_request(phone) first.
|
||||||
|
@ -442,10 +441,10 @@ class TelegramClient(TelegramBareClient):
|
||||||
The new created user.
|
The new created user.
|
||||||
"""
|
"""
|
||||||
if self.is_user_authorized():
|
if self.is_user_authorized():
|
||||||
self._check_events_pending_resolve()
|
await self._check_events_pending_resolve()
|
||||||
return self.get_me()
|
return await self.get_me()
|
||||||
|
|
||||||
result = self(SignUpRequest(
|
result = await self(SignUpRequest(
|
||||||
phone_number=self._phone,
|
phone_number=self._phone,
|
||||||
phone_code_hash=self._phone_code_hash.get(self._phone, ''),
|
phone_code_hash=self._phone_code_hash.get(self._phone, ''),
|
||||||
phone_code=str(code),
|
phone_code=str(code),
|
||||||
|
@ -456,10 +455,10 @@ class TelegramClient(TelegramBareClient):
|
||||||
self._self_input_peer = utils.get_input_peer(
|
self._self_input_peer = utils.get_input_peer(
|
||||||
result.user, allow_self=False
|
result.user, allow_self=False
|
||||||
)
|
)
|
||||||
self._set_connected_and_authorized()
|
await self._set_connected_and_authorized()
|
||||||
return result.user
|
return result.user
|
||||||
|
|
||||||
def log_out(self):
|
async def log_out(self):
|
||||||
"""
|
"""
|
||||||
Logs out Telegram and deletes the current ``*.session`` file.
|
Logs out Telegram and deletes the current ``*.session`` file.
|
||||||
|
|
||||||
|
@ -467,7 +466,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
True if the operation was successful.
|
True if the operation was successful.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self(LogOutRequest())
|
await self(LogOutRequest())
|
||||||
except RPCError:
|
except RPCError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -475,7 +474,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
self.session.delete()
|
self.session.delete()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_me(self, input_peer=False):
|
async def get_me(self, input_peer=False):
|
||||||
"""
|
"""
|
||||||
Gets "me" (the self user) which is currently authenticated,
|
Gets "me" (the self user) which is currently authenticated,
|
||||||
or None if the request fails (hence, not authenticated).
|
or None if the request fails (hence, not authenticated).
|
||||||
|
@ -491,9 +490,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
"""
|
"""
|
||||||
if input_peer and self._self_input_peer:
|
if input_peer and self._self_input_peer:
|
||||||
return self._self_input_peer
|
return self._self_input_peer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
me = self(GetUsersRequest([InputUserSelf()]))[0]
|
me = (await self(GetUsersRequest([InputUserSelf()])))[0]
|
||||||
if not self._self_input_peer:
|
if not self._self_input_peer:
|
||||||
self._self_input_peer = utils.get_input_peer(
|
self._self_input_peer = utils.get_input_peer(
|
||||||
me, allow_self=False
|
me, allow_self=False
|
||||||
|
@ -507,8 +505,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# region Dialogs ("chats") requests
|
# region Dialogs ("chats") requests
|
||||||
|
|
||||||
def get_dialogs(self, limit=10, offset_date=None, offset_id=0,
|
async def get_dialogs(self, limit=10, offset_date=None, offset_id=0,
|
||||||
offset_peer=InputPeerEmpty()):
|
offset_peer=InputPeerEmpty()):
|
||||||
"""
|
"""
|
||||||
Gets N "dialogs" (open "chats" or conversations with other people).
|
Gets N "dialogs" (open "chats" or conversations with other people).
|
||||||
|
|
||||||
|
@ -535,7 +533,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
limit = float('inf') if limit is None else int(limit)
|
limit = float('inf') if limit is None else int(limit)
|
||||||
if limit == 0:
|
if limit == 0:
|
||||||
# Special case, get a single dialog and determine count
|
# Special case, get a single dialog and determine count
|
||||||
dialogs = self(GetDialogsRequest(
|
dialogs = await self(GetDialogsRequest(
|
||||||
offset_date=offset_date,
|
offset_date=offset_date,
|
||||||
offset_id=offset_id,
|
offset_id=offset_id,
|
||||||
offset_peer=offset_peer,
|
offset_peer=offset_peer,
|
||||||
|
@ -549,7 +547,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
dialogs = OrderedDict() # Use peer id as identifier to avoid dupes
|
dialogs = OrderedDict() # Use peer id as identifier to avoid dupes
|
||||||
while len(dialogs) < limit:
|
while len(dialogs) < limit:
|
||||||
real_limit = min(limit - len(dialogs), 100)
|
real_limit = min(limit - len(dialogs), 100)
|
||||||
r = self(GetDialogsRequest(
|
r = await self(GetDialogsRequest(
|
||||||
offset_date=offset_date,
|
offset_date=offset_date,
|
||||||
offset_id=offset_id,
|
offset_id=offset_id,
|
||||||
offset_peer=offset_peer,
|
offset_peer=offset_peer,
|
||||||
|
@ -580,7 +578,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
dialogs.total = total_count
|
dialogs.total = total_count
|
||||||
return dialogs
|
return dialogs
|
||||||
|
|
||||||
def get_drafts(self): # TODO: Ability to provide a `filter`
|
async def get_drafts(self): # TODO: Ability to provide a `filter`
|
||||||
"""
|
"""
|
||||||
Gets all open draft messages.
|
Gets all open draft messages.
|
||||||
|
|
||||||
|
@ -589,7 +587,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
You can call ``draft.set_message('text')`` to change the message,
|
You can call ``draft.set_message('text')`` to change the message,
|
||||||
or delete it through :meth:`draft.delete()`.
|
or delete it through :meth:`draft.delete()`.
|
||||||
"""
|
"""
|
||||||
response = self(GetAllDraftsRequest())
|
response = await self(GetAllDraftsRequest())
|
||||||
self.session.process_entities(response)
|
self.session.process_entities(response)
|
||||||
self.session.generate_sequence(response.seq)
|
self.session.generate_sequence(response.seq)
|
||||||
drafts = [Draft._from_update(self, u) for u in response.updates]
|
drafts = [Draft._from_update(self, u) for u in response.updates]
|
||||||
|
@ -636,7 +634,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
if request.id == update.message.id:
|
if request.id == update.message.id:
|
||||||
return update.message
|
return update.message
|
||||||
|
|
||||||
def _parse_message_text(self, message, parse_mode):
|
async def _parse_message_text(self, message, parse_mode):
|
||||||
"""
|
"""
|
||||||
Returns a (parsed message, entities) tuple depending on parse_mode.
|
Returns a (parsed message, entities) tuple depending on parse_mode.
|
||||||
"""
|
"""
|
||||||
|
@ -657,7 +655,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
if m:
|
if m:
|
||||||
try:
|
try:
|
||||||
msg_entities[i] = InputMessageEntityMentionName(
|
msg_entities[i] = InputMessageEntityMentionName(
|
||||||
e.offset, e.length, self.get_input_entity(
|
e.offset, e.length, await self.get_input_entity(
|
||||||
int(m.group(1)) if m.group(1) else e.url
|
int(m.group(1)) if m.group(1) else e.url
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -667,8 +665,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
return message, msg_entities
|
return message, msg_entities
|
||||||
|
|
||||||
def send_message(self, entity, message, reply_to=None, parse_mode='md',
|
async def send_message(self, entity, message, reply_to=None,
|
||||||
link_preview=True):
|
parse_mode='md', link_preview=True):
|
||||||
"""
|
"""
|
||||||
Sends the given message to the specified entity (user/chat/channel).
|
Sends the given message to the specified entity (user/chat/channel).
|
||||||
|
|
||||||
|
@ -695,11 +693,12 @@ class TelegramClient(TelegramBareClient):
|
||||||
Returns:
|
Returns:
|
||||||
the sent message
|
the sent message
|
||||||
"""
|
"""
|
||||||
entity = self.get_input_entity(entity)
|
|
||||||
|
entity = await self.get_input_entity(entity)
|
||||||
if isinstance(message, Message):
|
if isinstance(message, Message):
|
||||||
if (message.media
|
if (message.media
|
||||||
and not isinstance(message.media, MessageMediaWebPage)):
|
and not isinstance(message.media, MessageMediaWebPage)):
|
||||||
return self.send_file(entity, message.media)
|
return await self.send_file(entity, message.media)
|
||||||
|
|
||||||
if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id):
|
if utils.get_peer_id(entity) == utils.get_peer_id(message.to_id):
|
||||||
reply_id = message.reply_to_msg_id
|
reply_id = message.reply_to_msg_id
|
||||||
|
@ -716,7 +715,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
)
|
)
|
||||||
message = message.message
|
message = message.message
|
||||||
else:
|
else:
|
||||||
message, msg_ent = self._parse_message_text(message, parse_mode)
|
message, msg_ent = await self._parse_message_text(message, parse_mode)
|
||||||
request = SendMessageRequest(
|
request = SendMessageRequest(
|
||||||
peer=entity,
|
peer=entity,
|
||||||
message=message,
|
message=message,
|
||||||
|
@ -725,7 +724,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
reply_to_msg_id=self._get_message_id(reply_to)
|
reply_to_msg_id=self._get_message_id(reply_to)
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self(request)
|
result = await self(request)
|
||||||
|
|
||||||
if isinstance(result, UpdateShortSentMessage):
|
if isinstance(result, UpdateShortSentMessage):
|
||||||
return Message(
|
return Message(
|
||||||
id=result.id,
|
id=result.id,
|
||||||
|
@ -739,8 +739,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
return self._get_response_message(request, result)
|
return self._get_response_message(request, result)
|
||||||
|
|
||||||
def edit_message(self, entity, message_id, message=None, parse_mode='md',
|
async def edit_message(self, entity, message_id, message=None,
|
||||||
link_preview=True):
|
parse_mode='md', link_preview=True):
|
||||||
"""
|
"""
|
||||||
Edits the given message ID (to change its contents or disable preview).
|
Edits the given message ID (to change its contents or disable preview).
|
||||||
|
|
||||||
|
@ -773,18 +773,18 @@ class TelegramClient(TelegramBareClient):
|
||||||
Returns:
|
Returns:
|
||||||
the edited message
|
the edited message
|
||||||
"""
|
"""
|
||||||
message, msg_entities = self._parse_message_text(message, parse_mode)
|
message, msg_entities = await self._parse_message_text(message, parse_mode)
|
||||||
request = EditMessageRequest(
|
request = EditMessageRequest(
|
||||||
peer=self.get_input_entity(entity),
|
peer=await self.get_input_entity(entity),
|
||||||
id=self._get_message_id(message_id),
|
id=self._get_message_id(message_id),
|
||||||
message=message,
|
message=message,
|
||||||
no_webpage=not link_preview,
|
no_webpage=not link_preview,
|
||||||
entities=msg_entities
|
entities=msg_entities
|
||||||
)
|
)
|
||||||
result = self(request)
|
result = await self(request)
|
||||||
return self._get_response_message(request, result)
|
return self._get_response_message(request, result)
|
||||||
|
|
||||||
def delete_messages(self, entity, message_ids, revoke=True):
|
async def delete_messages(self, entity, message_ids, revoke=True):
|
||||||
"""
|
"""
|
||||||
Deletes a message from a chat, optionally "for everyone".
|
Deletes a message from a chat, optionally "for everyone".
|
||||||
|
|
||||||
|
@ -812,18 +812,18 @@ class TelegramClient(TelegramBareClient):
|
||||||
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
|
message_ids = [m.id if isinstance(m, Message) else int(m) for m in message_ids]
|
||||||
|
|
||||||
if entity is None:
|
if entity is None:
|
||||||
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||||
|
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
|
|
||||||
if isinstance(entity, InputPeerChannel):
|
if isinstance(entity, InputPeerChannel):
|
||||||
return self(channels.DeleteMessagesRequest(entity, message_ids))
|
return await self(channels.DeleteMessagesRequest(entity, message_ids))
|
||||||
else:
|
else:
|
||||||
return self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
return await self(messages.DeleteMessagesRequest(message_ids, revoke=revoke))
|
||||||
|
|
||||||
def get_message_history(self, entity, limit=20, offset_date=None,
|
async def get_message_history(self, entity, limit=20, offset_date=None,
|
||||||
offset_id=0, max_id=0, min_id=0, add_offset=0,
|
offset_id=0, max_id=0, min_id=0, add_offset=0,
|
||||||
batch_size=100, wait_time=None):
|
batch_size=100, wait_time=None):
|
||||||
"""
|
"""
|
||||||
Gets the message history for the specified entity
|
Gets the message history for the specified entity
|
||||||
|
|
||||||
|
@ -884,13 +884,12 @@ class TelegramClient(TelegramBareClient):
|
||||||
second is the default for this limit (or above). You may need
|
second is the default for this limit (or above). You may need
|
||||||
an higher limit, so you're free to set the ``batch_size`` that
|
an higher limit, so you're free to set the ``batch_size`` that
|
||||||
you think may be good.
|
you think may be good.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
limit = float('inf') if limit is None else int(limit)
|
limit = float('inf') if limit is None else int(limit)
|
||||||
if limit == 0:
|
if limit == 0:
|
||||||
# No messages, but we still need to know the total message count
|
# No messages, but we still need to know the total message count
|
||||||
result = self(GetHistoryRequest(
|
result = await self(GetHistoryRequest(
|
||||||
peer=entity, limit=1,
|
peer=entity, limit=1,
|
||||||
offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0
|
offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0
|
||||||
))
|
))
|
||||||
|
@ -906,7 +905,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
while len(messages) < limit:
|
while len(messages) < limit:
|
||||||
# Telegram has a hard limit of 100
|
# Telegram has a hard limit of 100
|
||||||
real_limit = min(limit - len(messages), batch_size)
|
real_limit = min(limit - len(messages), batch_size)
|
||||||
result = self(GetHistoryRequest(
|
result = await self(GetHistoryRequest(
|
||||||
peer=entity,
|
peer=entity,
|
||||||
limit=real_limit,
|
limit=real_limit,
|
||||||
offset_date=offset_date,
|
offset_date=offset_date,
|
||||||
|
@ -931,7 +930,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
offset_id = result.messages[-1].id
|
offset_id = result.messages[-1].id
|
||||||
offset_date = result.messages[-1].date
|
offset_date = result.messages[-1].date
|
||||||
time.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
# Add a few extra attributes to the Message to make it friendlier.
|
# Add a few extra attributes to the Message to make it friendlier.
|
||||||
messages.total = total_messages
|
messages.total = total_messages
|
||||||
|
@ -959,8 +958,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def send_read_acknowledge(self, entity, message=None, max_id=None,
|
async def send_read_acknowledge(self, entity, message=None, max_id=None,
|
||||||
clear_mentions=False):
|
clear_mentions=False):
|
||||||
"""
|
"""
|
||||||
Sends a "read acknowledge" (i.e., notifying the given peer that we've
|
Sends a "read acknowledge" (i.e., notifying the given peer that we've
|
||||||
read their messages, also known as the "double check").
|
read their messages, also known as the "double check").
|
||||||
|
@ -993,17 +992,17 @@ class TelegramClient(TelegramBareClient):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Either a message list or a max_id must be provided.')
|
'Either a message list or a max_id must be provided.')
|
||||||
|
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
if clear_mentions:
|
if clear_mentions:
|
||||||
self(ReadMentionsRequest(entity))
|
await self(ReadMentionsRequest(entity))
|
||||||
if max_id is None:
|
if max_id is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if max_id is not None:
|
if max_id is not None:
|
||||||
if isinstance(entity, InputPeerChannel):
|
if isinstance(entity, InputPeerChannel):
|
||||||
return self(channels.ReadHistoryRequest(entity, max_id=max_id))
|
return await self(channels.ReadHistoryRequest(entity, max_id=max_id))
|
||||||
else:
|
else:
|
||||||
return self(messages.ReadHistoryRequest(entity, max_id=max_id))
|
return await self(messages.ReadHistoryRequest(entity, max_id=max_id))
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1025,8 +1024,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
raise TypeError('Invalid message type: {}'.format(type(message)))
|
raise TypeError('Invalid message type: {}'.format(type(message)))
|
||||||
|
|
||||||
def get_participants(self, entity, limit=None, search='',
|
async def get_participants(self, entity, limit=None, search='',
|
||||||
aggressive=False):
|
aggressive=False):
|
||||||
"""
|
"""
|
||||||
Gets the list of participants from the specified entity.
|
Gets the list of participants from the specified entity.
|
||||||
|
|
||||||
|
@ -1054,12 +1053,12 @@ class TelegramClient(TelegramBareClient):
|
||||||
A list of participants with an additional .total variable on the
|
A list of participants with an additional .total variable on the
|
||||||
list indicating the total amount of members in this group/channel.
|
list indicating the total amount of members in this group/channel.
|
||||||
"""
|
"""
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
limit = float('inf') if limit is None else int(limit)
|
limit = float('inf') if limit is None else int(limit)
|
||||||
if isinstance(entity, InputPeerChannel):
|
if isinstance(entity, InputPeerChannel):
|
||||||
total = self(GetFullChannelRequest(
|
total = (await self(GetFullChannelRequest(
|
||||||
entity
|
entity
|
||||||
)).full_chat.participants_count
|
))).full_chat.participants_count
|
||||||
|
|
||||||
all_participants = {}
|
all_participants = {}
|
||||||
if total > 10000 and aggressive:
|
if total > 10000 and aggressive:
|
||||||
|
@ -1091,9 +1090,9 @@ class TelegramClient(TelegramBareClient):
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(requests) == 1:
|
if len(requests) == 1:
|
||||||
results = (self(requests[0]),)
|
results = (await self(requests[0]),)
|
||||||
else:
|
else:
|
||||||
results = self(*requests)
|
results = await self(*requests)
|
||||||
for i in reversed(range(len(requests))):
|
for i in reversed(range(len(requests))):
|
||||||
participants = results[i]
|
participants = results[i]
|
||||||
if not participants.users:
|
if not participants.users:
|
||||||
|
@ -1111,7 +1110,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
users = UserList(values)
|
users = UserList(values)
|
||||||
users.total = total
|
users.total = total
|
||||||
elif isinstance(entity, InputPeerChat):
|
elif isinstance(entity, InputPeerChat):
|
||||||
users = self(GetFullChatRequest(entity.chat_id)).users
|
users = (await self(GetFullChatRequest(entity.chat_id))).users
|
||||||
if len(users) > limit:
|
if len(users) > limit:
|
||||||
users = users[:limit]
|
users = users[:limit]
|
||||||
users = UserList(users)
|
users = UserList(users)
|
||||||
|
@ -1125,14 +1124,14 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# region Uploading files
|
# region Uploading files
|
||||||
|
|
||||||
def send_file(self, entity, file, caption=None,
|
async def send_file(self, entity, file, caption=None,
|
||||||
force_document=False, progress_callback=None,
|
force_document=False, progress_callback=None,
|
||||||
reply_to=None,
|
reply_to=None,
|
||||||
attributes=None,
|
attributes=None,
|
||||||
thumb=None,
|
thumb=None,
|
||||||
allow_cache=True,
|
allow_cache=True,
|
||||||
parse_mode='md',
|
parse_mode='md',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Sends a file to the specified entity.
|
Sends a file to the specified entity.
|
||||||
|
|
||||||
|
@ -1201,14 +1200,14 @@ class TelegramClient(TelegramBareClient):
|
||||||
# Convert to tuple so we can iterate several times
|
# Convert to tuple so we can iterate several times
|
||||||
file = tuple(x for x in file)
|
file = tuple(x for x in file)
|
||||||
if all(utils.is_image(x) for x in file):
|
if all(utils.is_image(x) for x in file):
|
||||||
return self._send_album(
|
return await self._send_album(
|
||||||
entity, file, caption=caption,
|
entity, file, caption=caption,
|
||||||
progress_callback=progress_callback, reply_to=reply_to,
|
progress_callback=progress_callback, reply_to=reply_to,
|
||||||
parse_mode=parse_mode
|
parse_mode=parse_mode
|
||||||
)
|
)
|
||||||
# Not all are images, so send all the files one by one
|
# Not all are images, so send all the files one by one
|
||||||
return [
|
return [
|
||||||
self.send_file(
|
await self.send_file(
|
||||||
entity, x, allow_cache=False,
|
entity, x, allow_cache=False,
|
||||||
caption=caption, force_document=force_document,
|
caption=caption, force_document=force_document,
|
||||||
progress_callback=progress_callback, reply_to=reply_to,
|
progress_callback=progress_callback, reply_to=reply_to,
|
||||||
|
@ -1216,7 +1215,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
) for x in file
|
) for x in file
|
||||||
]
|
]
|
||||||
|
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
reply_to = self._get_message_id(reply_to)
|
reply_to = self._get_message_id(reply_to)
|
||||||
caption, msg_entities = self._parse_message_text(caption, parse_mode)
|
caption, msg_entities = self._parse_message_text(caption, parse_mode)
|
||||||
|
|
||||||
|
@ -1233,11 +1232,11 @@ class TelegramClient(TelegramBareClient):
|
||||||
reply_to_msg_id=reply_to,
|
reply_to_msg_id=reply_to,
|
||||||
message=caption,
|
message=caption,
|
||||||
entities=msg_entities)
|
entities=msg_entities)
|
||||||
return self._get_response_message(request, self(request))
|
return self._get_response_message(request, await self(request))
|
||||||
|
|
||||||
as_image = utils.is_image(file) and not force_document
|
as_image = utils.is_image(file) and not force_document
|
||||||
use_cache = InputPhoto if as_image else InputDocument
|
use_cache = InputPhoto if as_image else InputDocument
|
||||||
file_handle = self.upload_file(
|
file_handle = await self.upload_file(
|
||||||
file, progress_callback=progress_callback,
|
file, progress_callback=progress_callback,
|
||||||
use_cache=use_cache if allow_cache else None
|
use_cache=use_cache if allow_cache else None
|
||||||
)
|
)
|
||||||
|
@ -1314,7 +1313,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
input_kw = {}
|
input_kw = {}
|
||||||
if thumb:
|
if thumb:
|
||||||
input_kw['thumb'] = self.upload_file(thumb)
|
input_kw['thumb'] = await self.upload_file(thumb)
|
||||||
|
|
||||||
media = InputMediaUploadedDocument(
|
media = InputMediaUploadedDocument(
|
||||||
file=file_handle,
|
file=file_handle,
|
||||||
|
@ -1327,7 +1326,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
# send the media message to the desired entity.
|
# send the media message to the desired entity.
|
||||||
request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to,
|
request = SendMediaRequest(entity, media, reply_to_msg_id=reply_to,
|
||||||
message=caption, entities=msg_entities)
|
message=caption, entities=msg_entities)
|
||||||
msg = self._get_response_message(request, self(request))
|
msg = self._get_response_message(request, await self(request))
|
||||||
if msg and isinstance(file_handle, InputSizedFile):
|
if msg and isinstance(file_handle, InputSizedFile):
|
||||||
# There was a response message and we didn't use cached
|
# There was a response message and we didn't use cached
|
||||||
# version, so cache whatever we just sent to the database.
|
# version, so cache whatever we just sent to the database.
|
||||||
|
@ -1345,7 +1344,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
kwargs['is_voice_note'] = True
|
kwargs['is_voice_note'] = True
|
||||||
return self.send_file(*args, **kwargs)
|
return self.send_file(*args, **kwargs)
|
||||||
|
|
||||||
def _send_album(self, entity, files, caption=None,
|
async def _send_album(self, entity, files, caption=None,
|
||||||
progress_callback=None, reply_to=None,
|
progress_callback=None, reply_to=None,
|
||||||
parse_mode='md'):
|
parse_mode='md'):
|
||||||
"""Specialized version of .send_file for albums"""
|
"""Specialized version of .send_file for albums"""
|
||||||
|
@ -1354,7 +1353,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
# we need to produce right now to send albums (uploadMedia), and
|
# we need to produce right now to send albums (uploadMedia), and
|
||||||
# cache only makes a difference for documents where the user may
|
# cache only makes a difference for documents where the user may
|
||||||
# want the attributes used on them to change.
|
# want the attributes used on them to change.
|
||||||
entity = self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
if not utils.is_list_like(caption):
|
if not utils.is_list_like(caption):
|
||||||
caption = (caption,)
|
caption = (caption,)
|
||||||
captions = [
|
captions = [
|
||||||
|
@ -1367,11 +1366,11 @@ class TelegramClient(TelegramBareClient):
|
||||||
media = []
|
media = []
|
||||||
for file in files:
|
for file in files:
|
||||||
# fh will either be InputPhoto or a modified InputFile
|
# fh will either be InputPhoto or a modified InputFile
|
||||||
fh = self.upload_file(file, use_cache=InputPhoto)
|
fh = await self.upload_file(file, use_cache=InputPhoto)
|
||||||
if not isinstance(fh, InputPhoto):
|
if not isinstance(fh, InputPhoto):
|
||||||
input_photo = utils.get_input_photo(self(UploadMediaRequest(
|
input_photo = utils.get_input_photo((await self(UploadMediaRequest(
|
||||||
entity, media=InputMediaUploadedPhoto(fh)
|
entity, media=InputMediaUploadedPhoto(fh)
|
||||||
)).photo)
|
))).photo)
|
||||||
self.session.cache_file(fh.md5, fh.size, input_photo)
|
self.session.cache_file(fh.md5, fh.size, input_photo)
|
||||||
fh = input_photo
|
fh = input_photo
|
||||||
|
|
||||||
|
@ -1383,7 +1382,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
entities=msg_entities))
|
entities=msg_entities))
|
||||||
|
|
||||||
# Now we can construct the multi-media request
|
# Now we can construct the multi-media request
|
||||||
result = self(SendMultiMediaRequest(
|
result = await self(SendMultiMediaRequest(
|
||||||
entity, reply_to_msg_id=reply_to, multi_media=media
|
entity, reply_to_msg_id=reply_to, multi_media=media
|
||||||
))
|
))
|
||||||
return [
|
return [
|
||||||
|
@ -1392,12 +1391,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
if isinstance(update, UpdateMessageID)
|
if isinstance(update, UpdateMessageID)
|
||||||
]
|
]
|
||||||
|
|
||||||
def upload_file(self,
|
async def upload_file(self, file, part_size_kb=None, file_name=None,
|
||||||
file,
|
use_cache=None, progress_callback=None):
|
||||||
part_size_kb=None,
|
|
||||||
file_name=None,
|
|
||||||
use_cache=None,
|
|
||||||
progress_callback=None):
|
|
||||||
"""
|
"""
|
||||||
Uploads the specified file and returns a handle (an instance of
|
Uploads the specified file and returns a handle (an instance of
|
||||||
InputFile or InputFileBig, as required) which can be later used
|
InputFile or InputFileBig, as required) which can be later used
|
||||||
|
@ -1510,7 +1505,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
else:
|
else:
|
||||||
request = SaveFilePartRequest(file_id, part_index, part)
|
request = SaveFilePartRequest(file_id, part_index, part)
|
||||||
|
|
||||||
result = self(request)
|
result = await self(request)
|
||||||
if result:
|
if result:
|
||||||
__log__.debug('Uploaded %d/%d', part_index + 1,
|
__log__.debug('Uploaded %d/%d', part_index + 1,
|
||||||
part_count)
|
part_count)
|
||||||
|
@ -1531,7 +1526,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# region Downloading media requests
|
# region Downloading media requests
|
||||||
|
|
||||||
def download_profile_photo(self, entity, file=None, download_big=True):
|
async def download_profile_photo(self, entity, file=None, download_big=True):
|
||||||
"""
|
"""
|
||||||
Downloads the profile photo of the given entity (user/chat/channel).
|
Downloads the profile photo of the given entity (user/chat/channel).
|
||||||
|
|
||||||
|
@ -1565,12 +1560,12 @@ class TelegramClient(TelegramBareClient):
|
||||||
# The hexadecimal numbers above are simply:
|
# The hexadecimal numbers above are simply:
|
||||||
# hex(crc32(x.encode('ascii'))) for x in
|
# hex(crc32(x.encode('ascii'))) for x in
|
||||||
# ('User', 'Chat', 'UserFull', 'ChatFull')
|
# ('User', 'Chat', 'UserFull', 'ChatFull')
|
||||||
entity = self.get_entity(entity)
|
entity = await self.get_entity(entity)
|
||||||
if not hasattr(entity, 'photo'):
|
if not hasattr(entity, 'photo'):
|
||||||
# Special case: may be a ChatFull with photo:Photo
|
# Special case: may be a ChatFull with photo:Photo
|
||||||
# This is different from a normal UserProfilePhoto and Chat
|
# This is different from a normal UserProfilePhoto and Chat
|
||||||
if hasattr(entity, 'chat_photo'):
|
if hasattr(entity, 'chat_photo'):
|
||||||
return self._download_photo(
|
return await self._download_photo(
|
||||||
entity.chat_photo, file,
|
entity.chat_photo, file,
|
||||||
date=None, progress_callback=None
|
date=None, progress_callback=None
|
||||||
)
|
)
|
||||||
|
@ -1595,7 +1590,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# Download the media with the largest size input file location
|
# Download the media with the largest size input file location
|
||||||
try:
|
try:
|
||||||
self.download_file(
|
await self.download_file(
|
||||||
InputFileLocation(
|
InputFileLocation(
|
||||||
volume_id=photo_location.volume_id,
|
volume_id=photo_location.volume_id,
|
||||||
local_id=photo_location.local_id,
|
local_id=photo_location.local_id,
|
||||||
|
@ -1606,10 +1601,10 @@ class TelegramClient(TelegramBareClient):
|
||||||
except LocationInvalidError:
|
except LocationInvalidError:
|
||||||
# See issue #500, Android app fails as of v4.6.0 (1155).
|
# See issue #500, Android app fails as of v4.6.0 (1155).
|
||||||
# The fix seems to be using the full channel chat photo.
|
# The fix seems to be using the full channel chat photo.
|
||||||
ie = self.get_input_entity(entity)
|
ie = await self.get_input_entity(entity)
|
||||||
if isinstance(ie, InputPeerChannel):
|
if isinstance(ie, InputPeerChannel):
|
||||||
full = self(GetFullChannelRequest(ie))
|
full = await self(GetFullChannelRequest(ie))
|
||||||
return self._download_photo(
|
return await self._download_photo(
|
||||||
full.full_chat.chat_photo, file,
|
full.full_chat.chat_photo, file,
|
||||||
date=None, progress_callback=None
|
date=None, progress_callback=None
|
||||||
)
|
)
|
||||||
|
@ -1618,7 +1613,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
return None
|
return None
|
||||||
return file
|
return file
|
||||||
|
|
||||||
def download_media(self, message, file=None, progress_callback=None):
|
async def download_media(self, message, file=None, progress_callback=None):
|
||||||
"""
|
"""
|
||||||
Downloads the given media, or the media from a specified Message.
|
Downloads the given media, or the media from a specified Message.
|
||||||
|
|
||||||
|
@ -1646,19 +1641,19 @@ class TelegramClient(TelegramBareClient):
|
||||||
media = message
|
media = message
|
||||||
|
|
||||||
if isinstance(media, (MessageMediaPhoto, Photo)):
|
if isinstance(media, (MessageMediaPhoto, Photo)):
|
||||||
return self._download_photo(
|
return await self._download_photo(
|
||||||
media, file, date, progress_callback
|
media, file, date, progress_callback
|
||||||
)
|
)
|
||||||
elif isinstance(media, (MessageMediaDocument, Document)):
|
elif isinstance(media, (MessageMediaDocument, Document)):
|
||||||
return self._download_document(
|
return await self._download_document(
|
||||||
media, file, date, progress_callback
|
media, file, date, progress_callback
|
||||||
)
|
)
|
||||||
elif isinstance(media, MessageMediaContact):
|
elif isinstance(media, MessageMediaContact):
|
||||||
return self._download_contact(
|
return await self._download_contact(
|
||||||
media, file
|
media, file
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_photo(self, photo, file, date, progress_callback):
|
async def _download_photo(self, photo, file, date, progress_callback):
|
||||||
"""Specialized version of .download_media() for photos"""
|
"""Specialized version of .download_media() for photos"""
|
||||||
# Determine the photo and its largest size
|
# Determine the photo and its largest size
|
||||||
if isinstance(photo, MessageMediaPhoto):
|
if isinstance(photo, MessageMediaPhoto):
|
||||||
|
@ -1673,7 +1668,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
|
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
|
||||||
|
|
||||||
# Download the media with the largest size input file location
|
# Download the media with the largest size input file location
|
||||||
self.download_file(
|
await self.download_file(
|
||||||
InputFileLocation(
|
InputFileLocation(
|
||||||
volume_id=largest_size.volume_id,
|
volume_id=largest_size.volume_id,
|
||||||
local_id=largest_size.local_id,
|
local_id=largest_size.local_id,
|
||||||
|
@ -1685,7 +1680,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
)
|
)
|
||||||
return file
|
return file
|
||||||
|
|
||||||
def _download_document(self, document, file, date, progress_callback):
|
async def _download_document(self, document, file, date, progress_callback):
|
||||||
"""Specialized version of .download_media() for documents"""
|
"""Specialized version of .download_media() for documents"""
|
||||||
if isinstance(document, MessageMediaDocument):
|
if isinstance(document, MessageMediaDocument):
|
||||||
document = document.document
|
document = document.document
|
||||||
|
@ -1718,7 +1713,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
date=date, possible_names=possible_names
|
date=date, possible_names=possible_names
|
||||||
)
|
)
|
||||||
|
|
||||||
self.download_file(
|
await self.download_file(
|
||||||
InputDocumentFileLocation(
|
InputDocumentFileLocation(
|
||||||
id=document.id,
|
id=document.id,
|
||||||
access_hash=document.access_hash,
|
access_hash=document.access_hash,
|
||||||
|
@ -1825,12 +1820,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
return result
|
return result
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
def download_file(self,
|
async def download_file(self, input_location, file, part_size_kb=None,
|
||||||
input_location,
|
file_size=None, progress_callback=None):
|
||||||
file,
|
|
||||||
part_size_kb=None,
|
|
||||||
file_size=None,
|
|
||||||
progress_callback=None):
|
|
||||||
"""
|
"""
|
||||||
Downloads the given input location to a file.
|
Downloads the given input location to a file.
|
||||||
|
|
||||||
|
@ -1889,23 +1880,24 @@ class TelegramClient(TelegramBareClient):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if cdn_decrypter:
|
if cdn_decrypter:
|
||||||
result = cdn_decrypter.get_file()
|
result = await cdn_decrypter.get_file()
|
||||||
else:
|
else:
|
||||||
result = client(GetFileRequest(
|
result = await client(GetFileRequest(
|
||||||
input_location, offset, part_size
|
input_location, offset, part_size
|
||||||
))
|
))
|
||||||
|
|
||||||
if isinstance(result, FileCdnRedirect):
|
if isinstance(result, FileCdnRedirect):
|
||||||
__log__.info('File lives in a CDN')
|
__log__.info('File lives in a CDN')
|
||||||
cdn_decrypter, result = \
|
cdn_decrypter, result = \
|
||||||
CdnDecrypter.prepare_decrypter(
|
await CdnDecrypter.prepare_decrypter(
|
||||||
client, self._get_cdn_client(result),
|
client,
|
||||||
|
await self._get_cdn_client(result),
|
||||||
result
|
result
|
||||||
)
|
)
|
||||||
|
|
||||||
except FileMigrateError as e:
|
except FileMigrateError as e:
|
||||||
__log__.info('File lives in another DC')
|
__log__.info('File lives in another DC')
|
||||||
client = self._get_exported_client(e.new_dc)
|
client = await self._get_exported_client(e.new_dc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
offset += part_size
|
offset += part_size
|
||||||
|
@ -1947,25 +1939,25 @@ class TelegramClient(TelegramBareClient):
|
||||||
The event builder class or instance to be used,
|
The event builder class or instance to be used,
|
||||||
for instance ``events.NewMessage``.
|
for instance ``events.NewMessage``.
|
||||||
"""
|
"""
|
||||||
def decorator(f):
|
async def decorator(f):
|
||||||
self.add_event_handler(f, event)
|
await self.add_event_handler(f, event)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def _check_events_pending_resolve(self):
|
async def _check_events_pending_resolve(self):
|
||||||
if self._events_pending_resolve:
|
if self._events_pending_resolve:
|
||||||
for event in self._events_pending_resolve:
|
for event in self._events_pending_resolve:
|
||||||
event.resolve(self)
|
await event.resolve(self)
|
||||||
self._events_pending_resolve.clear()
|
self._events_pending_resolve.clear()
|
||||||
|
|
||||||
def _on_handler(self, update):
|
async def _on_handler(self, update):
|
||||||
for builder, callback in self._event_builders:
|
for builder, callback in self._event_builders:
|
||||||
event = builder.build(update)
|
event = builder.build(update)
|
||||||
if event:
|
if event:
|
||||||
event._client = self
|
event._client = self
|
||||||
try:
|
try:
|
||||||
callback(event)
|
await callback(event)
|
||||||
except events.StopPropagation:
|
except events.StopPropagation:
|
||||||
__log__.debug(
|
__log__.debug(
|
||||||
"Event handler '{}' stopped chain of "
|
"Event handler '{}' stopped chain of "
|
||||||
|
@ -1974,7 +1966,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def add_event_handler(self, callback, event=None):
|
async def add_event_handler(self, callback, event=None):
|
||||||
"""
|
"""
|
||||||
Registers the given callback to be called on the specified event.
|
Registers the given callback to be called on the specified event.
|
||||||
|
|
||||||
|
@ -1989,12 +1981,6 @@ class TelegramClient(TelegramBareClient):
|
||||||
If left unspecified, ``events.Raw`` (the ``Update`` objects
|
If left unspecified, ``events.Raw`` (the ``Update`` objects
|
||||||
with no further processing) will be passed instead.
|
with no further processing) will be passed instead.
|
||||||
"""
|
"""
|
||||||
if self.updates.workers is None:
|
|
||||||
warnings.warn(
|
|
||||||
"You have not setup any workers, so you won't receive updates."
|
|
||||||
" Pass update_workers=1 when creating the TelegramClient,"
|
|
||||||
" or set client.self.updates.workers = 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.updates.handler = self._on_handler
|
self.updates.handler = self._on_handler
|
||||||
if isinstance(event, type):
|
if isinstance(event, type):
|
||||||
|
@ -2003,8 +1989,8 @@ class TelegramClient(TelegramBareClient):
|
||||||
event = events.Raw()
|
event = events.Raw()
|
||||||
|
|
||||||
if self.is_user_authorized():
|
if self.is_user_authorized():
|
||||||
event.resolve(self)
|
await event.resolve(self)
|
||||||
self._check_events_pending_resolve()
|
await self._check_events_pending_resolve()
|
||||||
else:
|
else:
|
||||||
self._events_pending_resolve.append(event)
|
self._events_pending_resolve.append(event)
|
||||||
|
|
||||||
|
@ -2031,11 +2017,11 @@ class TelegramClient(TelegramBareClient):
|
||||||
|
|
||||||
# region Small utilities to make users' life easier
|
# region Small utilities to make users' life easier
|
||||||
|
|
||||||
def _set_connected_and_authorized(self):
|
async def _set_connected_and_authorized(self):
|
||||||
super()._set_connected_and_authorized()
|
await super()._set_connected_and_authorized()
|
||||||
self._check_events_pending_resolve()
|
await self._check_events_pending_resolve()
|
||||||
|
|
||||||
def get_entity(self, entity):
|
async def get_entity(self, entity):
|
||||||
"""
|
"""
|
||||||
Turns the given entity into a valid Telegram user or chat.
|
Turns the given entity into a valid Telegram user or chat.
|
||||||
|
|
||||||
|
@ -2069,7 +2055,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
# input channels (get channels) to get the most entities
|
# input channels (get channels) to get the most entities
|
||||||
# in the less amount of calls possible.
|
# in the less amount of calls possible.
|
||||||
inputs = [
|
inputs = [
|
||||||
x if isinstance(x, str) else self.get_input_entity(x)
|
x if isinstance(x, str) else await self.get_input_entity(x)
|
||||||
for x in entity
|
for x in entity
|
||||||
]
|
]
|
||||||
users = [x for x in inputs if isinstance(x, InputPeerUser)]
|
users = [x for x in inputs if isinstance(x, InputPeerUser)]
|
||||||
|
@ -2080,12 +2066,12 @@ class TelegramClient(TelegramBareClient):
|
||||||
tmp = []
|
tmp = []
|
||||||
while users:
|
while users:
|
||||||
curr, users = users[:200], users[200:]
|
curr, users = users[:200], users[200:]
|
||||||
tmp.extend(self(GetUsersRequest(curr)))
|
tmp.extend(await self(GetUsersRequest(curr)))
|
||||||
users = tmp
|
users = tmp
|
||||||
if chats: # TODO Handle chats slice?
|
if chats: # TODO Handle chats slice?
|
||||||
chats = self(GetChatsRequest(chats)).chats
|
chats = (await self(GetChatsRequest(chats))).chats
|
||||||
if channels:
|
if channels:
|
||||||
channels = self(GetChannelsRequest(channels)).chats
|
channels = (await self(GetChannelsRequest(channels))).chats
|
||||||
|
|
||||||
# Merge users, chats and channels into a single dictionary
|
# Merge users, chats and channels into a single dictionary
|
||||||
id_entity = {
|
id_entity = {
|
||||||
|
@ -2098,33 +2084,31 @@ class TelegramClient(TelegramBareClient):
|
||||||
# the amount of ResolveUsername calls, it would fail to catch
|
# the amount of ResolveUsername calls, it would fail to catch
|
||||||
# username changes.
|
# username changes.
|
||||||
result = [
|
result = [
|
||||||
self._get_entity_from_string(x) if isinstance(x, str)
|
await self._get_entity_from_string(x) if isinstance(x, str)
|
||||||
else id_entity[utils.get_peer_id(x)]
|
else id_entity[utils.get_peer_id(x)]
|
||||||
for x in inputs
|
for x in inputs
|
||||||
]
|
]
|
||||||
return result[0] if single else result
|
return result[0] if single else result
|
||||||
|
|
||||||
def _get_entity_from_string(self, string):
|
async def _get_entity_from_string(self, string):
|
||||||
"""
|
"""
|
||||||
Gets a full entity from the given string, which may be a phone or
|
Gets a full entity from the given string, which may be a phone or
|
||||||
an username, and processes all the found entities on the session.
|
an username, and processes all the found entities on the session.
|
||||||
The string may also be a user link, or a channel/chat invite link.
|
The string may also be a user link, or a channel/chat invite link.
|
||||||
|
|
||||||
This method has the side effect of adding the found users to the
|
This method has the side effect of adding the found users to the
|
||||||
session database, so it can be queried later without API calls,
|
session database, so it can be queried later without API calls,
|
||||||
if this option is enabled on the session.
|
if this option is enabled on the session.
|
||||||
|
|
||||||
Returns the found entity, or raises TypeError if not found.
|
Returns the found entity, or raises TypeError if not found.
|
||||||
"""
|
"""
|
||||||
phone = utils.parse_phone(string)
|
phone = utils.parse_phone(string)
|
||||||
if phone:
|
if phone:
|
||||||
for user in self(GetContactsRequest(0)).users:
|
for user in (await self(GetContactsRequest(0))).users:
|
||||||
if user.phone == phone:
|
if user.phone == phone:
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
username, is_join_chat = utils.parse_username(string)
|
username, is_join_chat = utils.parse_username(string)
|
||||||
if is_join_chat:
|
if is_join_chat:
|
||||||
invite = self(CheckChatInviteRequest(username))
|
invite = await self(CheckChatInviteRequest(username))
|
||||||
if isinstance(invite, ChatInvite):
|
if isinstance(invite, ChatInvite):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Cannot get entity from a channel '
|
'Cannot get entity from a channel '
|
||||||
|
@ -2134,14 +2118,15 @@ class TelegramClient(TelegramBareClient):
|
||||||
return invite.chat
|
return invite.chat
|
||||||
elif username:
|
elif username:
|
||||||
if username in ('me', 'self'):
|
if username in ('me', 'self'):
|
||||||
return self.get_me()
|
return await self.get_me()
|
||||||
result = self(ResolveUsernameRequest(username))
|
result = await self(ResolveUsernameRequest(username))
|
||||||
for entity in itertools.chain(result.users, result.chats):
|
for entity in itertools.chain(result.users, result.chats):
|
||||||
if entity.username.lower() == username:
|
if entity.username.lower() == username:
|
||||||
return entity
|
return entity
|
||||||
try:
|
try:
|
||||||
# Nobody with this username, maybe it's an exact name/title
|
# Nobody with this username, maybe it's an exact name/title
|
||||||
return self.get_entity(self.session.get_input_entity(string))
|
return await self.get_entity(
|
||||||
|
self.session.get_input_entity(string))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2149,24 +2134,20 @@ class TelegramClient(TelegramBareClient):
|
||||||
'Cannot turn "{}" into any entity (user or chat)'.format(string)
|
'Cannot turn "{}" into any entity (user or chat)'.format(string)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_entity(self, peer):
|
async def get_input_entity(self, peer):
|
||||||
"""
|
"""
|
||||||
Turns the given peer into its input entity version. Most requests
|
Turns the given peer into its input entity version. Most requests
|
||||||
use this kind of InputUser, InputChat and so on, so this is the
|
use this kind of InputUser, InputChat and so on, so this is the
|
||||||
most suitable call to make for those cases.
|
most suitable call to make for those cases.
|
||||||
|
|
||||||
entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`):
|
entity (:obj:`str` | :obj:`int` | :obj:`Peer` | :obj:`InputPeer`):
|
||||||
The integer ID of an user or otherwise either of a
|
The integer ID of an user or otherwise either of a
|
||||||
``PeerUser``, ``PeerChat`` or ``PeerChannel``, for
|
``PeerUser``, ``PeerChat`` or ``PeerChannel``, for
|
||||||
which to get its ``Input*`` version.
|
which to get its ``Input*`` version.
|
||||||
|
|
||||||
If this ``Peer`` hasn't been seen before by the library, the top
|
If this ``Peer`` hasn't been seen before by the library, the top
|
||||||
dialogs will be loaded and their entities saved to the session
|
dialogs will be loaded and their entities saved to the session
|
||||||
file (unless this feature was disabled explicitly).
|
file (unless this feature was disabled explicitly).
|
||||||
|
|
||||||
If in the end the access hash required for the peer was not found,
|
If in the end the access hash required for the peer was not found,
|
||||||
a ValueError will be raised.
|
a ValueError will be raised.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``.
|
``InputPeerUser``, ``InputPeerChat`` or ``InputPeerChannel``.
|
||||||
"""
|
"""
|
||||||
|
@ -2179,7 +2160,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
if isinstance(peer, str):
|
if isinstance(peer, str):
|
||||||
if peer in ('me', 'self'):
|
if peer in ('me', 'self'):
|
||||||
return InputPeerSelf()
|
return InputPeerSelf()
|
||||||
return utils.get_input_peer(self._get_entity_from_string(peer))
|
return utils.get_input_peer(await self._get_entity_from_string(peer))
|
||||||
|
|
||||||
if isinstance(peer, int):
|
if isinstance(peer, int):
|
||||||
peer, kind = utils.resolve_id(peer)
|
peer, kind = utils.resolve_id(peer)
|
||||||
|
@ -2206,7 +2187,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
limit=100
|
limit=100
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
result = self(req)
|
result = await self(req)
|
||||||
entities = {}
|
entities = {}
|
||||||
for x in itertools.chain(result.users, result.chats):
|
for x in itertools.chain(result.users, result.chats):
|
||||||
x_id = utils.get_peer_id(x)
|
x_id = utils.get_peer_id(x)
|
||||||
|
@ -2222,7 +2203,7 @@ class TelegramClient(TelegramBareClient):
|
||||||
req.offset_peer = entities[utils.get_peer_id(
|
req.offset_peer = entities[utils.get_peer_id(
|
||||||
result.dialogs[-1].peer
|
result.dialogs[-1].peer
|
||||||
)]
|
)]
|
||||||
time.sleep(1)
|
asyncio.sleep(1)
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'Could not find the input entity corresponding to "{}". '
|
'Could not find the input entity corresponding to "{}". '
|
||||||
|
|
|
@ -26,9 +26,9 @@ class Dialog:
|
||||||
|
|
||||||
self.draft = Draft(client, dialog.peer, dialog.draft)
|
self.draft = Draft(client, dialog.peer, dialog.draft)
|
||||||
|
|
||||||
def send_message(self, *args, **kwargs):
|
async def send_message(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Sends a message to this dialog. This is just a wrapper around
|
Sends a message to this dialog. This is just a wrapper around
|
||||||
client.send_message(dialog.input_entity, *args, **kwargs).
|
client.send_message(dialog.input_entity, *args, **kwargs).
|
||||||
"""
|
"""
|
||||||
return self._client.send_message(self.input_entity, *args, **kwargs)
|
return await self._client.send_message(self.input_entity, *args, **kwargs)
|
||||||
|
|
|
@ -31,14 +31,14 @@ class Draft:
|
||||||
return cls(client=client, peer=update.peer, draft=update.draft)
|
return cls(client=client, peer=update.peer, draft=update.draft)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity(self):
|
async def entity(self):
|
||||||
return self._client.get_entity(self._peer)
|
return await self._client.get_entity(self._peer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_entity(self):
|
async def input_entity(self):
|
||||||
return self._client.get_input_entity(self._peer)
|
return await self._client.get_input_entity(self._peer)
|
||||||
|
|
||||||
def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
|
async def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
|
||||||
"""
|
"""
|
||||||
Changes the draft message on the Telegram servers. The changes are
|
Changes the draft message on the Telegram servers. The changes are
|
||||||
reflected in this object. Changing only individual attributes like for
|
reflected in this object. Changing only individual attributes like for
|
||||||
|
@ -58,7 +58,7 @@ class Draft:
|
||||||
:param list entities: A list of formatting entities
|
:param list entities: A list of formatting entities
|
||||||
:return bool: ``True`` on success
|
:return bool: ``True`` on success
|
||||||
"""
|
"""
|
||||||
result = self._client(SaveDraftRequest(
|
result = await self._client(SaveDraftRequest(
|
||||||
peer=self._peer,
|
peer=self._peer,
|
||||||
message=text,
|
message=text,
|
||||||
no_webpage=no_webpage,
|
no_webpage=no_webpage,
|
||||||
|
@ -74,9 +74,9 @@ class Draft:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def delete(self):
|
async def delete(self):
|
||||||
"""
|
"""
|
||||||
Deletes this draft
|
Deletes this draft
|
||||||
:return bool: ``True`` on success
|
:return bool: ``True`` on success
|
||||||
"""
|
"""
|
||||||
return self.set_message(text='')
|
return await self.set_message(text='')
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import struct
|
import struct
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
from threading import Event
|
|
||||||
|
|
||||||
|
|
||||||
class TLObject:
|
class TLObject:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.confirm_received = Event()
|
self.confirm_received = None
|
||||||
self.rpc_error = None
|
self.rpc_error = None
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
|
@ -157,7 +156,7 @@ class TLObject:
|
||||||
return TLObject.pretty_format(self, indent=0)
|
return TLObject.pretty_format(self, indent=0)
|
||||||
|
|
||||||
# These should be overrode
|
# These should be overrode
|
||||||
def resolve(self, client, utils):
|
async def resolve(self, client, utils):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from queue import Queue, Empty
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from threading import RLock, Thread
|
|
||||||
|
|
||||||
from .tl import types as tl
|
from .tl import types as tl
|
||||||
|
|
||||||
|
@ -16,125 +15,40 @@ class UpdateState:
|
||||||
"""
|
"""
|
||||||
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
|
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
|
||||||
|
|
||||||
def __init__(self, workers=None):
|
def __init__(self, loop=None):
|
||||||
"""
|
|
||||||
:param workers: This integer parameter has three possible cases:
|
|
||||||
workers is None: Updates will *not* be stored on self.
|
|
||||||
workers = 0: Another thread is responsible for calling self.poll()
|
|
||||||
workers > 0: 'workers' background threads will be spawned, any
|
|
||||||
any of them will invoke the self.handler.
|
|
||||||
"""
|
|
||||||
self._workers = workers
|
|
||||||
self._worker_threads = []
|
|
||||||
|
|
||||||
self.handler = None
|
self.handler = None
|
||||||
self._updates_lock = RLock()
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
self._updates = Queue()
|
|
||||||
|
|
||||||
# https://core.telegram.org/api/updates
|
# https://core.telegram.org/api/updates
|
||||||
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
|
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
|
||||||
|
|
||||||
def can_poll(self):
|
def handle_update(self, update):
|
||||||
"""Returns True if a call to .poll() won't lock"""
|
if self.handler:
|
||||||
return not self._updates.empty()
|
asyncio.ensure_future(self.handler(update), loop=self._loop)
|
||||||
|
|
||||||
def poll(self, timeout=None):
|
|
||||||
"""
|
|
||||||
Polls an update or blocks until an update object is available.
|
|
||||||
If 'timeout is not None', it should be a floating point value,
|
|
||||||
and the method will 'return None' if waiting times out.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self._updates.get(timeout=timeout)
|
|
||||||
except Empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_workers(self):
|
|
||||||
return self._workers
|
|
||||||
|
|
||||||
def set_workers(self, n):
|
|
||||||
"""Changes the number of workers running.
|
|
||||||
If 'n is None', clears all pending updates from memory.
|
|
||||||
"""
|
|
||||||
if n is None:
|
|
||||||
self.stop_workers()
|
|
||||||
else:
|
|
||||||
self._workers = n
|
|
||||||
self.setup_workers()
|
|
||||||
|
|
||||||
workers = property(fget=get_workers, fset=set_workers)
|
|
||||||
|
|
||||||
def stop_workers(self):
|
|
||||||
"""
|
|
||||||
Waits for all the worker threads to stop.
|
|
||||||
"""
|
|
||||||
# Put dummy ``None`` objects so that they don't need to timeout.
|
|
||||||
n = self._workers
|
|
||||||
self._workers = None
|
|
||||||
if n:
|
|
||||||
with self._updates_lock:
|
|
||||||
for _ in range(n):
|
|
||||||
self._updates.put(None)
|
|
||||||
|
|
||||||
for t in self._worker_threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
self._worker_threads.clear()
|
|
||||||
|
|
||||||
def setup_workers(self):
|
|
||||||
if self._worker_threads or not self._workers:
|
|
||||||
# There already are workers, or workers is None or 0. Do nothing.
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(self._workers):
|
|
||||||
thread = Thread(
|
|
||||||
target=UpdateState._worker_loop,
|
|
||||||
name='UpdateWorker{}'.format(i),
|
|
||||||
daemon=True,
|
|
||||||
args=(self, i)
|
|
||||||
)
|
|
||||||
self._worker_threads.append(thread)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
def _worker_loop(self, wid):
|
|
||||||
while self._workers is not None:
|
|
||||||
try:
|
|
||||||
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
|
|
||||||
if update and self.handler:
|
|
||||||
self.handler(update)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
# We don't want to crash a worker thread due to any reason
|
|
||||||
__log__.exception('Unhandled exception on worker %d', wid)
|
|
||||||
|
|
||||||
def process(self, update):
|
def process(self, update):
|
||||||
"""Processes an update object. This method is normally called by
|
"""Processes an update object. This method is normally called by
|
||||||
the library itself.
|
the library itself.
|
||||||
"""
|
"""
|
||||||
if self._workers is None:
|
if isinstance(update, tl.updates.State):
|
||||||
return # No processing needs to be done if nobody's working
|
__log__.debug('Saved new updates state')
|
||||||
|
self._state = update
|
||||||
|
return # Nothing else to be done
|
||||||
|
|
||||||
with self._updates_lock:
|
if hasattr(update, 'pts'):
|
||||||
if isinstance(update, tl.updates.State):
|
self._state.pts = update.pts
|
||||||
__log__.debug('Saved new updates state')
|
|
||||||
self._state = update
|
|
||||||
return # Nothing else to be done
|
|
||||||
|
|
||||||
if hasattr(update, 'pts'):
|
# After running the script for over an hour and receiving over
|
||||||
self._state.pts = update.pts
|
# 1000 updates, the only duplicates received were users going
|
||||||
|
# online or offline. We can trust the server until new reports.
|
||||||
# After running the script for over an hour and receiving over
|
if isinstance(update, tl.UpdateShort):
|
||||||
# 1000 updates, the only duplicates received were users going
|
self.handle_update(update.update)
|
||||||
# online or offline. We can trust the server until new reports.
|
# Expand "Updates" into "Update", and pass these to callbacks.
|
||||||
if isinstance(update, tl.UpdateShort):
|
# Since .users and .chats have already been processed, we
|
||||||
self._updates.put(update.update)
|
# don't need to care about those either.
|
||||||
# Expand "Updates" into "Update", and pass these to callbacks.
|
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
||||||
# Since .users and .chats have already been processed, we
|
for u in update.updates:
|
||||||
# don't need to care about those either.
|
self.handle_update(u)
|
||||||
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
# TODO Handle "tl.UpdatesTooLong"
|
||||||
for u in update.updates:
|
else:
|
||||||
self._updates.put(u)
|
self.handle_update(update)
|
||||||
# TODO Handle "tl.UpdatesTooLong"
|
|
||||||
else:
|
|
||||||
self._updates.put(update)
|
|
||||||
|
|
|
@ -11,9 +11,9 @@ AUTO_GEN_NOTICE = \
|
||||||
|
|
||||||
|
|
||||||
AUTO_CASTS = {
|
AUTO_CASTS = {
|
||||||
'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))',
|
'InputPeer': 'utils.get_input_peer(await client.get_input_entity({}))',
|
||||||
'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))',
|
'InputChannel': 'utils.get_input_channel(await client.get_input_entity({}))',
|
||||||
'InputUser': 'utils.get_input_user(client.get_input_entity({}))',
|
'InputUser': 'utils.get_input_user(await client.get_input_entity({}))',
|
||||||
'InputMedia': 'utils.get_input_media({})',
|
'InputMedia': 'utils.get_input_media({})',
|
||||||
'InputPhoto': 'utils.get_input_photo({})'
|
'InputPhoto': 'utils.get_input_photo({})'
|
||||||
}
|
}
|
||||||
|
@ -289,7 +289,7 @@ class TLGenerator:
|
||||||
|
|
||||||
# Write the resolve(self, client, utils) method
|
# Write the resolve(self, client, utils) method
|
||||||
if any(arg.type in AUTO_CASTS for arg in args):
|
if any(arg.type in AUTO_CASTS for arg in args):
|
||||||
builder.writeln('def resolve(self, client, utils):')
|
builder.writeln('async def resolve(self, client, utils):')
|
||||||
for arg in args:
|
for arg in args:
|
||||||
ac = AUTO_CASTS.get(arg.type, None)
|
ac = AUTO_CASTS.get(arg.type, None)
|
||||||
if ac:
|
if ac:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user